Facilitating New Backend Integration by PrivateUse1#
Created On: Oct 03, 2023 | Last Updated: May 07, 2024 | Last Verified: Nov 05, 2024
In this tutorial we will walk through some necessary steps to integrate a new backendliving outsidepytorch/pytorch repo byPrivateUse1. Note that this tutorial assumes thatyou already have a basic understanding of PyTorch.you are an advanced user of PyTorch.
Note
This tutorial only involves the parts related to the PrivateUse1 mechanism that facilitates the integration of new devices,and other parts will not be covered. At the same time, not all the modules involved in this tutorial are required,and you can choose the modules that are helpful to you according to your actual needs.
What is PrivateUse1?#
Prior to Pytorch 2.0, PyTorch provided three reserved dispatch keys (and their corresponding Autograd keys)for prototyping out-of-tree backend extensions, the three dispatch keys are as follows:
PrivateUse1/AutogradPrivateUse1PrivateUse2/AutogradPrivateUse2PrivateUse3/AutogradPrivateUse3
After the prototype verification is passed, you can apply for a private key for the new backend, such as CUDA, XLA, MPS, and so on.
However, with the rapid development of PyTorch, more and more hardware manufacturers are trying tointegrate their backends into PyTorch, which might cause the following problems:
Every new backend integration involves a lot of file modification
There is currently a hard limit on the number of Dispatch Keys (
DispatchKeySet64-bit limit)
Note
There is also a problem with integrating the new backend into PyTorch through the PrivateUse1 Key, as it is impossibleto integrate many backends at the same time. Fortunately, these out-of-tree backends are rarely used simultaneously.
In view of the above reasons, the community began to recommend new backend to be integratedinto the PyTorch viaPrivateUse1.
However, the previousPrivateUse1 mechanism is not fully capable of integrating with the new backend, because itlacks some related support in certain modules, such as Storage, AMP, Distributed, and so on.
With the arrival of Pytorch 2.1.0, a series of optimizations and enhancements have been madeforPrivateUse1 in terms of new backend integration, and it is now possible to support the integrationof new devices rapidly and efficiently.
How to integrate new backend via PrivateUse1#
In this section, we will discuss the details of integrating the new backend into Pytorch viaPrivateUse1,which mainly consists of the following parts:
Register kernels for the new backend.
Register generator for the new backend.
Register device guard for the new backend.
Register serialization and deserialization functions for new backend metadata.
Other Modules.
Register kernels for the new backend#
The new backend may have some high-performance implementations of operator, which can be registered to the dispatcherbyTORCH_LIBRARY_IMPL API described inRegistering a Dispatched Operator in C++. This involvesseveral situations:
Register all the forward operators supported by the new backend to the dispatcher, and register the fallbackat the same time, so that when the new backend does not support some operators, these operators can fall backto the CPU for execution to ensure the availability of functions.
at::Tensorwrapper_Custom_Tensor_add(constat::Tensor&self,constat::Tensor&other,constat::Scalar&alpha){// Implementation of add kernel in new backend...}TORCH_LIBRARY_IMPL(aten,PrivateUse1,m){...m.impl("add.Tensor",TORCH_FN(wrapper_Custom_Tensor_add));...}voidcustom_cpu_fallback(constc10::OperatorHandle&op,torch::jit::Stack*stack){// Add some hints about new devices that do not support and need to fall back to cpuat::native::cpu_fallback(op,stack);}TORCH_LIBRARY_IMPL(_,PrivateUse1,m){m.fallback(torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());}
Register kernels from
torch::autograd::Functionto the dispatcher byAutogradPrivateUse1, if it is necessary fornew backend to overridePyTorchAutogradlayer, the dispatcher and autograd system will automatically call the forward andbackward implementations of these operators.
classCumtomSeluFunction:publictorch::autograd::Function<CumtomSeluFunction>{// Implementation of selu kernel in new backend}at::Tensorwrapper_AutogradCumstom__selu(constat::Tensor&self){returnCumtomSeluFunction::apply(self);}TORCH_LIBRARY_IMPL(aten,AutogradPrivateUse1,m){...m.impl("selu",TORCH_FN(wrapper_AutogradCustom__selu));...}
Register kernels which want to supportautomatic mixed precision (AMP) andfallback mechanism to the dispatcher by
AutocastPrivateUse1, the autocast system will automatically call these kernels when needed.
TORCH_LIBRARY_IMPL(aten,AutocastPrivateUse1,m){...KERNEL_PRIVATEUSEONE(<operator>,<policy>)...}TORCH_LIBRARY_IMPL(_,AutocastPrivateUse1,m){m.fallback(torch::CppFunction::makeFallthrough());}
What needs to be added is that if you want to support AMP in a new backend, you need to register a newBackendModule bytorch._register_device_module("backend_name",BackendModule), and theBackendModule needs to have the following APIs:
get_amp_supported_dtype()->List[torch.dtype]get the supported dtypes on the new backend in AMP, which might support one more
dtype.
is_autocast_enabled()->boolcheck the AMP is enabled or not on the new backend.
get_autocast_dtype()->torch.dtypeget the supported
dtypeon the new backend in AMP, which is set byset_autocast_dtypeor thedefaultdtype, and the defaultdtypeistorch.float16.
set_autocast_enabled(bool)->Noneenable or disable AMP on the new backend.
set_autocast_dtype(dtype)->Noneset the supported
dtypeon the new backend in AMP, and thedtypebe contained in thedtypesgotfromget_amp_supported_dtype.
Register generator for the new backend#
It is necessary to support generators corresponding to new devices. Currently,PrivateUse1 can dynamicallyregister custom generators, which are mainly divided into the following steps.
Inherit the
GeneratorImplclass to implement the generator class corresponding to the new backend,and implement various general methods.Define a new backend
builderwith a single parameter:deviceindex.Call
REGISTER_GENERATOR_PRIVATEUSE1macro to complete dynamic registration.
structCustomGeneratorImpl:publicc10::GeneratorImpl{// Implementation of generator in new backend}at::Generatormake_custom_generator(c10::DeviceIndexdevice_index){returnat::make_generator<CustomGeneratorImpl>(device_index);}REGISTER_GENERATOR_PRIVATEUSE1(make_cumstom_generator)
Register device guard for the new backend#
PyTorch provides functionalities related to device, stream, and event switching viaDeviceGuard.This function is also applicable toPrivateUse1 Key.
Inherit the
DeviceGuardImplInterfaceclass to implement the various general methods corresponding to the new backend.Call
C10_REGISTER_GUARD_IMPLmacro to complete dynamic registration.
structCustomGuardImplfinal:publicc10::impl::DeviceGuardImplInterface{// Implementation of guard in new backend}C10_REGISTER_GUARD_IMPL(PrivateUse1,CustomGuardImpl);
Register serialization and deserialization functions for new backend metadata#
PyTorch is currently able to dynamically register serialization/deserialization functions to support the serialization and deserializationof new backend additional metadata namedbackend_meta_ in classTensorImpl.ExtraMeta. You can refer to the following steps:
Inherit the
BackendMetaclass to implementCustomBackendMetadatacorresponding to the new backend andvarious fields of the new backend can be customized in the class.Implement the serialization and deserialization functions of the new backend, the function signatures are
void(constat::Tensor&,std::unordered_map<std::string,bool>&).Call the
TensorBackendMetaRegistrymacro to complete dynamic registration.
structCustomBackendMetadata:publicc10::BackendMeta{// Implementation of backend metadata in new backend}voidfor_serialization(constat::Tensor&t,std::unordered_map<std::string,bool>&m){// Implementation of serialization}voidfor_deserialization(constat::Tensor&t,std::unordered_map<std::string,bool>&m){// Implementation of deserialization}TensorBackendMetaRegistry(c10::DeviceType::PrivateUse1,&for_serialization,&for_deserialization);
Other Modules#
In addition to the above-mentioned parts, there are some other modules that can be expanded throughPrivateUse1,such asdistributedcollectivecommunication,benchmarktimer, and others, which will be added in the future.One example aboutPrivateUse1 integration isAscend NPU.
How to Improve User Experience with Privateuse1#
The primary goal of integrating new devices throughPrivateUse1 is to meet the basic functional requirements,and the next thing to do is to improve usability, which mainly involves the following aspects.
Register new backend module to Pytorch.
Rename PrivateUse1 to a custom name for the new backend.
Generate methods and properties related to the new backend.
Register new backend module to Pytorch#
Some CUDA-related interfaces in PyTorch can be called through the following form:torch.cuda.xxx. Therefore, in order tocomply with user habits, the new backend implemented through thePrivateUse1 mechanism should also provide similar interfaces.
For example, usingAscendNPU:
torch._register_device_module('npu',torch_npu.npu)
After doing the above operations, users can call some exclusive APIs ofAscendNPU throughtorch.npu.xxx
Rename PrivateUse1 to a custom name for the new backend#
PrivateUse1 Key is the internal mechanism of the new backend integrated into PyTorch. For users, compared withPrivateUse1,the custom name strongly related to the new backend should be more friendly.
Taking theAscendNPU as an example, the first usage will be more user-friendly.
torch.rand((2,2),device='npu:0')torch.rand((2,2),device='privateuse1:0')
Now, PyTorch provides a new C++/Python API for the self-namedPrivateUse1 backend, which is very simple to use.
torch.rename_privateuse1_backend("npu")
c10::register_privateuse1_backend("npu")
Generate methods and properties related to the new backend#
After renamingPrivateUse1 to a custome name, automatically generate properties and methods related to the new backend namein theTensor,nn,Storage modules for the new backend.
Here is an example forAscendNPU:
torch.rename_privateuse1_backend("npu")unsupported_dtype=[torch.quint8]torch.utils.generate_methods_for_privateuse1_backend(for_tensor=True,for_module=True,for_storage=True,unsupported_dtype=unsupported_dtype)
Then, you can use the following methods and properties:
torch.Tensor.npu()torch.Tensor.is_nputorch.Storage.npu()torch.Storage.is_npu...
Future Work#
The improvement of thePrivateUse1 mechanism is still in progress, so the integration method ofPrivateUse1of the new module will be added in turn. Here are a few items that we are actively working on:
Add the integration method of
distributedcollectivecommunication.Add the integration method of
benchmarktimer.
Conclusion#
This tutorial walked you through the process of integrating new backends into PyTorch viaPrivateUse1, including but not limited tooperator registration, generator registration, device guard registration, and so on. At the same time, some methods are introducedto improve the user experience.