Generic Join Context Manager#
Created On: Jun 06, 2025 | Last Updated On: Jun 06, 2025
The generic join context manager facilitates distributed training on uneveninputs. This page outlines the API of the relevant classes:Join,Joinable, andJoinHook. For a tutorial, seeDistributed Training with Uneven Inputs Using the Join Context Manager.
- classtorch.distributed.algorithms.Join(joinables,enable=True,throw_on_early_termination=False,**kwargs)[source]#
This class defines the generic join context manager, which allows custom hooks to be called after a process joins.
These hooks should shadow thecollective communications of non-joined processes to prevent hanging anderroring and to ensure algorithmic correctness. Refer to
JoinHookfor details about the hook definition.Warning
The context manager requires each participating
Joinabletocall the methodnotify_join_context()before its own per-iteration collective communications to ensure correctness.Warning
The context manager requires that all
process_groupattributes intheJoinHookobjects are the same. If there are multipleJoinHookobjects, then thedeviceof the first is used.The process group and device information is used for checking for non-joined processes and for notifying processes to throw an exception ifthrow_on_early_terminationis enabled, both of which using an all-reduce.- Parameters
joinables (List[Joinable]) – a list of the participating
Joinables; their hooks are iterated over in the givenorder.enable (bool) – a flag enabling uneven input detection; setting to
Falsedisables the context manager’s functionality and shouldonly be set when the user knows the inputs will not be uneven(default:True).throw_on_early_termination (bool) – a flag controlling whether to throw anexception upon detecting uneven inputs (default:
False).
Example:
>>>importos>>>importtorch>>>importtorch.distributedasdist>>>importtorch.multiprocessingasmp>>>importtorch.nn.parallel.DistributedDataParallelasDDP>>>importtorch.distributed.optim.ZeroRedundancyOptimizerasZeRO>>>fromtorch.distributed.algorithms.joinimportJoin>>>>>># On each spawned worker>>>defworker(rank):>>>dist.init_process_group("nccl",rank=rank,world_size=2)>>>model=DDP(torch.nn.Linear(1,1).to(rank),device_ids=[rank])>>>optim=ZeRO(model.parameters(),torch.optim.Adam,lr=0.01)>>># Rank 1 gets one more input than rank 0>>>inputs=[torch.tensor([1.]).to(rank)for_inrange(10+rank)]>>>withJoin([model,optim]):>>>forinputininputs:>>>loss=model(input).sum()>>>loss.backward()>>>optim.step()>>># All ranks reach here without hanging/erroring
- staticnotify_join_context(joinable)[source]#
Notifies the join context manager that the calling process has not yet joined.
Then, if
throw_on_early_termination=True, checks if uneven inputs have been detected(i.e. if one process has already joined) and throws an exception if so.This method should be called from a
Joinableobject beforeits per-iteration collective communications. For example, this shouldbe called at the beginning of the forward pass inDistributedDataParallel.Only the first
Joinableobject passed into the contextmanager performs the collective communications in this method, andfor the others, this method is vacuous.
- classtorch.distributed.algorithms.Joinable[source]#
This defines an abstract base class for joinable classes.
A joinable class(inheriting from
Joinable) should implementjoin_hook(),which returns aJoinHookinstance, in addition tojoin_device()andjoin_process_group()that return device andprocess group information, respectively.- abstractpropertyjoin_device:device#
Return the device from which to perform collective communications needed by the join context manager.
- classtorch.distributed.algorithms.JoinHook[source]#
This defines a join hook, which provides two entry points in the join context manager.
Entry points : a main hook, which is called repeatedly while there exists a non-joinedprocess, and a post-hook, which is called once all processes have joined.
To implement a join hook for the generic join context manager, define aclass that inherits from
JoinHookand overridemain_hook()andpost_hook()as appropriate.