Rate this Page

torch.mps.compile_shader#

torch.mps.compile_shader(source)[source]#

Compiles compute shader from source and allows one to invoke kernelsdefined there from the comfort of Python runtimeExample:

>>>lib=torch.mps.compile_shader(..."kernel void full(device float* out, constant float& val, uint idx [[thread_position_in_grid]]) { out[idx] = val; }"...)>>>x=torch.zeros(16,device="mps")>>>lib.full(x,3.14)