nemo_rl.distributed.worker_groups#

Module Contents#

Classes#

MultiWorkerFuture

Container for Ray futures with associated worker information.

RayWorkerBuilder

RayWorkerGroup

Manages a group of distributed Ray worker/actor processes that execute tasks in parallel.

API#

classnemo_rl.distributed.worker_groups.MultiWorkerFuture#

Container for Ray futures with associated worker information.

futures:list[ray.ObjectRef]#

None

return_from_workers:Optional[list[int]]#

None

called_workers:Optional[list[int]]#

None

get_results(
worker_group:nemo_rl.distributed.worker_groups.RayWorkerGroup,
return_generators_as_proxies:bool=False,
)list[Any]#

Get results from the futures, optionally respecting tied workers.

The method uses worker_group.worker_to_tied_group_index to identify which tiedworker group each worker belongs to, then selects only the first result from each group.

Parameters:
  • worker_group – The RayWorkerGroup that spawned the futures. Themapping contained in worker_group.worker_to_tied_group_indexis required for the deduplication path.

  • return_generators_as_proxies – If True, and a future is an ObjectRefGenerator,return the ObjectRefGenerator itself instead of consuming it.

Returns:

List of results

classnemo_rl.distributed.worker_groups.RayWorkerBuilder(ray_actor_class_fqn:str,*args,**kwargs)#

Initialization

classIsolatedWorkerInitializer(
ray_actor_class_fqn:str,
*init_args,
**init_kwargs,
)#

Initialization

create_worker(
placement_group:ray.util.placement_group.PlacementGroup,
placement_group_bundle_index:int,
num_gpus:int,
bundle_indices:Optional[tuple]=None,
**extra_options:Optional[dict[str,Any]],
)#

Create a Ray worker with the specified configuration.

Order of precedence for worker options configuration (from lowest to highest):

  1. Options passed by the user tocall (extra_options)

  2. Options required by the worker via configure_worker (may override user options with warning)

  3. Options set by the RayWorkerBuilder.call (specifically scheduling strategy)

If the worker needs to override user-provided options, it should log a warningto inform the user about the change and the reason for it.

Parameters:
  • placement_group – Ray placement group for resource allocation

  • placement_group_bundle_index – Index of the bundle in the placement group

  • num_gpus – Number of GPUs to allocate to this worker

  • bundle_indices – Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable)

  • extra_options – Additional options to pass to the Ray actor (may be overridden by actor’s configure_worker(…) method)

Returns:

A Ray actor reference to the created worker

create_worker_async(
placement_group:ray.util.placement_group.PlacementGroup,
placement_group_bundle_index:int,
num_gpus:float|int,
bundle_indices:Optional[tuple[int,list[int]]]=None,
**extra_options:Any,
)tuple[ray.ObjectRef,ray.actor.ActorHandle]#

Create a Ray worker asynchronously, returning futures.

This method returns immediately with futures that can be awaited later.

Parameters:
  • placement_group – Ray placement group for resource allocation

  • placement_group_bundle_index – Index of the bundle in the placement group

  • num_gpus – Number of GPUs to allocate to this worker (can be fractional)

  • bundle_indices – Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable)

  • extra_options – Additional options to pass to the Ray actor

Returns:

  • worker_future: A Ray ObjectRef that will resolve to the worker actor

  • initializer_actor: The initializer actor (needed to prevent GC)

Return type:

Tuple of (worker_future, initializer_actor)

__call__(
placement_group:ray.util.placement_group.PlacementGroup,
placement_group_bundle_index:int,
num_gpus:float|int,
bundle_indices:Optional[tuple[int,list[int]]]=None,
**extra_options:Any,
)ray.actor.ActorHandle#

Create a Ray worker with the specified configuration.

Order of precedence for worker options configuration (from lowest to highest):

  1. Options passed by the user tocall (extra_options)

  2. Options required by the worker via configure_worker (may override user options with warning)

  3. Options set by the RayWorkerBuilder.call (specifically scheduling strategy)

If the worker needs to override user-provided options, it should log a warningto inform the user about the change and the reason for it.

Parameters:
  • placement_group – Ray placement group for resource allocation

  • placement_group_bundle_index – Index of the bundle in the placement group

  • num_gpus – Number of GPUs to allocate to this worker (can be fractional)

  • bundle_indices – Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable)

  • extra_options – Additional options to pass to the Ray actor (may be overridden by actor’s configure_worker(…) method)

Returns:

A Ray actor reference to the created worker

classnemo_rl.distributed.worker_groups.RayWorkerGroup(
cluster:nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
remote_worker_builder:nemo_rl.distributed.worker_groups.RayWorkerBuilder,
workers_per_node:Optional[Union[int,list[int]]]=None,
name_prefix:str='',
bundle_indices_list:Optional[list[tuple[int,list[int]]]]=None,
sharding_annotations:Optional[nemo_rl.distributed.named_sharding.NamedSharding]=None,
env_vars:dict[str,str]={},
)#

Manages a group of distributed Ray worker/actor processes that execute tasks in parallel.

This class creates and manages Ray actor instances that run on resourcesallocated by a RayVirtualCluster. It handles:

  • Worker creation and placement on specific GPU resources

  • Setting up distributed training environment variables (rank, world size, etc.)

  • Executing methods across all workers in parallel

  • Collecting and aggregating results

  • Support for tied worker groups where multiple workers process the same data

Initialization

Initialize a group of distributed Ray workers.

Parameters:
  • cluster – RayVirtualCluster

  • remote_worker_builder – Callable that launches a ray worker and has updatable options

  • workers_per_node – Defaults to launch one worker per bundle in the cluster.Alternatively specify an int or list to launch a different number of workers per node.

  • name_prefix – Optional prefix for the names of the workers

  • bundle_indices_list – Explicit list of (node_idx, [local_bundle_indices]) tuples.Each tuple defines a tied group of workers placed on the same node.If provided, workers_per_node is ignored.

  • sharding_annotations – NamedSharding object representing mapping of named axes to ranks (i.e. for TP, PP, etc.)

get_dp_leader_worker_idx(dp_shard_idx:int)int#

Returns the index of the primary worker for a given data parallel shard.

_create_workers_from_bundle_indices(
remote_worker_builder:nemo_rl.distributed.worker_groups.RayWorkerBuilder,
bundle_indices_list:list[tuple[int,list[int]]],
env_vars:dict[str,str]={},
)None#

Create workers based on explicit bundle indices for tied worker groups.

Parameters:
  • remote_worker_builder – Builder function for Ray actors

  • bundle_indices_list – List of (node_idx, local_bundle_indices) tuples, where each tuplespecifies a tied group with its node and local bundle indices. If the local_bundle_indicesspans multiple nodes, the node_idx will be the first node’s index in the tied group.

propertyworkers:list[ray.actor.ActorHandle]#
propertyworker_metadata:list[dict[str,Any]]#
propertydp_size:int#

Number of data parallel shards.

run_single_worker_single_data(
method_name:str,
worker_idx:int,
*args,
**kwargs,
)ray.ObjectRef#

Run a method on a single, specific worker.

Parameters:
  • method_name – Name of the method to call on the worker.

  • worker_idx – The index of the worker to run the method on.

  • *args – Arguments to pass to the method.

  • **kwargs – Arguments to pass to the method.

Returns:

A Ray future for the result.

Return type:

ray.ObjectRef

run_all_workers_multiple_data(
method_name:str,
*args,
run_rank_0_only_axes:list[str]|None=None,
common_kwargs:Optional[dict[str,Any]]=None,
**kwargs,
)list[ray.ObjectRef]#

Run a method on all workers in parallel with different data.

Parameters:
  • method_name – Name of the method to call on each worker

  • *args – List of arguments to pass to workers/groupse.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]]

  • run_rank_0_only_axes – List of named axes for which only rank 0 should run the method.

  • common_kwargs – Keyword arguments to pass to all workers

  • **kwargs – Keyword arguments to pass to workers/groupse.g. {“key1”: [value_for_worker_1, value_for_worker_2], “key2”: [value_for_worker_1, value_for_worker_2]}

Returns:

A list of ray futures

Return type:

list[ray.ObjectRef]

run_all_workers_single_data(
method_name:str,
*args,
run_rank_0_only_axes:list[str]|None=None,
**kwargs,
)list[ray.ObjectRef]#

Run a method on all workers in parallel with the same data.

Parameters:
  • method_name – Name of the method to call on each worker

  • *args – Arguments to pass to the method

  • **kwargs – Arguments to pass to the method

  • run_rank_0_only_axes – List of named axes for which only rank 0 should run the method.

Returns:

A list of ray futures

Return type:

list[ray.ObjectRef]

run_all_workers_sharded_data(
method_name:str,
*args,
in_sharded_axes:list[str]|None=None,
replicate_on_axes:list[str]|None=None,
output_is_replicated:list[str]|None=None,
make_dummy_calls_to_free_axes:bool=False,
common_kwargs:Optional[dict[str,Any]]=None,
**kwargs,
)nemo_rl.distributed.worker_groups.MultiWorkerFuture#

Run a method on all workers in parallel with sharded data.

Axes in in_sharded_axes: Data is already split across these axes, so we just send the appropriate slice to each worker (along this axis)Axes in replicate_on_axes: Data is replicated to all workers along these dimensionsFree axes (axes not in either list): Data is only sent to workers at index 0 of these axes

Parameters:
  • method_name – Name of the method to call on each worker

  • *args – List of arguments to pass to workers/groupse.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]]

  • in_sharded_axes – List of axes that are sharded

  • replicate_on_axes – List of axes that are to be replicated

  • output_is_replicated – List of axes along which the output is replicated (and we should just return the first result).We also just return from rank 0 of free axes.

  • make_dummy_calls_to_free_axes – Whether to make dummy calls (with None) to workers thataren’t rank 0 on ‘free axes’ (axes not in in_sharded_axes or replicate_on_axes).

  • common_kwargs – Keyword arguments to pass to all workers

  • **kwargs – Keyword arguments to pass to workers/groupse.g. {“key1”: [value_for_worker_1, value_for_worker_2], “key2”: [value_for_worker_1, value_for_worker_2]}

Returns:

Object containing futures and their associated worker information

Return type:

MultiWorkerFuture

get_all_worker_results(
future_bundle:nemo_rl.distributed.worker_groups.MultiWorkerFuture,
return_generators_as_proxies:bool=False,
)list[Any]#

Get results from all workers, optionally filtering to get just one result per tied worker group.

Parameters:
  • future_bundle – MultiWorkerFuture containing futures and worker information.

  • return_generators_as_proxies – If True, and a future in the bundle is an ObjectRefGenerator,return the ObjectRefGenerator itself instead of consuming it.

Returns:

List of results, deduplicated as specified in the future_bundle

shutdown(
cleanup_method:Optional[str]=None,
timeout:Optional[float]=30.0,
force:bool=False,
)bool#

Shutdown all workers in the worker group.

Parameters:
  • cleanup_method – Optional method name to call on each worker before termination.If provided, this method will be called on each worker to allowfor graceful cleanup.

  • timeout – Timeout in seconds for graceful shutdown. Only applicable if cleanup_method is provided.If None, wait indefinitely for workers to complete their cleanup.

  • force – If True, forcefully terminate workers with ray.kill() even if cleanup_method is provided.If cleanup_method is None, workers are always forcefully terminated.

Returns:

True if all workers were successfully shut down

Return type:

bool