Rate this Page

Note

Go to the endto download the full example code.

Custom Python Operators#

Created On: Jun 18, 2024 | Last Updated: Mar 19, 2025 | Last Verified: Nov 05, 2024

What you will learn
  • How to integrate custom operators written in Python with PyTorch

  • How to test custom operators usingtorch.library.opcheck

Prerequisites
  • PyTorch 2.4 or later

PyTorch offers a large library of operators that work on Tensors (e.g.torch.add,torch.sum, etc). However, you might wish to use a new customizedoperator with PyTorch, perhaps written by a third-party library. This tutorialshows how to wrap Python functions so that they behave like PyTorch nativeoperators. Reasons why you may wish to create a custom operator in PyTorch include:

  • Treating an arbitrary Python function as an opaque callable with respecttotorch.compile (that is, preventtorch.compile from tracinginto the function).

  • Adding training support to an arbitrary Python function

Usetorch.library.custom_op() to create Python custom operators.Use the C++TORCH_LIBRARY APIs to create C++ custom operators (thesework in Python-less environments).See theCustom Operators Landing Pagefor more details.

Please note that if your operation can be expressed as a composition ofexisting PyTorch operators, then there is usually no need to use the custom operatorAPI – everything (for exampletorch.compile, training support) shouldjust work.

Example: Wrapping PIL’s crop into a custom operator#

Let’s say that we are using PIL’scrop operation.

importtorchfromtorchvision.transforms.functionalimportto_pil_image,pil_to_tensorimportPILimportIPythonimportmatplotlib.pyplotaspltdefcrop(pic,box):img=to_pil_image(pic.cpu())cropped_img=img.crop(box)returnpil_to_tensor(cropped_img).to(pic.device)/255.defdisplay(img):plt.imshow(img.numpy().transpose((1,2,0)))img=torch.ones(3,64,64)img*=torch.linspace(0,1,steps=64)*torch.linspace(0,1,steps=64).unsqueeze(-1)display(img)
python custom ops
cropped_img=crop(img,(10,10,50,50))display(cropped_img)
python custom ops

crop is not handled effectively out-of-the-box bytorch.compile:torch.compile induces a“graph break”on functions it is unable to handle and graph breaks are bad for performance.The following code demonstrates this by raising an error(torch.compile withfullgraph=True raises an error if agraph break occurs).

@torch.compile(fullgraph=True)deff(img):returncrop(img,(10,10,50,50))# The following raises an error. Uncomment the line to see it.# cropped_img = f(img)

In order to black-boxcrop for use withtorch.compile, we need todo two things:

  1. wrap the function into a PyTorch custom operator.

  2. add a “FakeTensor kernel” (aka “meta kernel”) to the operator.Given someFakeTensors inputs (dummy Tensors that don’t have storage),this function should return dummy Tensors of your choice with the correctTensor metadata (shape/strides/dtype/device).

fromtypingimportSequence# Use torch.library.custom_op to define a new custom operator.# If your operator mutates any input Tensors, their names must be specified# in the ``mutates_args`` argument.@torch.library.custom_op("mylib::crop",mutates_args=())defcrop(pic:torch.Tensor,box:Sequence[int])->torch.Tensor:img=to_pil_image(pic.cpu())cropped_img=img.crop(box)return(pil_to_tensor(cropped_img)/255.).to(pic.device,pic.dtype)# Use register_fake to add a ``FakeTensor`` kernel for the operator@crop.register_fakedef_(pic,box):channels=pic.shape[0]x0,y0,x1,y1=boxresult=pic.new_empty(y1-y0,x1-x0,channels).permute(2,0,1)# The result should have the same metadata (shape/strides/``dtype``/device)# as running the ``crop`` function above.returnresult

After this,crop now works without graph breaks:

@torch.compile(fullgraph=True)deff(img):returncrop(img,(10,10,50,50))cropped_img=f(img)display(img)
python custom ops
display(cropped_img)
python custom ops

Adding training support for crop#

Usetorch.library.register_autograd to add training support for an operator.Prefer this over directly usingtorch.autograd.Function; some compositions ofautograd.Function with PyTorch operator registration APIs can lead to (andhas led to) silent incorrectness when composed withtorch.compile.

If you don’t need training support, there is no need to usetorch.library.register_autograd.If you end up training with acustom_op that doesn’t have an autogradregistration, we’ll raise an error message.

The gradient formula forcrop is essentiallyPIL.paste (we’ll leave thederivation as an exercise to the reader). Let’s first wrappaste into acustom operator:

@torch.library.custom_op("mylib::paste",mutates_args=())defpaste(im1:torch.Tensor,im2:torch.Tensor,coord:Sequence[int])->torch.Tensor:assertim1.device==im2.deviceassertim1.dtype==im2.dtypeim1_pil=to_pil_image(im1.cpu())im2_pil=to_pil_image(im2.cpu())PIL.Image.Image.paste(im1_pil,im2_pil,coord)return(pil_to_tensor(im1_pil)/255.).to(im1.device,im1.dtype)@paste.register_fakedef_(im1,im2,coord):assertim1.device==im2.deviceassertim1.dtype==im2.dtypereturntorch.empty_like(im1)

And now let’s useregister_autograd to specify the gradient formula forcrop:

defbackward(ctx,grad_output):grad_input=grad_output.new_zeros(ctx.pic_shape)grad_input=paste(grad_input,grad_output,ctx.coords)returngrad_input,Nonedefsetup_context(ctx,inputs,output):pic,box=inputsctx.coords=box[:2]ctx.pic_shape=pic.shapecrop.register_autograd(backward,setup_context=setup_context)

Note that the backward must be a composition of PyTorch-understood operators,which is why we wrapped paste into a custom operator instead of directly usingPIL’s paste.

img=img.requires_grad_()result=crop(img,(10,10,50,50))result.sum().backward()display(img.grad)
python custom ops

This is the correct gradient, with 1s (white) in the cropped region and 0s(black) in the unused region.

Testing Python Custom operators#

Usetorch.library.opcheck to test that the custom operator was registeredcorrectly. This does not test that the gradients are mathematically correct;please write separate tests for that (either manual ones ortorch.autograd.gradcheck).

To useopcheck, pass it a set of example inputs to test against. If youroperator supports training, then the examples should include Tensors thatrequire grad. If your operator supports multiple devices, then the examplesshould include Tensors from each device.

examples=[[torch.randn(3,64,64),[0,0,10,10]],[torch.randn(3,91,91,requires_grad=True),[10,0,20,10]],[torch.randn(3,60,60,dtype=torch.double),[3,4,32,20]],[torch.randn(3,512,512,requires_grad=True,dtype=torch.double),[3,4,32,45]],]forexampleinexamples:torch.library.opcheck(crop,example)

Mutable Python Custom operators#

You can also wrap a Python function that mutates its inputs into a customoperator.Functions that mutate inputs are common because that is how many low-levelkernels are written; for example, a kernel that computessin may take inthe input and an output tensor and writeinput.sin() to the output tensor.

We’ll usenumpy.sin to demonstrate an example of a mutable Pythoncustom operator.

importnumpyasnp@torch.library.custom_op("mylib::numpy_sin",mutates_args={"output"},device_types="cpu")defnumpy_sin(input:torch.Tensor,output:torch.Tensor)->None:assertinput.device==output.deviceassertinput.device.type=="cpu"input_np=input.numpy()output_np=output.numpy()np.sin(input_np,out=output_np)

Because the operator doesn’t return anything, there is no need to registeraFakeTensor kernel (meta kernel) to get it to work withtorch.compile.

@torch.compile(fullgraph=True)deff(x):out=torch.empty(3)numpy_sin(x,out)returnoutx=torch.randn(3)y=f(x)asserttorch.allclose(y,x.sin())

And here’s anopcheck run telling us that we did indeed register the operator correctly.opcheck would error out if we forgot to add the output tomutates_args, for example.

example_inputs=[[torch.randn(3),torch.empty(3)],[torch.randn(0,3),torch.empty(0,3)],[torch.randn(1,2,3,4,dtype=torch.double),torch.empty(1,2,3,4,dtype=torch.double)],]forexampleinexample_inputs:torch.library.opcheck(numpy_sin,example)

Conclusion#

In this tutorial, we learned how to usetorch.library.custom_op tocreate a custom operator in Python that works with PyTorch subsystemssuch astorch.compile and autograd.

This tutorial provides a basic introduction to custom operators.For more detailed information, see:

Total running time of the script: (0 minutes 3.592 seconds)