Extending dispatcher for a new backend in C++#
Created On: Feb 01, 2021 | Last Updated: Sep 23, 2024 | Last Verified: Nov 05, 2024
In this tutorial we will walk through all necessary steps to extend the dispatcher toadd a new device living outsidepytorch/pytorch repo and maintain it to keep insync with native PyTorch devices. Here we’ll assume that you’re familiar with howtoregister a dispatched operator in C++ and how to write acustom autograd function.
Note
This tutorial touches a lot of internal components inside PyTorch which are being actively improved,please expect changes to APIs if you decide to follow this tutorial. We’ll keep this tutorialup to date with the latest APIs.
What’s a new backend?#
Adding a new backend to PyTorch requires a lot of development and maintenance from backend extenders.Before adding a new backend, let’s first consider a few common use cases and recommended solutions for them:
If you have new algorithms for an existing PyTorch operator, send a PR to PyTorch.
If you want to propose a new operator, send a feature request/PR to PyTorch.
If you want to add support for a new device/hardware like Google TPU and customized chips, which often requires usinghardware-specific API to write kernels, follow this tutorial and add a out-of-tree backend to PyTorch.
If you want to add support for existing operators but with a different Tensor layout/representationlike sparse and quantized, which enforces your kernels to be written in a way that’s more efficientgiven the layout/representation limitation, follow this tutorial and add a out-of-tree backend to PyTorch.
In this tutorial we’ll mainly focus on adding a new out-of-tree device below. Adding out-of-tree supportfor a different tensor layout might share many common steps with devices, but we haven’t seen an example ofsuch integrations yet so it might require additional work from PyTorch to support it.
Get a dispatch key for your backend#
PyTorch operators are implemented in C++ and made available in Python frontend through Python bindings.The PyTorch dispatcher divides the implementation of an operator into multiple kernels, each of which isassociated with a specific dispatch key. Supporting a new backend in PyTorch essentially means writinga kernel for each PyTorch operator in C++ and then registering them to a dispatch key representing yourcustomized backend in the dispatcher.
Dispatch key is your identifier in the dispatcher system. The dispatcher looks at the dispatch keys carried oninput tensors and calls the right kernel accordingly. PyTorch provides three reserved dispatch keys(and their corresponding Autograd keys) for prototyping out-of-tree backend extensions:
PrivateUse1/AutogradPrivateUse1
PrivateUse2/AutogradPrivateUse2
PrivateUse3/AutogradPrivateUse3
You can choose any of keys above to prototype your customized backend.To create a Tensor onPrivateUse1 backend, you need to set dispatch key inTensorImpl constructor.
/* Example TensorImpl constructor */TensorImpl(Storage&&storage,DispatchKeySetks,constcaffe2::TypeMetadata_type);// To create a TensorImpl on PrivateUse1 backend, pass in the following ks to TensorImpl creation.DispatchKeySetks=c10::DispatchKeySet{c10::DispatchKey::PrivateUse1,c10::DispatchKey::AutogradPrivateUse1};
Note thatTensorImpl class above assumes your Tensor is backed by a storage like CPU/CUDA. We alsoprovideOpaqueTensorImpl for backends without a storage. And you might need to tweak/override certainmethods to fit your customized hardware.One example in pytorch repo isVulkan TensorImpl.
Note
Once the prototype is done and you plan to do regular releases for your backend extension, please feel free tosubmit a PR topytorch/pytorch to reserve a dedicated dispatch key for your backend.
Get the full list of PyTorch operators#
PyTorch provides a full list of extensible C++ operators in generated filebuild/aten/src/ATen/RegistrationDeclarations.h.This file is only available after building PyTorch from source.Here’s a snippet of the file:
Tensorabs(constTensor&self);// {"schema": "aten::abs(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}Tensor&abs_(Tensor&self);// {"schema": "aten::abs_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}Tensor&abs_out(Tensor&out,constTensor&self);// {"schema": "aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}Tensorabsolute(constTensor&self);// {"schema": "aten::absolute(Tensor self) -> Tensor", "dispatch": "False", "default": "False"}Tensor&absolute_(Tensor&self);// {"schema": "aten::absolute_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "False"}Tensor&absolute_out(Tensor&out,constTensor&self);// {"schema": "aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "False"}Tensorangle(constTensor&self);// {"schema": "aten::angle(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}Tensor&angle_out(Tensor&out,constTensor&self);// {"schema": "aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}Tensorsgn(constTensor&self);// {"schema": "aten::sgn(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
There’re multiple fields associated with a single operator. Let’s break it down usingabs_out as an example:
Tensor&abs_out(Tensor&out,constTensor&self);is the C++ signature of the operator, your C++kernel should match this signature exactly.aten::abs.out(Tensorself,*,Tensor(a!)out)->Tensor(a!)is the unique schema representing the operator,which also contains aliasing and mutation annotations compared to the C++ signature. This is the unique identifierthe dispatcher uses to find an operator.dispatchanddefaultare boolean fields that provide information about what native PyTorch kernelscan do, thus implies whether it’s required for backend extenders to implement the kernel.More details can be found inregister kernels for the new backend.
Register kernels for the new backend#
To register your kernels to PyTorch dispatcher, you can use theTORCH_LIBRARY_IMPL API described inRegistering a Dispatched Operator in C++:
TORCH_LIBRARY_IMPL(aten,PrivateUse1,m){m.impl(<schema_my_op1>,&my_op1);m.impl(<schema_my_op2>,&my_op2);m.impl(<schema_my_op2_backward>,&my_op2_backward);}
Now let’s zoom in and what operator requires a kernel from a customized backend and what’sinside the kernels exactly.
PyTorch currently has more than 1600 operators and it’s still growing. It’s unrealisticfor backend extensions to keep up with this speed. Even for native backends like CPUor CUDA, it often requires a lot of work to write dedicated kernels for every new op.
Fortunately, some native PyTorch kernels are written in a way that they decompose tocombination of several known operators. In other words, you only need to implementa set of known operators (ops that require registration below) instead of all PyTorch operators.
PyTorch operators can be classified into two categories:
Ops that require registration: PyTorch native implementation for these ops is backend specificand thus it’s required to provide a kernel for customized backend. Otherwise calling such opon the customized backend will error out.
In
RegistrationDeclarations.hthese operators havedispatchset to Trueanddefaultset to Falsein the metadata found in their accompanying comments.
Registration is optional: backend extenders can skip registering to these ops without sacrificing any support.However, if a backend extender wants to override the default kernel provided by PyTorch, they can stillregister their customized kernel to their backend and the dispatcher will use it for your backend only.For example, current implementation of PyTorch’s
max_pool2dreturnsindicesas part of forward outputs whichcreates overhead in torch_xla, so torch_xla registers its own kernel formax_pool2dinstead.In
RegistrationDeclarations.hthese operators havedispatchset to Falseordefaultset to Truein the metadata found in their accompanying comments.
Autograd support for the new backend#
Gradient formulas are mostly purely mathematical and thus are general for all backends.PyTorch often registers a kernel to alias dispatch key Autograd, which means it can be used by all backends.
For these operators you don’t have to worry about their derivative formulas,you can just write forward definitions for operators inRegistrationDeclarations.h and PyTorch handlesbackward for you automatically.
Tensormy_op1(constTensor&self,constTensor&other){// call your backend-specific APIs to implement my_op so that// it matches PyTorch's native behavior}TORCH_LIBRARY_IMPL(aten,PrivateUse1,m){m.impl(<schema_my_op1>,&my_op);}
In some cases, PyTorch backward kernel implementations are also device specific so that they can squeeze outmax performance out of each backend. For those operators you’ll see op_backward showing up inRegistrationDeclarations.h asrequired registration as well.
Tensormy_op2_backward(constTensor&self,constTensor&other){// call your backend-specific APIs to implement my_op2_backward so that// it matches PyTorch's native behavior}// Note backward kernel is still registered to PrivateUse1 instead of AutogradPrivateUse1.// PyTorch will wrap your backward kernel with proper autograd setup and then link to it in// my_op2's AutogradPrivateUse1 kernel.TORCH_LIBRARY_IMPL(aten,PrivateUse1,m){m.impl(<schema_my_op2>,&my_op2);m.impl(<schema_my_op2_backward>,&my_op2_backward);}
In a fewrare cases, PyTorch’s gradient formula for certain operators may have assumptions that don’t generalizefor all backends. In those cases backend extenders can optionally override PyTorch Autograd layer by registeringa kernel from torch::autograd::Function to the corresponding dispatch key (for example, AutogradPrivateUse1 ifyou’re using PrivateUse1 for your backend):
classMyAddFunction:publictorch::autograd::Function<MyAddFunction>{public:staticTensorforward(AutogradContext*ctx,torch::Tensorself,torch::Tensorother){at::AutoNonVariableTypeModeg;returnmyadd(self,other);}statictensor_listbackward(AutogradContext*ctx,tensor_listgrad_outputs){autograd_output=grad_outputs[0];return{grad_output,grad_output};}};Tensormyadd_autograd(constTensor&self,constTensor&other){returnMyAddFunction::apply(self,other)[0];}// Register the autograd kernel to AutogradPrivateUse1TORCH_LIBRARY_IMPL(aten,AutogradPrivateUse1,m){m.impl(<myadd_schema>,&myadd_autograd);}// Register the inference kernel to PrivateUse1TORCH_LIBRARY_IMPL(aten,PrivateUse1,m){m.impl(<myadd_schema>,&myadd);}
With this trick you have full control over both training and inference behavior formy_add operator in your backend.Here’san example in thepytorch/xla repository.
Build an extension#
Out-of-tree backend is supported by adding a C++ extension to PyTorch.Once you have kernels and registrations ready, you can build a C++ extension bywriting asetup.py script that usessetuptools to compile C++ code. Here’s a simplified example frompytorch/xla repo:
fromsetuptoolsimportsetupfromtorch.utils.cpp_extensionimportBuildExtension,CppExtensionsetup(name='torch_xla',ext_modules=[CppExtension('_XLAC',torch_xla_sources,include_dirs=include_dirs,extra_compile_args=extra_compile_args,library_dirs=library_dirs,extra_link_args=extra_link_args+ \[make_relative_rpath('torch_xla/lib')],),],cmdclass={'build_ext':Build,# Build is a derived class of BuildExtension}# more configs...)
Seeour C++ extension tutorialfor more details.
Custom operator support#
Your new backend should work seamlessly withcustomized operators extended in pythonwithout writing any new kernels as long as the customized operator is composed of existingPyTorch operators (which are already supported by your backend).
Forcustom operators extended in C++ they often come with abackend specific C++ kernel implementation e.g. nms kernel in torchvsionas well asa customized Python API e.g. torch.ops.torchvision.nms.To support these operators, backend extenders will need to write a C++ kernel for your backend and properlyregister it to the corresponding namespace in the dispatcher similar to supporting PyTorch native operators.Alternatively you could also add a customized API in your extension e.gtorch_xla.core.functions.nms forthese adhoc requests.
JIT support#
As we mentioned inRegistering a Dispatched Operator in C++, kernels registered throughm.impl() APIsupport being called in both unboxed and boxed ways. In other words your customized backend can also work with ourJIT tracing/scripting frontend just like the in-tree backends like CPU or CUDA do. You could potentially also write specialized optimizationpasses for your backend on a JIT graph. But we will not discuss it here since we haven’t finalized the integration pointin JIT, so the current backend support will focus on the eager frontend for now.
Testing your backend against native PyTorch backends#
PyTorch lets tests run on multiple device types using itsgeneric device type testing framework.You can find details abouthow tests use itand information abouthow to add a new device type.Once added, PyTorch tests using the generic device type testing framework will be run using your device type, too.Seethis Wiki page for an example of how tests are instantiated.
Running PyTorch’s existing test suites with your device type is important to ensure correctness,but not all PyTorch features are supported by every device type. The generic device type testingframework allows for considerable customization so that device types can select which tests to run,which dtypes they support, and even which precisions to use when comparing tensors for equality.
An example device type that uses the generic device type testing framework and doesn’t ship withPyTorch is XLA. Seeits extension of the generic device type testing framework,which contains examples of block listing tests, block listing dtypes, and overriding test precision.
The generic device type testing framework is actively developed. To request a feature please file anissue on PyTorch’s Github.
Backward Compatibility#
Currently PyTorch can’t guarantee backward compatibility for registered operators.Operators, as well as their schemas, might be added/modified/deleted as needed. Registeredkernels must beexactly the same as PyTorch version. If PyTorch adds more parameters (even with defaults) for an operator, your old registration won’t work until it’s updatedto match PyTorch’s new signature.
As a result, wehighly recommend out-of-tree backend extenders only sync with major PyTorchreleases to minimize interruptions in development. PyTorch is on a quarterly release cadence.Backend extenders should join the#announcement channel atpytorch.slack.comto get latest updates on releases.
Known issues & additional notes#
Not all test suites are device generic yet. Extensible test classes can be found by searching
instantiate_device_type_testsin PyTorch codebase, e.gTestTorchDeviceType,TestViewOps,TestTensorDeviceOps,TestTypePromotionetc.There’s no extension point in C++ for serializing a python Tensor object on customized backend. Currentlyyou can only extend it by modifyingPyTorch Tensor __reduce_ex__ methodor monkey patching in out-of-tree repository.
If your backend doesn’t allow direct memory access, you should pay additional attention to supportingview ops since they’re supposed to share storage. Changes to view tensor need to propagated to itsbase tensor and vice versa.
There’s no extension point in C++ for Optimizer if your backend doesn’t work with the native PyTorchOptimizers, e.g. need to carry the states to be updated in backward like torch-xla. Such use casescurrently can only be done through adding customized API or monkey patching in out-of-tree repository.
Future Work#
Making every component in PyTorch extensible for an out-of-tree backend seamlessrequires a lot of changes to PyTorch internals. Here are a few items that we’reactively working on might improve the experience in the future:
Improve test coverage of generic testing framework.
Improve
Mathkernel coverage and more comprehensive tests to make sureMathkernel behavior matches other backends likeCPU/CUDA.Refactor
RegistrationDeclarations.hto carry the minimal information and reusePyTorch’s codegen as much as possible.Support a backend fallback kernel to automatic convert inputs to CPU and convert theresult back to the customized backend. This will allow “full” operator coverage eventhough you don’t have kernels written for every operator.
Stay in touch#
Please usePyTorch dev discussions for questions and discussions. If you haveany feature requests or bug reports, pleasefile an issue on github.
If you’re interested in helping in any of the future work items above (e.g adding moreMathkernels for PyTorch operators in C++), please reach out to us through Github or Slack!