Customize Process Group Backends Using Cpp Extensions#
Created On: Feb 01, 2022 | Last Updated: Nov 14, 2024 | Last Verified: Nov 05, 2024
Author:Howard Huang,Feng Tian,Shen Li,Min Si
Note
View and edit this tutorial ingithub.
Prerequisites:
This tutorial demonstrates how to implement a customBackend and plug that intoPyTorch distributed package usingcpp extensions. This is helpful when you need a specialized softwarestack for your hardware, or when you would like to experiment with newcollective communication algorithms.
Basics#
PyTorch collective communications power several widely adopted distributedtraining features, includingDistributedDataParallel andZeroRedundancyOptimizer.In order to make the same collective communication API work withdifferent communication backends, the distributed package abstracts collectivecommunication operations into aBackendclass. Different backends canthen be implemented as subclasses ofBackend using preferredthird-party libraries. PyTorch distributed comes with three default backends,ProcessGroupNCCL,ProcessGroupGloo, andProcessGroupMPI. However,beyond these three backends, there are also other communication libraries(e.g.,UCC,OneCCL), different types of hardware(e.g.,TPU,Trainum), and emergingcommunication algorithms (e.g.,Herring,Reduction Server).Therefore, the distributed package exposes extension APIs to allow customizingcollective communication backends.
The 4 steps below show how to implement a dummyBackend backendand use that in Python application code. Please note that this tutorial focuseson demonstrating the extension APIs, instead of developing a functioningcommunication backend. Hence, thedummy backend just covers a subset of theAPIs (all_reduce andall_gather), and simply sets the values of tensorsto 0.
Step 1: Implement a Subclass ofBackend#
This first step is to implement aBackend subclass that overridestarget collective communication APIs and runs the custom communication algorithm.The extension also needs to implement aWork subclass, whichserves as a future of communication results and allows asynchronous execution inapplication code. If the extension uses third-party libraries, it caninclude the headers and call into the library APIs from theBackendDummysubclass. The two code snippets below present the implementation ofdummy.h anddummy.cpp. See thedummy collectivesrepository for the full implementation.
// file name: dummy.hpp#include<torch/python.h>#include<torch/csrc/distributed/c10d/Backend.hpp>#include<torch/csrc/distributed/c10d/Work.hpp>#include<torch/csrc/distributed/c10d/Store.hpp>#include<torch/csrc/distributed/c10d/Types.hpp>#include<torch/csrc/distributed/c10d/Utils.hpp>#include<pybind11/chrono.h>namespacec10d{classBackendDummy:publicBackend{public:BackendDummy(intrank,intsize);c10::intrusive_ptr<Work>allgather(std::vector<std::vector<at::Tensor>>&outputTensors,std::vector<at::Tensor>&inputTensors,constAllgatherOptions&opts=AllgatherOptions())override;c10::intrusive_ptr<Work>allreduce(std::vector<at::Tensor>&tensors,constAllreduceOptions&opts=AllreduceOptions())override;// The collective communication APIs without a custom implementation// will error out if invoked by application code.};classWorkDummy:publicWork{public:WorkDummy(OpTypeopType,c10::intrusive_ptr<c10::ivalue::Future>future)// future of the output:Work(-1,// rank, only used by recvAnySource, irrelevant in this demoopType),future_(std::move(future)){}boolisCompleted()override;boolisSuccess()constoverride;boolwait(std::chrono::millisecondstimeout=kUnsetTimeout)override;virtualc10::intrusive_ptr<c10::ivalue::Future>getFuture()override;private:c10::intrusive_ptr<c10::ivalue::Future>future_;};}// namespace c10d
// file name: dummy.cpp#include"dummy.hpp"namespacec10d{// This is a dummy allgather that sets all output tensors to zero// Modify the implementation to conduct real communication asynchronouslyc10::intrusive_ptr<Work>BackendDummy::allgather(std::vector<std::vector<at::Tensor>>&outputTensors,std::vector<at::Tensor>&inputTensors,constAllgatherOptions&/* unused */){for(auto&outputTensorVec:outputTensors){for(auto&outputTensor:outputTensorVec){outputTensor.zero_();}}autofuture=c10::make_intrusive<c10::ivalue::Future>(c10::ListType::create(c10::ListType::create(c10::TensorType::get())));future->markCompleted(c10::IValue(outputTensors));returnc10::make_intrusive<WorkDummy>(OpType::ALLGATHER,std::move(future));}// This is a dummy allreduce that sets all output tensors to zero// Modify the implementation to conduct real communication asynchronouslyc10::intrusive_ptr<Work>BackendDummy::allreduce(std::vector<at::Tensor>&tensors,constAllreduceOptions&opts){for(auto&tensor:tensors){tensor.zero_();}autofuture=c10::make_intrusive<c10::ivalue::Future>(c10::ListType::create(c10::TensorType::get()));future->markCompleted(c10::IValue(tensors));returnc10::make_intrusive<WorkDummy>(OpType::ALLGATHER,std::move(future));}}// namespace c10d
Step 2: Expose The Extension Python APIs#
The backend constructors are calledfrom Python side,so the extension also needs to expose the constructor APIs to Python. This canbe done by adding the following methods. In this example,store andtimeout are ignored by theBackendDummy instantiation method, asthose are not used in this dummy implementation. However, real-world extensionsshould consider using thestore to perform rendezvous and supporting thetimeout argument.
// file name: dummy.hppclassBackendDummy:publicBackend{...<Step1code>...staticc10::intrusive_ptr<Backend>createBackendDummy(constc10::intrusive_ptr<::c10d::Store>&store,intrank,intsize,conststd::chrono::duration<float>&timeout);staticvoidBackendDummyConstructor()__attribute__((constructor)){py::objectmodule=py::module::import("torch.distributed");py::objectregister_backend=module.attr("Backend").attr("register_backend");// torch.distributed.Backend.register_backend will add `dummy` as a// new valid backend.register_backend("dummy",py::cpp_function(createBackendDummy));}}
// file name: dummy.cppc10::intrusive_ptr<Backend>BackendDummy::createBackendDummy(constc10::intrusive_ptr<::c10d::Store>&/* unused */,intrank,intsize,conststd::chrono::duration<float>&/* unused */){returnc10::make_intrusive<BackendDummy>(rank,size);}PYBIND11_MODULE(TORCH_EXTENSION_NAME,m){m.def("createBackendDummy",&BackendDummy::createBackendDummy);}
Step 3: Build The Custom Extension#
Now, the extension source code files are ready. We can then usecpp extensionsto build it. To do that, create asetup.py file that prepares the paths andcommands. Then callpythonsetup.pydevelop to install the extension.
If the extension depends on third-party libraries, you can also specifylibraries_dirs andlibraries to the cpp extension APIs. See thetorch uccproject as a real-world example.
# file name: setup.pyimportosimportsysimporttorchfromsetuptoolsimportsetupfromtorch.utilsimportcpp_extensionsources=["src/dummy.cpp"]include_dirs=[f"{os.path.dirname(os.path.abspath(__file__))}/include/"]iftorch.cuda.is_available():module=cpp_extension.CUDAExtension(name="dummy_collectives",sources=sources,include_dirs=include_dirs,)else:module=cpp_extension.CppExtension(name="dummy_collectives",sources=sources,include_dirs=include_dirs,)setup(name="Dummy-Collectives",version="0.0.1",ext_modules=[module],cmdclass={'build_ext':cpp_extension.BuildExtension})
Step 4: Use The Extension in Application#
After installation, you can conveniently use thedummy backend when callinginit_process_groupas if it is an builtin backend.
We can specify dispatching based on backend by changing thebackend argument ofinit_process_group. Wecan dispatch collective with CPU tensor togloo backend and dispatch collective with CUDA tensor todummy backend byspecifyingcpu:gloo,cuda:dummy as the backend argument.
To send all tensors todummy backend, we can simply specifydummy as the backend argument.
importosimporttorch# importing dummy_collectives makes torch.distributed recognize `dummy`# as a valid backend.importdummy_collectivesimporttorch.distributedasdistos.environ['MASTER_ADDR']='localhost'os.environ['MASTER_PORT']='29500'# Alternatively:# dist.init_process_group("dummy", rank=0, world_size=1)dist.init_process_group("cpu:gloo,cuda:dummy",rank=0,world_size=1)# this goes through gloox=torch.ones(6)dist.all_reduce(x)print(f"cpu allreduce:{x}")# this goes through dummyiftorch.cuda.is_available():y=x.cuda()dist.all_reduce(y)print(f"cuda allreduce:{y}")try:dist.broadcast(y,0)exceptRuntimeError:print("got RuntimeError when calling broadcast")