Rate this Page

Generic Join Context Manager#

Created On: Jun 06, 2025 | Last Updated On: Jun 06, 2025

The generic join context manager facilitates distributed training on uneveninputs. This page outlines the API of the relevant classes:Join,Joinable, andJoinHook. For a tutorial, seeDistributed Training with Uneven Inputs Using the Join Context Manager.

classtorch.distributed.algorithms.Join(joinables,enable=True,throw_on_early_termination=False,**kwargs)[source]#

This class defines the generic join context manager, which allows custom hooks to be called after a process joins.

These hooks should shadow thecollective communications of non-joined processes to prevent hanging anderroring and to ensure algorithmic correctness. Refer toJoinHookfor details about the hook definition.

Warning

The context manager requires each participatingJoinable tocall the methodnotify_join_context() before its own per-iteration collective communications to ensure correctness.

Warning

The context manager requires that allprocess_group attributes intheJoinHook objects are the same. If there are multipleJoinHook objects, then thedevice of the first is used.The process group and device information is used for checking for non-joined processes and for notifying processes to throw an exception ifthrow_on_early_termination is enabled, both of which using an all-reduce.

Parameters
  • joinables (List[Joinable]) – a list of the participatingJoinable s; their hooks are iterated over in the givenorder.

  • enable (bool) – a flag enabling uneven input detection; setting toFalse disables the context manager’s functionality and shouldonly be set when the user knows the inputs will not be uneven(default:True).

  • throw_on_early_termination (bool) – a flag controlling whether to throw anexception upon detecting uneven inputs (default:False).

Example:

>>>importos>>>importtorch>>>importtorch.distributedasdist>>>importtorch.multiprocessingasmp>>>importtorch.nn.parallel.DistributedDataParallelasDDP>>>importtorch.distributed.optim.ZeroRedundancyOptimizerasZeRO>>>fromtorch.distributed.algorithms.joinimportJoin>>>>>># On each spawned worker>>>defworker(rank):>>>dist.init_process_group("nccl",rank=rank,world_size=2)>>>model=DDP(torch.nn.Linear(1,1).to(rank),device_ids=[rank])>>>optim=ZeRO(model.parameters(),torch.optim.Adam,lr=0.01)>>># Rank 1 gets one more input than rank 0>>>inputs=[torch.tensor([1.]).to(rank)for_inrange(10+rank)]>>>withJoin([model,optim]):>>>forinputininputs:>>>loss=model(input).sum()>>>loss.backward()>>>optim.step()>>># All ranks reach here without hanging/erroring
staticnotify_join_context(joinable)[source]#

Notifies the join context manager that the calling process has not yet joined.

Then, ifthrow_on_early_termination=True, checks if uneven inputs have been detected(i.e. if one process has already joined) and throws an exception if so.

This method should be called from aJoinable object beforeits per-iteration collective communications. For example, this shouldbe called at the beginning of the forward pass inDistributedDataParallel.

Only the firstJoinable object passed into the contextmanager performs the collective communications in this method, andfor the others, this method is vacuous.

Parameters

joinable (Joinable) – theJoinable object calling thismethod.

Returns

An async work handle for the all-reduce meant to notify the contextmanager that the process has not yet joined ifjoinable is thefirst one passed into the context manager;None otherwise.

classtorch.distributed.algorithms.Joinable[source]#

This defines an abstract base class for joinable classes.

A joinable class(inheriting fromJoinable) should implementjoin_hook(),which returns aJoinHook instance, in addition tojoin_device() andjoin_process_group() that return device andprocess group information, respectively.

abstractpropertyjoin_device:device#

Return the device from which to perform collective communications needed by the join context manager.

abstractjoin_hook(**kwargs)[source]#

Return aJoinHook instance for the givenJoinable.

Parameters

kwargs (dict) – adict containing any keyword argumentsto modify the behavior of the join hook at run time; allJoinable instances sharing the same join contextmanager are forwarded the same value forkwargs.

Return type

JoinHook

abstractpropertyjoin_process_group:Any#

Returns the process group for the collective communications needed by the join context manager itself.

classtorch.distributed.algorithms.JoinHook[source]#

This defines a join hook, which provides two entry points in the join context manager.

Entry points : a main hook, which is called repeatedly while there exists a non-joinedprocess, and a post-hook, which is called once all processes have joined.

To implement a join hook for the generic join context manager, define aclass that inherits fromJoinHook and overridemain_hook() andpost_hook() as appropriate.

main_hook()[source]#

Call this hook while there exists a non-joined process to shadow collective communications in a training iteration.

Training iteration i.e., in one forward pass, backward pass, and optimizer step.

post_hook(is_last_joiner)[source]#

Call hook after all processes have joined.

It is passed an additionalbool argumentis_last_joiner, which indicates if the rank is one of the last to join.

Parameters

is_last_joiner (bool) –True if the rank is one of the last tojoin;False otherwise.