Rate this Page
★★★★★
torch.mps.compile_shader#
- torch.mps.compile_shader(source)[source]#
Compiles compute shader from source and allows one to invoke kernelsdefined there from the comfort of Python runtimeExample:
>>>lib=torch.mps.compile_shader(..."kernel void full(device float* out, constant float& val, uint idx [[thread_position_in_grid]]) { out[idx] = val; }"...)>>>x=torch.zeros(16,device="mps")>>>lib.full(x,3.14)
On this page