- Notifications
You must be signed in to change notification settings - Fork38
torchcomms: a modern PyTorch communications API
License
meta-pytorch/torchcomms
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation

torchcomms is a new experimental communications API for PyTorch. This providesboth the high level collectives API as well as several out of the box backends.
torchcomms requires the following software and hardware:
- Python 3.10 or higher
- PyTorch 2.8 or higher
- CUDA-capable GPU (for NCCL/NCCLX or RCCL backends)
torchcomms is available on PyPI and can be installed using pip. Alternatively,you can build torchcomms from source.
You can install torchcomms and PyTorch nightly builds using pip:
# Cuda 12.6pip install --pre torch torchcomms --index-url https://download.pytorch.org/whl/nightly/cu126# Cuda 12.8pip install --pre torch torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128# Cuda 12.9pip install --pre torch torchcomms --index-url https://download.pytorch.org/whl/nightly/cu129# Cuda 13.0pip install --pre torch torchcomms --index-url https://download.pytorch.org/whl/nightly/cu130
- CMake 3.22 or higher
- Ninja 1.10 or higher
Alternatively, you can build torchcomms from source. If you want to build the NCCLX backend, we recommend building it under a virtual conda environment.Run the following commands to build and install torchcomms:
# Create a conda environmentconda create -n torchcomms python=3.10conda activate torchcomms# Clone the repositorygit clone git@github.com:meta-pytorch/torchcomms.gitcd torchcomms
No build needed - uses the library provided by PyTorch
If you want to install the third-party dependencies directly from conda, run the following command:
USE_SYSTEM_LIBS=1 ./build_ncclx.sh
If you want to build and install the third-party dependencies from source, run the following command:
./build_ncclx.sh
Install some prerequisites
conda install conda-forge::glog=0.4.0 conda-forge::gflags conda-forge::fmt -yEnvironment variables to find rocm/rccl headers
export ROCM_HOME=/opt/rocmexport RCCL_INCLUDE=$ROCM_HOME/include/rccl./build_rccl.sh
Install some prerequisites
conda install conda-forge::glog=0.4.0 conda-forge::gflags conda-forge::fmt -yEnvironment variables to find rocm/rcclx headers
export BUILD_DIR=${PWD}/comms/rcclx/develop/build/release/buildexport ROCM_HOME=/opt/rocmexport RCCLX_INCLUDE=${BUILD_DIR}/include/rcclexport RCCLX_LIB=${BUILD_DIR}/lib./build_rcclx.sh
# Install PyTorch (if not already installed)pip install -r requirements.txtpip install --no-build-isolation -v.
You can customize the build by setting environment variables before running pip install:
# Enable/disable specific backends (ON/OFF or 1/0)export USE_NCCL=ON# Default: ONexport USE_NCCLX=ON# Default: ONexport USE_GLOO=ON# Default: ONexport USE_RCCL=OFF# Default: OFFexport USE_RCCLX=OFF# Default: OFF
Then run:
# Install PyTorch (if not already installed)pip install -r requirements.txtpip install --no-build-isolation -v.
Here's a simple example demonstrating synchronousAllReduce communication across multiple GPUs:
#!/usr/bin/env python3# example.pyimporttorchfromtorchcommsimportnew_comm,ReduceOpdefmain():# Initialize TorchComm with NCCLX backenddevice=torch.device("cuda")torchcomm=new_comm("nccl",device,name="main_comm")# Get rank and world sizerank=torchcomm.get_rank()world_size=torchcomm.get_size()# Calculate device IDnum_devices=torch.cuda.device_count()device_id=rank%num_devicestarget_device=torch.device(f"cuda:{device_id}")print(f"Rank{rank}/{world_size}: Running on device{device_id}")# Create a tensor with rank-specific datatensor=torch.full( (1024,),float(rank+1),dtype=torch.float32,device=target_device )print(f"Rank{rank}: Before AllReduce:{tensor[0].item()}")# Perform synchronous AllReduce (sum across all ranks)torchcomm.all_reduce(tensor,ReduceOp.SUM,async_op=False)# Synchronize CUDA streamtorch.cuda.current_stream().synchronize()print(f"Rank{rank}: After AllReduce:{tensor[0].item()}")# Cleanuptorchcomm.finalize()if__name__=="__main__":main()
To run this example with multiple processes (one per GPU):
# Using torchrun (recommended)torchrun --nproc_per_node=2 example.py# Or using python -m torch.distributed.launchpython -m torch.distributed.launch --nproc_per_node=2 example.py
To run this example with multiple nodes:
- Node 0
torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 --rdzv-endpoint="<master-node>:<master-port>" example.py- Node 1
torchrun --nnodes=2 --nproc_per_node=8 --node_rank=1 --rdzv-endpoint="<master-node>:<master-port>" example.pyIn the example above, we perform the following steps:
new_comm()creates a communicator with the specified backend- Each process gets its unique rank and total world size
- Each rank creates a tensor with rank-specific values
- All tensors are summed across all ranks
- Clean up communication resources
torchcomms also supports asynchronous operations for better performance.Here is the same example as above, but with asynchronousAllReduce:
importtorchfromtorchcommsimportnew_comm,ReduceOpdevice=torch.device("cuda")torchcomm=new_comm("nccl",device,name="main_comm")rank=torchcomm.get_rank()device_id=rank%torch.cuda.device_count()target_device=torch.device(f"cuda:{device_id}")# Create tensortensor=torch.full((1024,),float(rank+1),dtype=torch.float32,device=target_device)# Start async AllReducework=torchcomm.all_reduce(tensor,ReduceOp.SUM,async_op=True)# Do other work while communication happensprint(f"Rank{rank}: Doing other work while AllReduce is in progress...")# Wait for completionwork.wait()print(f"Rank{rank}: AllReduce completed")torchcomm.finalize()
See theCONTRIBUTING file for how to help out.
Source code is made available under aBSD 3 license, however you may have other legal obligations that govern your use of other content linked in this repository, such as the license or terms of service for third-party data and models.
torchcomms backends include third-party source code may be using other licenses.Please check the directory and relevant files to verify the license.
For convenience some of them are listed below:
About
torchcomms: a modern PyTorch communications API
Resources
License
Code of conduct
Contributing
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Releases
Packages0
Uh oh!
There was an error while loading.Please reload this page.