Rate this Page

FullyShardedDataParallel#

Created On: Feb 02, 2022 | Last Updated On: Jun 11, 2025

classtorch.distributed.fsdp.FullyShardedDataParallel(module,process_group=None,sharding_strategy=None,cpu_offload=None,auto_wrap_policy=None,backward_prefetch=BackwardPrefetch.BACKWARD_PRE,mixed_precision=None,ignored_modules=None,param_init_fn=None,device_id=None,sync_module_states=False,forward_prefetch=False,limit_all_gathers=True,use_orig_params=False,ignored_states=None,device_mesh=None)[source]#

A wrapper for sharding module parameters across data parallel workers.

This is inspired byXu et al. aswell as the ZeRO Stage 3 fromDeepSpeed.FullyShardedDataParallel is commonly shortened to FSDP.

Example:

>>>importtorch>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>torch.cuda.set_device(device_id)>>>sharded_module=FSDP(my_module)>>>optim=torch.optim.Adam(sharded_module.parameters(),lr=0.0001)>>>x=sharded_module(x,y=3,z=torch.Tensor([1]))>>>loss=x.sum()>>>loss.backward()>>>optim.step()

Using FSDP involves wrapping your module and then initializing youroptimizer after. This is required since FSDP changes the parametervariables.

When setting up FSDP, you need to consider the destination CUDAdevice. If the device has an ID (dev_id), you have three options:

  • Place the module on that device

  • Set the device usingtorch.cuda.set_device(dev_id)

  • Passdev_id into thedevice_id constructor argument.

This ensures that the FSDP instance’s compute device is thedestination device. For option 1 and 3, the FSDP initializationalways occurs on GPU. For option 2, the FSDP initializationhappens on module’s current device, which may be a CPU.

If you’re using thesync_module_states=True flag, you need toensure that the module is on a GPU or use thedevice_idargument to specify a CUDA device that FSDP will move the moduleto in the FSDP constructor. This is necessary becausesync_module_states=True requires GPU communication.

FSDP also takes care of moving input tensors to the forward methodto the GPU compute device, so you don’t need to manually move themfrom CPU.

Foruse_orig_params=True,ShardingStrategy.SHARD_GRAD_OP exposes the unshardedparameters, not the sharded parameters after forward, unlikeShardingStrategy.FULL_SHARD. If you wantto inspect the gradients, you can use thesummon_full_paramsmethod withwith_grads=True.

Withlimit_all_gathers=True, you may see a gap in the FSDPpre-forward where the CPU thread is not issuing any kernels. This isintentional and shows the rate limiter in effect. Synchronizing the CPUthread in that way prevents over-allocating memory for subsequentall-gathers, and it should not actually delay GPU kernel execution.

FSDP replaces managed modules’ parameters withtorch.Tensorviews during forward and backward computation for autograd-relatedreasons. If your module’s forward relies on saved references tothe parameters instead of reacquiring the references eachiteration, then it will not see FSDP’s newly created views,and autograd will not work correctly.

Finally, when usingsharding_strategy=ShardingStrategy.HYBRID_SHARDwith the sharding process group being intra-node and thereplication process group being inter-node, settingNCCL_CROSS_NIC=1 can help improve the all-reduce times overthe replication process group for some cluster setups.

Limitations

There are several limitations to be aware of when using FSDP:

  • FSDP currently does not support gradient accumulation outsideno_sync() when using CPU offloading. This is because FSDPuses the newly-reduced gradient instead of accumulating with anyexisting gradient, which can lead to incorrect results.

  • FSDP does not support running the forward pass of a submodulethat is contained in an FSDP instance. This is because thesubmodule’s parameters will be sharded, but the submodule itselfis not an FSDP instance, so its forward pass will not all-gatherthe full parameters appropriately.

  • FSDP does not work with double backwards due to the way itregisters backward hooks.

  • FSDP has some constraints when freezing parameters.Foruse_orig_params=False, each FSDP instance must manageparameters that are all frozen or all non-frozen. Foruse_orig_params=True, FSDP supports mixing frozen andnon-frozen parameters, but it’s recommended to avoid doing so toprevent higher than expected gradient memory usage.

  • As of PyTorch 1.12, FSDP offers limited support for sharedparameters. If enhanced shared parameter support is needed foryour use case, please post inthis issue.

  • You should avoid modifying the parameters between forward andbackward without using thesummon_full_params context, asthe modifications may not persist.

Parameters:
  • module (nn.Module) – This is the module to be wrapped with FSDP.

  • process_group (Optional[Union[ProcessGroup,Tuple[ProcessGroup,ProcessGroup]]]) – This is the process group over which the model is sharded and thusthe one used for FSDP’s all-gather and reduce-scatter collectivecommunications. IfNone, then FSDP uses the default processgroup. For hybrid sharding strategies such asShardingStrategy.HYBRID_SHARD, users can pass in a tuple ofprocess groups, representing the groups over which to shard andreplicate, respectively. IfNone, then FSDP constructs processgroups for the user to shard intra-node and replicate inter-node.(Default:None)

  • sharding_strategy (Optional[ShardingStrategy]) – This configures the sharding strategy, which may trade off memorysaving and communication overhead. SeeShardingStrategyfor details. (Default:FULL_SHARD)

  • cpu_offload (Optional[CPUOffload]) – This configures CPU offloading. If this is set toNone, thenno CPU offloading happens. SeeCPUOffload for details.(Default:None)

  • auto_wrap_policy (Optional[Union[Callable[[nn.Module,bool,int],bool],ModuleWrapPolicy,CustomPolicy]]) –

    This specifies a policy to apply FSDP to submodules ofmodule,which is needed for communication and computation overlap and thusaffects performance. IfNone, then FSDP only applies tomodule, and users should manually apply FSDP to parent modulesthemselves (proceeding bottom-up). For convenience, this acceptsModuleWrapPolicy directly, which allows users to specify themodule classes to wrap (e.g. the transformer block). Otherwise,this should be a callable that takes in three argumentsmodule:nn.Module,recurse:bool, andnonwrapped_numel:int and should return abool specifyingwhether the passed-inmodule should have FSDP applied ifrecurse=False or if the traversal should continue into themodule’s subtree ifrecurse=True. Users may add additionalarguments to the callable. Thesize_based_auto_wrap_policy intorch.distributed.fsdp.wrap.py gives an example callable thatapplies FSDP to a module if the parameters in its subtree exceed100M numel. We recommend printing the model after applying FSDPand adjusting as needed.

    Example:

    >>>defcustom_auto_wrap_policy(>>>module:nn.Module,>>>recurse:bool,>>>nonwrapped_numel:int,>>># Additional custom arguments>>>min_num_params:int=int(1e8),>>>)->bool:>>>returnnonwrapped_numel>=min_num_params>>># Configure a custom `min_num_params`>>>my_auto_wrap_policy=functools.partial(custom_auto_wrap_policy,min_num_params=int(1e5))

  • backward_prefetch (Optional[BackwardPrefetch]) – This configures explicit backward prefetching of all-gathers. IfNone, then FSDP does not backward prefetch, and there is nocommunication and computation overlap in the backward pass. SeeBackwardPrefetch for details. (Default:BACKWARD_PRE)

  • mixed_precision (Optional[MixedPrecision]) – This configures native mixed precision for FSDP. If this is set toNone, then no mixed precision is used. Otherwise, parameter,buffer, and gradient reduction dtypes can be set. SeeMixedPrecision for details. (Default:None)

  • ignored_modules (Optional[Iterable[torch.nn.Module]]) – Modules whoseown parameters and child modules’ parameters and buffers areignored by this instance. None of the modules directly inignored_modules should beFullyShardedDataParallelinstances, and any child modules that are already-constructedFullyShardedDataParallel instances will not be ignored ifthey are nested under this instance. This argument may be used toavoid sharding specific parameters at module granularity when using anauto_wrap_policy or if parameters’ sharding is not managed byFSDP. (Default:None)

  • param_init_fn (Optional[Callable[[nn.Module],None]]) –

    ACallable[torch.nn.Module]->None thatspecifies how modules that are currently on the meta device shouldbe initialized onto an actual device. As of v1.12, FSDP detectsmodules with parameters or buffers on meta device viais_metaand either appliesparam_init_fn if specified or callsnn.Module.reset_parameters() otherwise. For both cases, theimplementation shouldonly initialize the parameters/buffers ofthe module, not those of its submodules. This is to avoidre-initialization. In addition, FSDP also supports deferredinitialization via torchdistX’s (pytorch/torchdistX)deferred_init() API, where the deferred modules are initializedby callingparam_init_fn if specified or torchdistX’s defaultmaterialize_module() otherwise. Ifparam_init_fn isspecified, then it is applied to all meta-device modules, meaningthat it should probably case on the module type. FSDP calls theinitialization function before parameter flattening and sharding.

    Example:

    >>>module=MyModule(device="meta")>>>defmy_init_fn(module:nn.Module):>>># E.g. initialize depending on the module type>>>...>>>fsdp_model=FSDP(module,param_init_fn=my_init_fn,auto_wrap_policy=size_based_auto_wrap_policy)>>>print(next(fsdp_model.parameters()).device)# current CUDA device>>># With torchdistX>>>module=deferred_init.deferred_init(MyModule,device="cuda")>>># Will initialize via deferred_init.materialize_module().>>>fsdp_model=FSDP(module,auto_wrap_policy=size_based_auto_wrap_policy)

  • device_id (Optional[Union[int,torch.device]]) – Anint ortorch.device giving the CUDA device on which FSDPinitialization takes place, including the module initializationif needed and the parameter sharding. This should be specified toimprove initialization speed ifmodule is on CPU. If thedefault CUDA device was set (e.g. viatorch.cuda.set_device),then the user may passtorch.cuda.current_device to this.(Default:None)

  • sync_module_states (bool) – IfTrue, then each FSDP module willbroadcast module parameters and buffers from rank 0 to ensure thatthey are replicated across ranks (adding communication overhead tothis constructor). This can help loadstate_dict checkpointsviaload_state_dict in a memory efficient way. SeeFullStateDictConfig for an example of this. (Default:False)

  • forward_prefetch (bool) – IfTrue, then FSDPexplicitly prefetchesthe next forward-pass all-gather before the current forwardcomputation. This is only useful for CPU-bound workloads, in whichcase issuing the next all-gather earlier may improve overlap. Thisshould only be used for static-graph models since the prefetchingfollows the first iteration’s execution order. (Default:False)

  • limit_all_gathers (bool) – IfTrue, then FSDP explicitlysynchronizes the CPU thread to ensure GPU memory usage from onlytwo consecutive FSDP instances (the current instance runningcomputation and the next instance whose all-gather is prefetched).IfFalse, then FSDP allows the CPU thread to issue all-gatherswithout any extra synchronization. (Default:True) We oftenrefer to this feature as the “rate limiter”. This flag should onlybe set toFalse for specific CPU-bound workloads with lowmemory pressure in which case the CPU thread can aggressively issueall kernels without concern for the GPU memory usage.

  • use_orig_params (bool) – Setting this toTrue has FSDP usemodule ‘s original parameters. FSDP exposes those originalparameters to the user viann.Module.named_parameters()instead of FSDP’s internalFlatParameter s. This meansthat the optimizer step runs on the original parameters, enablingper-original-parameter hyperparameters. FSDP preserves the originalparameter variables and manipulates their data between unshardedand sharded forms, where they are always views into the underlyingunsharded or shardedFlatParameter, respectively. With thecurrent algorithm, the sharded form is always 1D, losing theoriginal tensor structure. An original parameter may have all,some, or none of its data present for a given rank. In the nonecase, its data will be like a size-0 empty tensor. Users should notauthor programs relying on what data is present for a givenoriginal parameter in its sharded form.True is required tousetorch.compile(). Setting this toFalse exposes FSDP’sinternalFlatParameter s to the user viann.Module.named_parameters(). (Default:False)

  • ignored_states (Optional[Iterable[torch.nn.Parameter]],Optional[Iterable[torch.nn.Module]]) – Ignored parameters or modules that will not be managed by this FSDPinstance, meaning that the parameters are not sharded and theirgradients are not reduced across ranks. This argument unifies withthe existingignored_modules argument, and we may deprecateignored_modules soon. For backward compatibility, we keep bothignored_states andignored_modules`, but FSDP only allows oneof them to be specified as notNone.

  • device_mesh (Optional[DeviceMesh]) – DeviceMesh can be used as an alternative toprocess_group. When device_mesh is passed, FSDP will use the underlying processgroups for all-gather and reduce-scatter collective communications. Therefore,these two args need to be mutually exclusive. For hybrid sharding strategies such asShardingStrategy.HYBRID_SHARD, users can pass in a 2D DeviceMesh insteadof a tuple of process groups. For 2D FSDP + TP, users are required to pass indevice_mesh instead of process_group. For more DeviceMesh info, please visit:https://pytorch.org/tutorials/recipes/distributed_device_mesh.html

apply(fn)[source]#

Applyfn recursively to every submodule (as returned by.children()) as well as self.

Typical use includes initializing the parameters of a model (see alsotorch.nn.init).

Compared totorch.nn.Module.apply, this version additionally gathersthe full parameters before applyingfn. It should not be called fromwithin anothersummon_full_params context.

Parameters:

fn (Module -> None) – function to be applied to each submodule

Returns:

self

Return type:

Module

check_is_root()[source]#

Check if this instance is a root FSDP module.

Return type:

bool

clip_grad_norm_(max_norm,norm_type=2.0)[source]#

Clip the gradient norm of all parameters.

The norm is computed over all parameters’ gradients as viewed as a single vector, and thegradients are modified in-place.

Parameters:
  • max_norm (float orint) – max norm of the gradients

  • norm_type (float orint) – type of the used p-norm. Can be'inf'for infinity norm.

Returns:

Total norm of the parameters (viewed as a single vector).

Return type:

Tensor

If every FSDP instance usesNO_SHARD, meaning that nogradients are sharded across ranks, then you may directly usetorch.nn.utils.clip_grad_norm_().

If at least some FSDP instance uses a sharded strategy (i.e.one other thanNO_SHARD), then you should use this methodinstead oftorch.nn.utils.clip_grad_norm_() since this methodhandles the fact that gradients are sharded across ranks.

The total norm returned will have the “largest” dtype acrossall parameters/gradients as defined by PyTorch’s type promotionsemantics. For example, ifall parameters/gradients use a lowprecision dtype, then the returned norm’s dtype will be that lowprecision dtype, but if there exists at least one parameter/gradient using FP32, then the returned norm’s dtype will be FP32.

Warning

This needs to be called on all ranks since it usescollective communications.

staticflatten_sharded_optim_state_dict(sharded_optim_state_dict,model,optim)[source]#

Flatten a sharded optimizer state-dict.

The API is similar toshard_full_optim_state_dict(). The onlydifference is that the inputsharded_optim_state_dict should bereturned fromsharded_optim_state_dict(). Therefore, there willbe all-gather calls on each rank to gatherShardedTensor s.

Parameters:
Returns:

Refer toshard_full_optim_state_dict().

Return type:

dict[str,Any]

forward(*args,**kwargs)[source]#

Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.

Return type:

Any

staticfsdp_modules(module,root_only=False)[source]#

Return all nested FSDP instances.

This possibly includesmodule itself and only includes FSDP root modules ifroot_only=True.

Parameters:
  • module (torch.nn.Module) – Root module, which may or may not be anFSDP module.

  • root_only (bool) – Whether to return only FSDP root modules.(Default:False)

Returns:

FSDP modules that are nested inthe inputmodule.

Return type:

List[FullyShardedDataParallel]

staticfull_optim_state_dict(model,optim,optim_input=None,rank0_only=True,group=None)[source]#

Return the full optimizer state-dict.

Consolidates the full optimizer state on rank 0 and returns itas adict following the convention oftorch.optim.Optimizer.state_dict(), i.e. with keys"state"and"param_groups". The flattened parameters inFSDP modulescontained inmodel are mapped back to their unflattened parameters.

This needs to be called on all ranks since it usescollective communications. However, ifrank0_only=True, thenthe state dict is only populated on rank 0, and all other ranksreturn an emptydict.

Unliketorch.optim.Optimizer.state_dict(), this methoduses full parameter names as keys instead of parameter IDs.

Like intorch.optim.Optimizer.state_dict(), the tensorscontained in the optimizer state dict are not cloned, so there maybe aliasing surprises. For best practices, consider saving thereturned optimizer state dict immediately, e.g. usingtorch.save().

Parameters:
  • model (torch.nn.Module) – Root module (which may or may not be aFullyShardedDataParallel instance) whose parameterswere passed into the optimizeroptim.

  • optim (torch.optim.Optimizer) – Optimizer formodel ‘sparameters.

  • optim_input (Optional[Union[List[Dict[str,Any]],Iterable[torch.nn.Parameter]]]) – Input passed into the optimizeroptim representing either alist of parameter groups or an iterable of parameters;ifNone, then this method assumes the input wasmodel.parameters(). This argument is deprecated, and thereis no need to pass it in anymore. (Default:None)

  • rank0_only (bool) – IfTrue, saves the populateddictonly on rank 0; ifFalse, saves it on all ranks. (Default:True)

  • group (dist.ProcessGroup) – Model’s process group orNone if usingthe default process group. (Default:None)

Returns:

Adict containing the optimizer state formodel ‘s original unflattened parameters and including keys“state” and “param_groups” following the convention oftorch.optim.Optimizer.state_dict(). Ifrank0_only=True,then nonzero ranks return an emptydict.

Return type:

Dict[str, Any]

staticget_state_dict_type(module)[source]#

Get the state_dict_type and the corresponding configurations for the FSDP modules rooted atmodule.

The target module does not have to be an FSDP module.

Returns:

AStateDictSettings containing the state_dict_type andstate_dict / optim_state_dict configs that are currently set.

Raises:
  • AssertionError` if the StateDictSettings for differen

  • FSDP submodules differ.

Return type:

StateDictSettings

propertymodule:Module#

Return the wrapped module.

named_buffers(*args,**kwargs)[source]#

Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.

Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefixwhen inside thesummon_full_params() context manager.

Return type:

Iterator[tuple[str,Tensor]]

named_parameters(*args,**kwargs)[source]#

Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself.

Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefixwhen inside thesummon_full_params() context manager.

Return type:

Iterator[tuple[str,Parameter]]

no_sync()[source]#

Disable gradient synchronizations across FSDP instances.

Within this context, gradients will be accumulated in modulevariables, which will later be synchronized in the firstforward-backward pass after exiting the context. This should only beused on the root FSDP instance and will recursively apply to allchildren FSDP instances.

Note

This likely results in higher memory usage because FSDP willaccumulate the full model gradients (instead of gradient shards)until the eventual sync.

Note

When used with CPU offloading, the gradients will not beoffloaded to CPU when inside the context manager. Instead, theywill only be offloaded right after the eventual sync.

Return type:

Generator

staticoptim_state_dict(model,optim,optim_state_dict=None,group=None)[source]#

Transform the state-dict of an optimizer corresponding to a sharded model.

The given state-dict can be transformed to one of three types:1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict.

For full optimizer state_dict, all states are unflattened and not sharded.Rank0 only and CPU only can be specified viastate_dict_type() toavoid OOM.

For sharded optimizer state_dict, all states are unflattened but sharded.CPU only can be specified viastate_dict_type() to further savememory.

For local state_dict, no transformation will be performed. But a statewill be converted from nn.Tensor to ShardedTensor to represent its shardingnature (this is not supported yet).

Example:

>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>fromtorch.distributed.fsdpimportStateDictType>>>fromtorch.distributed.fsdpimportFullStateDictConfig>>>fromtorch.distributed.fsdpimportFullOptimStateDictConfig>>># Save a checkpoint>>>model,optim=...>>>FSDP.set_state_dict_type(>>>model,>>>StateDictType.FULL_STATE_DICT,>>>FullStateDictConfig(rank0_only=False),>>>FullOptimStateDictConfig(rank0_only=False),>>>)>>>state_dict=model.state_dict()>>>optim_state_dict=FSDP.optim_state_dict(model,optim)>>>save_a_checkpoint(state_dict,optim_state_dict)>>># Load a checkpoint>>>model,optim=...>>>state_dict,optim_state_dict=load_a_checkpoint()>>>FSDP.set_state_dict_type(>>>model,>>>StateDictType.FULL_STATE_DICT,>>>FullStateDictConfig(rank0_only=False),>>>FullOptimStateDictConfig(rank0_only=False),>>>)>>>model.load_state_dict(state_dict)>>>optim_state_dict=FSDP.optim_state_dict_to_load(>>>model,optim,optim_state_dict>>>)>>>optim.load_state_dict(optim_state_dict)
Parameters:
  • model (torch.nn.Module) – Root module (which may or may not be aFullyShardedDataParallel instance) whose parameterswere passed into the optimizeroptim.

  • optim (torch.optim.Optimizer) – Optimizer formodel ‘sparameters.

  • optim_state_dict (Dict[str,Any]) – the target optimizer state_dict totransform. If the value is None, optim.state_dict() will be used. (Default:None)

  • group (dist.ProcessGroup) – Model’s process group across which parametersare sharded orNone if using the default process group. (Default:None)

Returns:

Adict containing the optimizer state formodel. The sharding of the optimizer state is based onstate_dict_type.

Return type:

Dict[str, Any]

staticoptim_state_dict_to_load(model,optim,optim_state_dict,is_named_optimizer=False,load_directly=False,group=None)[source]#

Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.

Given aoptim_state_dict that is transformed throughoptim_state_dict(), it gets converted to the flattened optimizerstate_dict that can be loaded tooptim which is the optimizer formodel.model must be sharded by FullyShardedDataParallel.

>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>fromtorch.distributed.fsdpimportStateDictType>>>fromtorch.distributed.fsdpimportFullStateDictConfig>>>fromtorch.distributed.fsdpimportFullOptimStateDictConfig>>># Save a checkpoint>>>model,optim=...>>>FSDP.set_state_dict_type(>>>model,>>>StateDictType.FULL_STATE_DICT,>>>FullStateDictConfig(rank0_only=False),>>>FullOptimStateDictConfig(rank0_only=False),>>>)>>>state_dict=model.state_dict()>>>original_osd=optim.state_dict()>>>optim_state_dict=FSDP.optim_state_dict(>>>model,>>>optim,>>>optim_state_dict=original_osd>>>)>>>save_a_checkpoint(state_dict,optim_state_dict)>>># Load a checkpoint>>>model,optim=...>>>state_dict,optim_state_dict=load_a_checkpoint()>>>FSDP.set_state_dict_type(>>>model,>>>StateDictType.FULL_STATE_DICT,>>>FullStateDictConfig(rank0_only=False),>>>FullOptimStateDictConfig(rank0_only=False),>>>)>>>model.load_state_dict(state_dict)>>>optim_state_dict=FSDP.optim_state_dict_to_load(>>>model,optim,optim_state_dict>>>)>>>optim.load_state_dict(optim_state_dict)
Parameters:
  • model (torch.nn.Module) – Root module (which may or may not be aFullyShardedDataParallel instance) whose parameterswere passed into the optimizeroptim.

  • optim (torch.optim.Optimizer) – Optimizer formodel ‘sparameters.

  • optim_state_dict (Dict[str,Any]) – The optimizer states to be loaded.

  • is_named_optimizer (bool) – Is this optimizer a NamedOptimizer orKeyedOptimizer. Only set to True ifoptim is TorchRec’sKeyedOptimizer or torch.distributed’s NamedOptimizer.

  • load_directly (bool) – If this is set to True, this API will alsocall optim.load_state_dict(result) before returning the result.Otherwise, users are responsible to calloptim.load_state_dict()(Default:False)

  • group (dist.ProcessGroup) – Model’s process group across which parametersare sharded orNone if using the default process group. (Default:None)

Return type:

dict[str,Any]

register_comm_hook(state,hook)[source]#

Register a communication hook.

This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregatesgradients across multiple workers.This hook can be used to implement several algorithms likeGossipGrad and gradient compressionwhich involve different communication strategies forparameter syncs while training withFullyShardedDataParallel.

Warning

FSDP communication hook should be registered before running an initial forward passand only once.

Parameters:
  • state (object) –

    Passed to the hook to maintain any state information during the training process.Examples include error feedback in gradient compression,peers to communicate with next inGossipGrad, etc.It is locally stored by each workerand shared by all the gradient tensors on the worker.

  • hook (Callable) – Callable, which has one of the following signatures:1)hook:Callable[torch.Tensor]->None:This function takes in a Python tensor, which representsthe full, flattened, unsharded gradient with respect to all variablescorresponding to the model this FSDP unit is wrapping(that are not wrapped by other FSDP sub-units).It then performs all necessary processing and returnsNone;2)hook:Callable[torch.Tensor,torch.Tensor]->None:This function takes in two Python tensors, the first one representsthe full, flattened, unsharded gradient with respect to all variablescorresponding to the model this FSDP unit is wrapping(that are not wrapped by other FSDP sub-units). The latterrepresents a pre-sized tensor to store a chunk of a sharded gradient afterreduction.In both cases, callable performs all necessary processing and returnsNone.Callables with signature 1 are expected to handle gradient communication for aNO_SHARD case.Callables with signature 2 are expected to handle gradient communication for sharded cases.

staticrekey_optim_state_dict(optim_state_dict,optim_state_key_type,model,optim_input=None,optim=None)[source]#

Re-keys the optimizer state dictoptim_state_dict to use the key typeoptim_state_key_type.

This can be used to achieve compatibility between optimizer state dicts from models with FSDPinstances and ones without.

To re-key an FSDP full optimizer state dict (i.e. fromfull_optim_state_dict()) to use parameter IDs and be loadable toa non-wrapped model:

>>>wrapped_model,wrapped_optim=...>>>full_osd=FSDP.full_optim_state_dict(wrapped_model,wrapped_optim)>>>nonwrapped_model,nonwrapped_optim=...>>>rekeyed_osd=FSDP.rekey_optim_state_dict(full_osd,OptimStateKeyType.PARAM_ID,nonwrapped_model)>>>nonwrapped_optim.load_state_dict(rekeyed_osd)

To re-key a normal optimizer state dict from a non-wrapped model to beloadable to a wrapped model:

>>>nonwrapped_model,nonwrapped_optim=...>>>osd=nonwrapped_optim.state_dict()>>>rekeyed_osd=FSDP.rekey_optim_state_dict(osd,OptimStateKeyType.PARAM_NAME,nonwrapped_model)>>>wrapped_model,wrapped_optim=...>>>sharded_osd=FSDP.shard_full_optim_state_dict(rekeyed_osd,wrapped_model)>>>wrapped_optim.load_state_dict(sharded_osd)
Returns:

The optimizer state dict re-keyed using theparameter keys specified byoptim_state_key_type.

Return type:

Dict[str, Any]

staticscatter_full_optim_state_dict(full_optim_state_dict,model,optim_input=None,optim=None,group=None)[source]#

Scatter the full optimizer state dict from rank 0 to all other ranks.

Returns the sharded optimizer state dict on each rank.The return value is the same asshard_full_optim_state_dict(), and on rank0, the first argument should be the return value offull_optim_state_dict().

Example:

>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>model,optim=...>>>full_osd=FSDP.full_optim_state_dict(model,optim)# only non-empty on rank 0>>># Define new model with possibly different world size>>>new_model,new_optim,new_group=...>>>sharded_osd=FSDP.scatter_full_optim_state_dict(full_osd,new_model,group=new_group)>>>new_optim.load_state_dict(sharded_osd)

Note

Bothshard_full_optim_state_dict() andscatter_full_optim_state_dict() may be used to get thesharded optimizer state dict to load. Assuming that the fulloptimizer state dict resides in CPU memory, the former requireseach rank to have the full dict in CPU memory, where each rankindividually shards the dict without any communication, while thelatter requires only rank 0 to have the full dict in CPU memory,where rank 0 moves each shard to GPU memory (for NCCL) andcommunicates it to ranks appropriately. Hence, the former hashigher aggregate CPU memory cost, while the latter has highercommunication cost.

Parameters:
  • full_optim_state_dict (Optional[Dict[str,Any]]) – Optimizer statedict corresponding to the unflattened parameters and holdingthe full non-sharded optimizer state if on rank 0; the argumentis ignored on nonzero ranks.

  • model (torch.nn.Module) – Root module (which may or may not be aFullyShardedDataParallel instance) whose parameterscorrespond to the optimizer state infull_optim_state_dict.

  • optim_input (Optional[Union[List[Dict[str,Any]],Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either alist of parameter groups or an iterable of parameters;ifNone, then this method assumes the input wasmodel.parameters(). This argument is deprecated, and thereis no need to pass it in anymore. (Default:None)

  • optim (Optional[torch.optim.Optimizer]) – Optimizer that will loadthe state dict returned by this method. This is the preferredargument to use overoptim_input. (Default:None)

  • group (dist.ProcessGroup) – Model’s process group orNone ifusing the default process group. (Default:None)

Returns:

The full optimizer state dict now remapped toflattened parameters instead of unflattened parameters andrestricted to only include this rank’s part of the optimizer state.

Return type:

Dict[str, Any]

staticset_state_dict_type(module,state_dict_type,state_dict_config=None,optim_state_dict_config=None)[source]#

Set thestate_dict_type of all the descendant FSDP modules of the target module.

Also takes (optional) configuration for the model’s and optimizer’s state dict.The target module does not have to be a FSDP module. If the targetmodule is a FSDP module, itsstate_dict_type will also be changed.

Note

This API should be called for only the top-level (root)module.

Note

This API enables users to transparently use the conventionalstate_dict API to take model checkpoints in cases where theroot FSDP module is wrapped by anothernn.Module. For example,the following will ensurestate_dict is called on all non-FSDPinstances, while dispatching intosharded_state_dict implementationfor FSDP:

Example:

>>>model=DDP(FSDP(...))>>>FSDP.set_state_dict_type(>>>model,>>>StateDictType.SHARDED_STATE_DICT,>>>state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),>>>optim_state_dict_config=OptimStateDictConfig(offload_to_cpu=True),>>>)>>>param_state_dict=model.state_dict()>>>optim_state_dict=FSDP.optim_state_dict(model,optim)
Parameters:
  • module (torch.nn.Module) – Root module.

  • state_dict_type (StateDictType) – the desiredstate_dict_type to set.

  • state_dict_config (Optional[StateDictConfig]) – the configuration for thetargetstate_dict_type.

  • optim_state_dict_config (Optional[OptimStateDictConfig]) – the configurationfor the optimizer state dict.

Returns:

A StateDictSettings that include the previous state_dict type andconfiguration for the module.

Return type:

StateDictSettings

staticshard_full_optim_state_dict(full_optim_state_dict,model,optim_input=None,optim=None)[source]#

Shard a full optimizer state-dict.

Remaps the state infull_optim_state_dict to flattened parameters instead of unflattenedparameters and restricts to only this rank’s part of the optimizer state.The first argument should be the return value offull_optim_state_dict().

Example:

>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>model,optim=...>>>full_osd=FSDP.full_optim_state_dict(model,optim)>>>torch.save(full_osd,PATH)>>># Define new model with possibly different world size>>>new_model,new_optim=...>>>full_osd=torch.load(PATH)>>>sharded_osd=FSDP.shard_full_optim_state_dict(full_osd,new_model)>>>new_optim.load_state_dict(sharded_osd)

Note

Bothshard_full_optim_state_dict() andscatter_full_optim_state_dict() may be used to get thesharded optimizer state dict to load. Assuming that the fulloptimizer state dict resides in CPU memory, the former requireseach rank to have the full dict in CPU memory, where each rankindividually shards the dict without any communication, while thelatter requires only rank 0 to have the full dict in CPU memory,where rank 0 moves each shard to GPU memory (for NCCL) andcommunicates it to ranks appropriately. Hence, the former hashigher aggregate CPU memory cost, while the latter has highercommunication cost.

Parameters:
  • full_optim_state_dict (Dict[str,Any]) – Optimizer state dictcorresponding to the unflattened parameters and holding thefull non-sharded optimizer state.

  • model (torch.nn.Module) – Root module (which may or may not be aFullyShardedDataParallel instance) whose parameterscorrespond to the optimizer state infull_optim_state_dict.

  • optim_input (Optional[Union[List[Dict[str,Any]],Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either alist of parameter groups or an iterable of parameters;ifNone, then this method assumes the input wasmodel.parameters(). This argument is deprecated, and thereis no need to pass it in anymore. (Default:None)

  • optim (Optional[torch.optim.Optimizer]) – Optimizer that will loadthe state dict returned by this method. This is the preferredargument to use overoptim_input. (Default:None)

Returns:

The full optimizer state dict now remapped toflattened parameters instead of unflattened parameters andrestricted to only include this rank’s part of the optimizer state.

Return type:

Dict[str, Any]

staticsharded_optim_state_dict(model,optim,group=None)[source]#

Return the optimizer state-dict in its sharded form.

The API is similar tofull_optim_state_dict() but this API chunksall non-zero-dimension states toShardedTensor to save memory.This API should only be used when the modelstate_dict is derivedwith the context managerwithstate_dict_type(SHARDED_STATE_DICT):.

For the detailed usage, refer tofull_optim_state_dict().

Warning

The returned state dict containsShardedTensor andcannot be directly used by the regularoptim.load_state_dict.

Return type:

dict[str,Any]

staticstate_dict_type(module,state_dict_type,state_dict_config=None,optim_state_dict_config=None)[source]#

Set thestate_dict_type of all the descendant FSDP modules of the target module.

This context manager has the same functions asset_state_dict_type(). Read the document ofset_state_dict_type() for the detail.

Example:

>>>model=DDP(FSDP(...))>>>withFSDP.state_dict_type(>>>model,>>>StateDictType.SHARDED_STATE_DICT,>>>):>>>checkpoint=model.state_dict()
Parameters:
  • module (torch.nn.Module) – Root module.

  • state_dict_type (StateDictType) – the desiredstate_dict_type to set.

  • state_dict_config (Optional[StateDictConfig]) – the modelstate_dictconfiguration for the targetstate_dict_type.

  • optim_state_dict_config (Optional[OptimStateDictConfig]) – the optimizerstate_dict configuration for the targetstate_dict_type.

Return type:

Generator

staticsummon_full_params(module,recurse=True,writeback=True,rank0_only=False,offload_to_cpu=False,with_grads=False)[source]#

Expose full params for FSDP instances with this context manager.

Can be usefulafter forward/backward for a model to getthe params for additional processing or checking. It can take a non-FSDPmodule and will summon full params for all contained FSDP modules aswell as their children, depending on therecurse argument.

Note

This can be used on inner FSDPs.

Note

This cannot be used within a forward or backward pass. Norcan forward and backward be started from within this context.

Note

Parameters will revert to their local shards after the contextmanager exits, storage behavior is the same as forward.

Note

The full parameters can be modified, but only the portioncorresponding to the local param shard will persist after thecontext manager exits (unlesswriteback=False, in which casechanges will be discarded). In the case where FSDP does not shardthe parameters, currently only whenworld_size==1, orNO_SHARDconfig, the modification is persisted regardless ofwriteback.

Note

This method works on modules which are not FSDP themselves butmay contain multiple independent FSDP units. In that case, the givenarguments will apply to all contained FSDP units.

Warning

Note thatrank0_only=True in conjunction withwriteback=True is not currently supported and will raise anerror. This is because model parameter shapes would be differentacross ranks within the context, and writing to them can lead toinconsistency across ranks when the context is exited.

Warning

Note thatoffload_to_cpu andrank0_only=False willresult in full parameters being redundantly copied to CPU memory forGPUs that reside on the same machine, which may incur the risk ofCPU OOM. It is recommended to useoffload_to_cpu withrank0_only=True.

Parameters:
  • recurse (bool,Optional) – recursively summon all params for nestedFSDP instances (default: True).

  • writeback (bool,Optional) – ifFalse, modifications to params arediscarded after the context manager exits;disabling this can be slightly more efficient (default: True)

  • rank0_only (bool,Optional) – ifTrue, full parameters arematerialized on only global rank 0. This means that within thecontext, only rank 0 will have full parameters and the otherranks will have sharded parameters. Note that settingrank0_only=True withwriteback=True is not supported,as model parameter shapes will be different across rankswithin the context, and writing to them can lead toinconsistency across ranks when the context is exited.

  • offload_to_cpu (bool,Optional) – IfTrue, full parameters areoffloaded to CPU. Note that this offloading currently onlyoccurs if the parameter is sharded (which is only not the casefor world_size = 1 orNO_SHARD config). It is recommendedto useoffload_to_cpu withrank0_only=True to avoidredundant copies of model parameters being offloaded to the same CPU memory.

  • with_grads (bool,Optional) – IfTrue, gradients are alsounsharded with the parameters. Currently, this is onlysupported when passinguse_orig_params=True to the FSDPconstructor andoffload_to_cpu=False to this method.(Default:False)

Return type:

Generator

classtorch.distributed.fsdp.BackwardPrefetch(value)[source]#

This configures explicit backward prefetching, which improves throughput byenabling communication and computation overlap in the backward pass at thecost of slightly increased memory usage.

  • BACKWARD_PRE: This enables the most overlap but increases memoryusage the most. This prefetches the next set of parametersbefore thecurrent set of parameters’ gradient computation. This overlaps thenextall-gather and thecurrent gradient computation, and at the peak, itholds the current set of parameters, next set of parameters, and currentset of gradients in memory.

  • BACKWARD_POST: This enables less overlap but requires less memoryusage. This prefetches the next set of parametersafter the currentset of parameters’ gradient computation. This overlaps thecurrentreduce-scatter and thenext gradient computation, and it frees thecurrent set of parameters before allocating memory for the next set ofparameters, only holding the next set of parameters and current set ofgradients in memory at the peak.

  • FSDP’sbackward_prefetch argument acceptsNone, which disablesthe backward prefetching altogether. This has no overlap and does notincrease memory usage. In general, we do not recommend this setting sinceit may degrade throughput significantly.

For more technical context: For a single process group using NCCL backend,any collectives, even if issued from different streams, contend for thesame per-device NCCL stream, which implies that the relative order in whichthe collectives are issued matters for overlapping. The two backwardprefetching values correspond to different issue orders.

classtorch.distributed.fsdp.ShardingStrategy(value)[source]#

This specifies the sharding strategy to be used for distributed training byFullyShardedDataParallel.

  • FULL_SHARD: Parameters, gradients, and optimizer states are sharded.For the parameters, this strategy unshards (via all-gather) before theforward, reshards after the forward, unshards before the backwardcomputation, and reshards after the backward computation. For gradients,it synchronizes and shards them (via reduce-scatter) after the backwardcomputation. The sharded optimizer states are updated locally per rank.

  • SHARD_GRAD_OP: Gradients and optimizer states are sharded duringcomputation, and additionally, parameters are sharded outsidecomputation. For the parameters, this strategy unshards before theforward, does not reshard them after the forward, and only reshards themafter the backward computation. The sharded optimizer states are updatedlocally per rank. Insideno_sync(), the parameters are not reshardedafter the backward computation.

  • NO_SHARD: Parameters, gradients, and optimizer states are not shardedbut instead replicated across ranks similar to PyTorch’sDistributedDataParallel API. For gradients, this strategysynchronizes them (via all-reduce) after the backward computation. Theunsharded optimizer states are updated locally per rank.

  • HYBRID_SHARD: ApplyFULL_SHARD within a node, and replicate parameters acrossnodes. This results in reduced communication volume as expensive all-gathers andreduce-scatters are only done within a node, which can be more performant for medium-sized models.

  • _HYBRID_SHARD_ZERO2: ApplySHARD_GRAD_OP within a node, and replicate parameters acrossnodes. This is likeHYBRID_SHARD, except this may provide even higher throughputsince the unsharded parameters are not freed after the forward pass, saving theall-gathers in the pre-backward.

classtorch.distributed.fsdp.MixedPrecision(param_dtype=None,reduce_dtype=None,buffer_dtype=None,keep_low_precision_grads=False,cast_forward_inputs=False,cast_root_forward_inputs=True,_module_classes_to_ignore=(<class'torch.nn.modules.batchnorm._BatchNorm'>,))[source]#

This configures FSDP-native mixed precision training.

Variables:
  • param_dtype (Optional[torch.dtype]) – This specifies the dtype for modelparameters during forward and backward and thus the dtype forforward and backward computation. Outside forward and backward, thesharded parameters are kept in full precision (e.g. for theoptimizer step), and for model checkpointing, the parameters arealways saved in full precision. (Default:None)

  • reduce_dtype (Optional[torch.dtype]) – This specifies the dtype forgradient reduction (i.e. reduce-scatter or all-reduce). If this isNone butparam_dtype is notNone, then this takes ontheparam_dtype value, still running gradient reduction in lowprecision. This is permitted to differ fromparam_dtype, e.g.to force gradient reduction to run in full precision. (Default:None)

  • buffer_dtype (Optional[torch.dtype]) – This specifies the dtype forbuffers. FSDP does not shard buffers. Rather, FSDP casts them tobuffer_dtype in the first forward pass and keeps them in thatdtype thereafter. For model checkpointing, the buffers are savedin full precision except forLOCAL_STATE_DICT. (Default:None)

  • keep_low_precision_grads (bool) – IfFalse, then FSDP upcastsgradients to full precision after the backward pass in preparationfor the optimizer step. IfTrue, then FSDP keeps the gradientsin the dtype used for gradient reduction, which can save memory ifusing a custom optimizer that supports running in low precision.(Default:False)

  • cast_forward_inputs (bool) – IfTrue, then this FSDP module castsits forward args and kwargs toparam_dtype. This is to ensurethat parameter and input dtypes match for forward computation, asrequired by many ops. This may need to be set toTrue when onlyapplying mixed precision to some but not all FSDP modules, in whichcase a mixed-precision FSDP submodule needs to recast its inputs.(Default:False)

  • cast_root_forward_inputs (bool) – IfTrue, then the root FSDP modulecasts its forward args and kwargs toparam_dtype, overridingthe value ofcast_forward_inputs. For non-root FSDP modules,this does not do anything. (Default:True)

  • _module_classes_to_ignore (collections.abc.Sequence[type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): This specifiesmodule classes to ignore for mixed precision when using anauto_wrap_policy: Modules of these classes will have FSDPapplied to them separately with mixed precision disabled (meaningthat the final FSDP construction would deviate from the specifiedpolicy). Ifauto_wrap_policy is not specified, then this doesnot do anything. This API is experimental and subject to change.(Default:(_BatchNorm,))

Note

This API is experimental and subject to change.

Note

Only floating point tensors are cast to their specified dtypes.

Note

Insummon_full_params, parameters are forced to fullprecision, but buffers are not.

Note

Layer norm and batch norm accumulate infloat32 even whentheir inputs are in a low precision likefloat16 orbfloat16.Disabling FSDP’s mixed precision for those norm modules only means thatthe affine parameters are kept infloat32. However, this incursseparate all-gathers and reduce-scatters for those norm modules, whichmay be inefficient, so if the workload permits, the user should preferto still apply mixed precision to those modules.

Note

By default, if the user passes a model with any_BatchNormmodules and specifies anauto_wrap_policy, then the batch normmodules will have FSDP applied to them separately with mixed precisiondisabled. See the_module_classes_to_ignore argument.

Note

MixedPrecision hascast_root_forward_inputs=True andcast_forward_inputs=False by default. For the root FSDP instance,itscast_root_forward_inputs takes precedence over itscast_forward_inputs. For non-root FSDP instances, theircast_root_forward_inputs values are ignored. The default setting issufficient for the typical case where each FSDP instance has the sameMixedPrecision configuration and only needs to cast inputs to theparam_dtype at the beginning of the model’s forward pass.

Note

For nested FSDP instances with differentMixedPrecisionconfigurations, we recommend setting individualcast_forward_inputsvalues to configure casting inputs or not before each instance’sforward. In such a case, since the casts happen before each FSDPinstance’s forward, a parent FSDP instance should have its non-FSDPsubmodules run before its FSDP submodules to avoid the activation dtypebeing changed due to a differentMixedPrecision configuration.

Example:

>>>model=nn.Sequential(nn.Linear(3,3),nn.Linear(3,3))>>>model[1]=FSDP(>>>model[1],>>>mixed_precision=MixedPrecision(param_dtype=torch.float16,cast_forward_inputs=True),>>>)>>>model=FSDP(>>>model,>>>mixed_precision=MixedPrecision(param_dtype=torch.bfloat16,cast_forward_inputs=True),>>>)

The above shows a working example. On the other hand, ifmodel[1]were replaced withmodel[0], meaning that the submodule usingdifferentMixedPrecision ran its forward first, thenmodel[1]would incorrectly seefloat16 activations instead ofbfloat16ones.

classtorch.distributed.fsdp.CPUOffload(offload_params=False)[source]#

This configures CPU offloading.

Variables:

offload_params (bool) – This specifies whether to offload parameters toCPU when not involved in computation. IfTrue, then thisoffloads gradients to CPU as well, meaning that the optimizer stepruns on CPU.

classtorch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[source]#

StateDictConfig is the base class for allstate_dict configurationclasses. Users should instantiate a child class (e.g.FullStateDictConfig) in order to configure settings for thecorrespondingstate_dict type supported by FSDP.

Variables:

offload_to_cpu (bool) – IfTrue, then FSDP offloads the state dictvalues to CPU, and ifFalse, then FSDP keeps them on GPU.(Default:False)

classtorch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False,rank0_only=False)[source]#

FullStateDictConfig is a config class meant to be used withStateDictType.FULL_STATE_DICT. We recommend enabling bothoffload_to_cpu=True andrank0_only=True when saving full statedicts to save GPU memory and CPU memory, respectively. This config classis meant to be used via thestate_dict_type() context manager asfollows:

>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>fsdp=FSDP(model,auto_wrap_policy=...)>>>cfg=FullStateDictConfig(offload_to_cpu=True,rank0_only=True)>>>withFSDP.state_dict_type(fsdp,StateDictType.FULL_STATE_DICT,cfg):>>>state=fsdp.state_dict()>>># `state` will be empty on non rank 0 and contain CPU tensors on rank 0.>>># To reload checkpoint for inference, finetuning, transfer learning, etc:>>>model=model_fn()# Initialize model in preparation for wrapping with FSDP>>>ifdist.get_rank()==0:>>># Load checkpoint only on rank 0 to avoid memory redundancy>>>state_dict=torch.load("my_checkpoint.pt")>>>model.load_state_dict(state_dict)>>># All ranks initialize FSDP module as usual. `sync_module_states` argument>>># communicates loaded checkpoint states from rank 0 to rest of the world.>>>fsdp=FSDP(...model,...device_id=torch.cuda.current_device(),...auto_wrap_policy=...,...sync_module_states=True,...)>>># After this point, all ranks have FSDP model with loaded checkpoint.
Variables:

rank0_only (bool) – IfTrue, then only rank 0 saves the full statedict, and nonzero ranks save an empty dict. IfFalse, then allranks save the full state dict. (Default:False)

classtorch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False,_use_dtensor=False)[source]#

ShardedStateDictConfig is a config class meant to be used withStateDictType.SHARDED_STATE_DICT.

Variables:

_use_dtensor (bool) – IfTrue, then FSDP saves the state dict valuesasDTensor, and ifFalse, then FSDP saves them asShardedTensor. (Default:False)

Warning

_use_dtensor is a private field ofShardedStateDictConfigand it is used by FSDP to determine the type of state dict values. Users should notmanually modify_use_dtensor.

classtorch.distributed.fsdp.LocalStateDictConfig(offload_to_cpu:bool=False)[source]#
classtorch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[source]#

OptimStateDictConfig is the base class for alloptim_state_dictconfiguration classes. Users should instantiate a child class (e.g.FullOptimStateDictConfig) in order to configure settings for thecorrespondingoptim_state_dict type supported by FSDP.

Variables:

offload_to_cpu (bool) – IfTrue, then FSDP offloads the state dict’stensor values to CPU, and ifFalse, then FSDP keeps them on theoriginal device (which is GPU unless parameter CPU offloading isenabled). (Default:True)

classtorch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True,rank0_only=False)[source]#
Variables:

rank0_only (bool) – IfTrue, then only rank 0 saves the full statedict, and nonzero ranks save an empty dict. IfFalse, then allranks save the full state dict. (Default:False)

classtorch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True,_use_dtensor=False)[source]#

ShardedOptimStateDictConfig is a config class meant to be used withStateDictType.SHARDED_STATE_DICT.

Variables:

_use_dtensor (bool) – IfTrue, then FSDP saves the state dict valuesasDTensor, and ifFalse, then FSDP saves them asShardedTensor. (Default:False)

Warning

_use_dtensor is a private field ofShardedOptimStateDictConfigand it is used by FSDP to determine the type of state dict values. Users should notmanually modify_use_dtensor.

classtorch.distributed.fsdp.LocalOptimStateDictConfig(offload_to_cpu:bool=False)[source]#
classtorch.distributed.fsdp.StateDictSettings(state_dict_type:torch.distributed.fsdp.api.StateDictType,state_dict_config:torch.distributed.fsdp.api.StateDictConfig,optim_state_dict_config:torch.distributed.fsdp.api.OptimStateDictConfig)[source]#
On this page