PyTorch Symmetric Memory#
Created On: Oct 24, 2025 | Last Updated On: Jan 08, 2026
Note
torch.distributed._symmetric_memory is currently in alpha state and underdevelopment. API changes may be possible.
Why Symmetric Memory?#
With rapidly evolving parallelization techniques, existing frameworks andlibraries often struggle to keep up, and developers increasingly rely on customimplementations directly scheduling communications and computations. In recentyears we’ve witnessed a shift from primarily relying on one-dimensionaldata-parallelism techniques to multi-dimensional parallelism ones. The latterhave different latency requirements for different types of communications andthus require fine-grained overlapping of compute and communications.
To minimize compute interference, they also require the use of copy engines andnetwork interface cards (NICs) to drive communication. Network transportprotocols such as remote direct memory access (RDMA) enhance the performance byenabling direct, high-speed, and low-latency communication between processorsand memory. This increase in variety indicates the need for finer-grainedcommunication primitives than are offered today by high-level collective APIs,ones that would enable developers to implement specific algorithms tailored fortheir use cases, such as low-latency collectives, fine-grainedcompute-communications overlap, or custom fusions.
Furthermore, today’s advanced AI systems connect GPUs with high-bandwidth links(such as NVLinks, InfiniBand or RoCE), making GPU global memory directlyaccessible to peers. Such connections present a great opportunity forprogrammers to program the system as a single, gigantic GPU with vast accessiblememory, instead of programming singular “GPU islands.”
In this document, we will show how you can use PyTorch Symmetric Memory toprogram modern GPU systems as a “single GPU” and achieve fine-grained remoteaccess.
What PyTorch Symmetric Memory unlocks?#
PyTorch Symmetric Memory unlocks three new capabilities:
Customized communication patterns: Increased flexibility in kernel writingallows developers to write custom kernels that implement their customcomputations and communications, directly tailored to the need of theapplication. It will also be straightforward to add support for new data typesalong with the special compute that those data types might require, even if it’snot present yet in the standard libraries.
In-kernel compute-comm fusion: Device-initiated communication capabilityallows developers to write kernels with both computation and communicationinstructions, allowing for the fusion of computation and data movement in thesmallest possible granularity.
Low-latency remote access: Network transport protocols like RDMA enhance theperformance of symmetric memory in networked environments by enabling direct,high-speed, and low-latency communication between processors and memory. RDMAeliminates the overhead associated with the traditional network stack and CPUinvolvement. It also offloads data transfer from the compute to the NICs,freeing up compute resources for computational tasks.
Next, we will show you how PyTorch Symmetric Memory (SymmMem) enables newapplications with the above capabilities.
A “Hello World” example#
The PyTorch SymmMem programming model involves two key elements:
creating symmetric tensors
creating SymmMem kernels
To create symmetric tensors, one can use thetorch.distributed._symmetric_memory package:
importtorch.distributed._symmetric_memoryassymm_memt=symm_mem.empty(128,device=torch.device("cuda",rank))hdl=symm_mem.rendezvous(t,group)
Thesymm_mem.empty function creates a tensor that is backed by a symmetricmemory allocation. Therendezvous function establishes a rendezvous with peersin the group, and returns a handle to the symmetric memory allocation. Thehandle provides method to access information related to the symmetric memoryallocation, such as pointers to symmetric buffer on peer ranks, multicastpointer (if supported), and signal pads.
Theempty andrendezvous functions must be called in the same order on allranks in the group.
Then, collectives can be called on these tensors. For example, to perform aone-shot all-reduce:
# Most SymmMem ops are under the torch.ops.symm_mem namespacetorch.ops.symm_mem.one_shot_all_reduce(t,"sum",group)
Please note thattorch.ops.symm_mem is an “op namespace” instead of a pythonmodule. Therefore, you can’t import it byimporttorch.ops.symm_mem, neithercan you import an op byfromtorch.ops.symm_memimportone_shot_all_reduce.You can call the op directly as in the example above.
Write your own kernel#
To write your own kernel doing communications with symmetric memory, you’ll needaccess to the addresses of mapped peer buffers and access to signal pads thatare required for synchronization. In the kernel you’ll also need to performcorrect synchronizations to make sure that peers are ready for communication,and signal to them that this GPU is ready.
PyTorch Symmetric Memory provides CUDA Graph-compatible synchronizationprimitives that operate on the signal pad accompanying each symmetric memoryallocation. Kernels using symmetric memory can be written both in CUDA and inTriton. Here’s an example allocating symmetric tensor and exchanging handles:
importtorch.distributed._symmetric_memoryassymm_memdist.init_process_group()rank=dist.get_rank()# Allocate a tensort=symm_mem.empty(4096,device=f"cuda:{rank}")# Establish symmetric memory and obtain the handlehdl=symm_mem.rendezvous(t,dist.group.WORLD)
Access to buffer pointers, multimem pointer, and signal pads is provided via:
hdl.buffer_ptrshdl.multicast_ptrhdl.signal_pad_ptrs
Data pointed to bybuffer_ptrs can be accessed just like regular local data,and any necessary compute can also be performed in the usual ways. As with localdata, you can and should use vectorized accesses to improve efficiency.
Symmetric memory is especially convenient for writing kernels in Triton. Whilepreviously Triton removed the barriers to writing efficient CUDA code, nowcommunications can be added easily to Triton kernels. The kernel belowdemonstrates a low-latency, all-reduce kernel written in Triton.
@triton.jitdefone_shot_all_reduce_kernel(buf_tuple,signal_pad_ptrs,output_ptr,numel:tl.constexpr,rank:tl.constexpr,world_size:tl.constexpr,BLOCK_SIZE:tl.constexpr,):ptx_utils.symm_mem_sync(signal_pad_ptrs,None,rank,world_size,hasSubsequenceMemAccess=True)pid=tl.program_id(axis=0)block_start=pid*BLOCK_SIZEwhileblock_start<numel:offsets=block_start+tl.arange(0,BLOCK_SIZE)mask=offsets<numelacc=tl.zeros((BLOCK_SIZE,),dtype=tl.bfloat16)foriintl.static_range(world_size):buffer_rank=buf_tuple[i]x=tl.load(buffer_rank+offsets,mask=mask)acc+=xtl.store(output_ptr+offsets,acc,mask=mask)block_start+=tl.num_programs(axis=0)*BLOCK_SIZEptx_utils.symm_mem_sync(signal_pad_ptrs,None,rank,world_size,hasPreviousMemAccess=True)
Synchronizations at the beginning and the end of the kernel above guarantee thatall the processes see consistent data. The bulk of the kernel is recognizableTriton code, and Triton will optimize it behind the scene, making sure memoryaccesses are performed in an efficient way with vectorization and unrolling. Aswith all Triton kernels, it is easily modifiable to add extra computations orchange the communication algorithm. Visithttps://github.com/meta-pytorch/kraken/blob/main/kraken to see additionalutilities and examples of using symmetric memory to implement common patterns inTriton.
Scale out#
Large language models distribute experts onto more than 8 GPUs, hence requiringmulti-node access capability. NICs capable of RDMA come to help. In addition,software libraries such as NVSHMEM or rocSHMEM abstract away the programmingdifference between intra-node access and inter-node access with primitives thatare slightly higher level than pointer access, such as put and get.
PyTorch provides NVSHMEM plugins to augment Triton kernels’ cross-nodecapabilities. As shown in the code snippet below, one can initiate a cross-nodeput command within the kernel.
importtorch.distributed._symmetric_memory._nvshmem_tritonasnvshmemfromtorch.distributed._symmetric_memory._nvshmem_tritonimportrequires_nvshmem@requires_nvshmem@triton.jitdefmy_put_kernel(dest,src,nelems,pe,):nvshmem.put(dest,src,nelems,pe)
Therequires_nvshmem decorator is used to indicate that the kernel requiresthe NVSHMEM device library as an external dependency. When Triton compiles thekernel, the decorator will search your system paths for the NVSHMEM devicelibrary. If it is available, Triton will include the necessary device assemblyto use the NVSHMEM functions.
Using Memory Pool#
Memory pool allows PyTorch SymmMem to cache memory allocations that have beenrendezvoused, saving time when creating new tensors. For convenience, PyTorchSymmMem has added aget_mem_pool API to return a symmetric memory pool. Userscan use the returned MemPool with thetorch.cuda.use_mem_pool context manager.In the example below, tensorx will be created from symmetric memory:
importtorch.distributed._symmetric_memoryassymm_memmempool=symm_mem.get_mem_pool(device)withtorch.cuda.use_mem_pool(mempool):x=torch.arange(128,device=device)torch.ops.symm_mem.one_shot_all_reduce(x,"sum",group_name)
Similarly, you can put a compute operation under the MemPool context, and theresult tensor will be created from symmetric memory too.
dim=1024w=torch.ones(dim,dim,device=device)x=torch.ones(1,dim,device=device)mempool=symm_mem.get_mem_pool(device)withtorch.cuda.use_mem_pool(mempool):# y will be in symmetric memoryy=torch.mm(x,w)
As of torch 2.11, theCUDA andNVSHMEM backends support MemPool. MemPoolsupport of theNCCL backend is in progress.
API Reference#
- torch.distributed._symmetric_memory.empty(*size:_int,dtype:_dtype|None=None,device:_device|None=None)→Tensor[source]#
- torch.distributed._symmetric_memory.empty(size:Sequence[_int],*,dtype:_dtype|None=None,device:_device|None=None)→Tensor
Similar to
torch.empty(). The returned tensor can be used bytorch._distributed._symmetric_memory.rendezvous()to establish asymmetric memory tensor among participating processes.- Parameters:
size (int...) – a sequence of integers defining the shape of the output tensor.Can be a variable number of arguments or a collection like a list or tuple.
- Keyword Arguments:
dtype (
torch.dtype, optional) – the desired data type of returned tensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()).device (
torch.device, optional) – the desired device of returned tensor.Default: ifNone, uses the current device for the default tensor type(seetorch.set_default_device()).devicewill be the CPUfor CPU tensor types and the current CUDA device for CUDA tensor types.
- torch.distributed._symmetric_memory.rendezvous(tensor,group)→_SymmetricMemory[source]#
Establish a symmetric memory tensor among participating processes. This isa collective operation.
- Parameters:
tensor (
torch.Tensor) – the local tensor used to establish the symmetric memory tensor.It must be allocated viatorch._distributed._symmetric_memory.empty(). The shape,dtype, and device type must be identical across all participating processes.group (Union[str,
torch.distributed.ProcessGroup]) – The group identifying theparticipating processes. This can be either a group name or a process group object.
- Return type:
_SymmetricMemory
- torch.distributed._symmetric_memory.is_nvshmem_available()→bool[source]#
Check if NVSHMEM is available in current build and on current system.
- Return type:
- torch.distributed._symmetric_memory.set_backend(name)[source]#
Set the backend for symmetric memory allocation. This is a global settingand affects all subsequent calls to
torch._distributed._symmetric_memory.empty(). Note that the backendcannot be changed once a symmetric memory tensor has been allocated.- Parameters:
backend (str) – the backend for symmetric memory allocation. Currently,only“NVSHMEM”,“CUDA”,“NCCL” are supported.
- torch.distributed._symmetric_memory.get_backend(device)[source]#
Get the backend for symmetric memory allocation for a given device. If notfound, return None.
- Parameters:
device (torch.device or str) – the device for which to get the backend.
- Return type:
str | None
- torch.distributed._symmetric_memory.get_mem_pool(device)[source]#
Get the symmetric memory pool for a given device. If not found, create a newpool.
The tensor allocations with this pool must be symmetric across ranks. Theallocated tensors can be used with symmetric operations, for example,operations defined undertorch.ops.symm_mem.
- Parameters:
device (torch.device or str) – the device for which to get the symmetric memory pool.
- Returns:
the symmetric memory pool for the given device.
- Return type:
torch.cuda.MemPool
Example:
>>>pool=torch.distributed._symmetric_memory.get_mem_pool("cuda:0")>>>withtorch.cuda.use_mem_pool(pool):>>>tensor=torch.randn(1000,device="cuda:0")>>>tensor=torch.ops.symm_mem.one_shot_all_reduce(tensor,"sum",group_name)
Op Reference#
Note
The following ops are hosted in thetorch.ops.symm_mem namespace. You can callthem directly viatorch.ops.symm_mem.<op_name>.
- torch.ops.symm_mem.multimem_all_reduce_(input:Tensor,reduce_op:str,group_name:str)→Tensor#
Performs a multimem all-reduce operation on the input tensor. This operationrequires hardware support for multimem operations. On NVIDIA GPUs, NVLinkSHARP is required.
- torch.ops.symm_mem.multimem_all_gather_out(input:Tensor,group_name:str,out:Tensor)→Tensor#
Performs a multimem all-gather operation on the input tensor. This operation requires hardware support for multimem operations. On NVIDIA GPUs, NVLink SHARP is required.
- torch.ops.symm_mem.one_shot_all_reduce(input:Tensor,reduce_op:str,group_name:str)→Tensor#
Performs a one-shot all-reduce operation on the input tensor.
- torch.ops.symm_mem.one_shot_all_reduce_out(input:Tensor,reduce_op:str,group_name:str,out:Tensor)→Tensor#
Performs a one-shot all-reduce operation based on the input tensor and writes the result to the output tensor.
- Parameters:
input (Tensor) – Input tensor to perform all-reduce on. Must be symmetric.
reduce_op (str) – Reduction operation to perform. Currently only “sum” is supported.
group_name (str) – Name of the group to perform all-reduce on.
out (Tensor) – Output tensor to store the result of the all-reduce operation. Can be a regular tensor.
- torch.ops.symm_mem.two_shot_all_reduce_(input:Tensor,reduce_op:str,group_name:str)→Tensor#
Performs a two-shot all-reduce operation on the input tensor.
- torch.ops.symm_mem.all_to_all_vdev(input:Tensor,out:Tensor,in_splits:Tensor,out_splits_offsets:Tensor,group_name:str)→None#
Performs an all-to-all-v operation using NVSHMEM, with split information provided on device.
- Parameters:
input (Tensor) – Input tensor to perform all-to-all on. Must be symmetric.
out (Tensor) – Output tensor to store the result of the all-to-all operation. Must be symmetric.
in_splits (Tensor) – Tensor containing splits of data to send to each peer. Must be symmetric. Must be of size (group_size,). The splits are in the unit of elements in the 1st dimension.
out_splits_offsets (Tensor) – Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size). The rows are (in order): output splits and output offsets.
group_name (str) – Name of the group to perform all-to-all on.
- torch.ops.symm_mem.all_to_all_vdev_2d(input:Tensor,out:Tensor,in_splits:Tensor,out_splits_offsets:Tensor,group_name:str[,major_align:int=None])→None#
Perform a 2D all-to-all-v operation using NVSHMEM, with split information provided on device. In Mixture of Experts models, this operation can be used to dispatch tokens.
- Parameters:
input (Tensor) – Input tensor to perform all-to-all on. Must be symmetric.
out (Tensor) – Output tensor to store the result of the all-to-all operation. Must be symmetric.
in_splits (Tensor) – Tensor containing the splits of data to send to each expert. Must be symmetric. Must be of size (group_size * ne,), where ne is the number of experts per rank. The splits are in the unit of elements in the 1st dimension.
out_splits_offsets (Tensor) – Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size * ne). The rows are (in order): output splits and output offsets.
group_name (str) – Name of the group to perform all-to-all on.
major_align (int) – Optional alignment for the major dimension of the output chunk for each expert. If not provided, the alignment is assumed to be 1. Any alignment adjustment will be reflected in the output offsets.
A 2D AllToAllv shuffle is illustrated below:(world_size = 2, ne = 2, total number of experts = 4):
Source:|Rank0|Rank1||c0|c1|c2|c3|d0|d1|d2|d3|Dest:|Rank0|Rank1||c0|d0|c1|d1|c2|d2|c3|d3|
where eachc_i /d_i are slices of theinput tensor, targeting experti, with length indicated by input splits. That is, the 2D AllToAllvshuffle achieves a transpose from rank-major order at input to expert-majororder at output.
Ifmajor_align is not 1, the output offsets of c1, c2, c3 will beup-aligned to this value. For example, if c0 has length 5 and d0 haslength 7 (making a total of 12), and if themajor_align is set to 16,the output offset of c1 will be 16. Similar for c2 and c3. This value hasno effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3.Note: since cutlass does not support empty bins, we set the aligned lengthtomajor_align if it is 0. Seepytorch/pytorch#152668.
- torch.ops.symm_mem.all_to_all_vdev_2d_offset(Tensorinput,Tensorout,Tensorin_splits_offsets,Tensorout_splits_offsets,strgroup_name)→None#
Perform a 2D AllToAllv shuffle operation, with input split and offsetinformation provided on device. The input offsets are not required to beexact prefix sum of the input splits, i.e. paddings are allowed between thesplit chunks. The paddings, however, will not be transferred to peerranks.
In Mixture of Experts models, this operation can be used to combine tokensprocessed by experts on parallel ranks. This operation can be viewed as an“reverse” operation to theall_to_all_vdev_2d operation (which shufflestokens to experts).
- Parameters:
input (Tensor) – Input tensor to perform all-to-all on. Must be symmetric.
out (Tensor) – Output tensor to store the result of the all-to-all operation. Must be symmetric.
in_splits_offsets (Tensor) – Tensor containing the splits and offsets of data to send to each expert. Must be symmetric. Must be of size (2, group_size * ne), wherene is the number of experts. The rows are (in order): input splits and input offsets. The splits are in the unit of elements in the 1st dimension.
out_splits_offsets (Tensor) – Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size * ne). The rows are (in order): output splits and output offsets.
group_name (str) – Name of the group to perform all-to-all on.
- torch.ops.symm_mem.tile_reduce(in_tile:Tensor,out_tile:Tensor,root:int,group_name:str[,reduce_op:str='sum'])→None#
Reduces a 2D tile from all ranks to a specified root rank within a process group.
- Parameters:
in_tile (Tensor) – Input 2D tensor to be reduced. Must be symmetrically allocated.
out_tile (Tensor) – Output 2D tensor to contain the result of the reduction. Must be symmetric and have the same shape, dtype, and device asin_tile.
root (int) – The rank of the process in the specified group that will receive the reduced result.
group_name (str) – The name of the symmetric memory process group to perform the reduction in.
reduce_op (str) – The reduction operation to perform. Currently, only
"sum"is supported. Defaults to"sum".
This function reducesin_tile tensors from all members of the group, writing the result toout_tile at the root rank. All ranks must participate and provide the samegroup_name and tensor shapes.
Example:
>>>>>># Reduce the bottom-right quadrant of a tensor>>>tile_size=full_size//2>>>full_inp=symm_mem.empty(full_size,full_size)>>>full_out=symm_mem.empty(full_size,full_size)>>>s=slice(tile_size,2*tile_size)>>>in_tile=full_inp[s,s]>>>out_tile=full_out[s,s]>>>torch.ops.symm_mem.tile_reduce(in_tile,out_tile,root=0,group_name)
- torch.ops.symm_mem.multi_root_tile_reduce(in_tiles:list[Tensor],out_tile:Tensor,roots:list[int],group_name:str,[reduce_op:str='sum'])→None#
Perform multiple tile reductions concurrently, with each tile reduced to a separate root.
: param list[Tensor] in_tiles: A list of input tensors.: param Tensor out_tile: Output tensor to contain the reduced tile.: param list[int] roots: A list of root ranks each corresponding to an input tile inin_tiles, in the same order. A rank cannot be a root more than once.: param str group_name: Name of the group to use for the collective operation.: param str reduce_op: Reduction operation to perform. Currently only “sum” is supported.
Example:
>>>>>># Reduce four quadrants of a tensor, each to a different root>>>tile_size=full_size//2>>>full_inp=symm_mem.empty(full_size,full_size)>>>s0=slice(0,tile_size)>>>s1=slice(tile_size,2*tile_size)>>>in_tiles=[full_inp[s0,s0],full_inp[s0,s1],full_inp[s1,s0],full_inp[s1,s1]]>>>out_tile=symm_mem.empty(tile_size,tile_size)>>>roots=[0,1,2,3]>>>torch.ops.symm_mem.multi_root_tile_reduce(in_tiles,out_tile,roots,group_name)