FullyShardedDataParallel#
Created On: Feb 02, 2022 | Last Updated On: Jun 11, 2025
- classtorch.distributed.fsdp.FullyShardedDataParallel(module,process_group=None,sharding_strategy=None,cpu_offload=None,auto_wrap_policy=None,backward_prefetch=BackwardPrefetch.BACKWARD_PRE,mixed_precision=None,ignored_modules=None,param_init_fn=None,device_id=None,sync_module_states=False,forward_prefetch=False,limit_all_gathers=True,use_orig_params=False,ignored_states=None,device_mesh=None)[source]#
A wrapper for sharding module parameters across data parallel workers.
This is inspired byXu et al. aswell as the ZeRO Stage 3 fromDeepSpeed.FullyShardedDataParallel is commonly shortened to FSDP.
Example:
>>>importtorch>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>torch.cuda.set_device(device_id)>>>sharded_module=FSDP(my_module)>>>optim=torch.optim.Adam(sharded_module.parameters(),lr=0.0001)>>>x=sharded_module(x,y=3,z=torch.Tensor([1]))>>>loss=x.sum()>>>loss.backward()>>>optim.step()
Using FSDP involves wrapping your module and then initializing youroptimizer after. This is required since FSDP changes the parametervariables.
When setting up FSDP, you need to consider the destination CUDAdevice. If the device has an ID (
dev_id), you have three options:Place the module on that device
Set the device using
torch.cuda.set_device(dev_id)Pass
dev_idinto thedevice_idconstructor argument.
This ensures that the FSDP instance’s compute device is thedestination device. For option 1 and 3, the FSDP initializationalways occurs on GPU. For option 2, the FSDP initializationhappens on module’s current device, which may be a CPU.
If you’re using the
sync_module_states=Trueflag, you need toensure that the module is on a GPU or use thedevice_idargument to specify a CUDA device that FSDP will move the moduleto in the FSDP constructor. This is necessary becausesync_module_states=Truerequires GPU communication.FSDP also takes care of moving input tensors to the forward methodto the GPU compute device, so you don’t need to manually move themfrom CPU.
For
use_orig_params=True,ShardingStrategy.SHARD_GRAD_OPexposes the unshardedparameters, not the sharded parameters after forward, unlikeShardingStrategy.FULL_SHARD. If you wantto inspect the gradients, you can use thesummon_full_paramsmethod withwith_grads=True.With
limit_all_gathers=True, you may see a gap in the FSDPpre-forward where the CPU thread is not issuing any kernels. This isintentional and shows the rate limiter in effect. Synchronizing the CPUthread in that way prevents over-allocating memory for subsequentall-gathers, and it should not actually delay GPU kernel execution.FSDP replaces managed modules’ parameters with
torch.Tensorviews during forward and backward computation for autograd-relatedreasons. If your module’s forward relies on saved references tothe parameters instead of reacquiring the references eachiteration, then it will not see FSDP’s newly created views,and autograd will not work correctly.Finally, when using
sharding_strategy=ShardingStrategy.HYBRID_SHARDwith the sharding process group being intra-node and thereplication process group being inter-node, settingNCCL_CROSS_NIC=1can help improve the all-reduce times overthe replication process group for some cluster setups.Limitations
There are several limitations to be aware of when using FSDP:
FSDP currently does not support gradient accumulation outside
no_sync()when using CPU offloading. This is because FSDPuses the newly-reduced gradient instead of accumulating with anyexisting gradient, which can lead to incorrect results.FSDP does not support running the forward pass of a submodulethat is contained in an FSDP instance. This is because thesubmodule’s parameters will be sharded, but the submodule itselfis not an FSDP instance, so its forward pass will not all-gatherthe full parameters appropriately.
FSDP does not work with double backwards due to the way itregisters backward hooks.
FSDP has some constraints when freezing parameters.For
use_orig_params=False, each FSDP instance must manageparameters that are all frozen or all non-frozen. Foruse_orig_params=True, FSDP supports mixing frozen andnon-frozen parameters, but it’s recommended to avoid doing so toprevent higher than expected gradient memory usage.As of PyTorch 1.12, FSDP offers limited support for sharedparameters. If enhanced shared parameter support is needed foryour use case, please post inthis issue.
You should avoid modifying the parameters between forward andbackward without using the
summon_full_paramscontext, asthe modifications may not persist.
- Parameters:
module (nn.Module) – This is the module to be wrapped with FSDP.
process_group (Optional[Union[ProcessGroup,Tuple[ProcessGroup,ProcessGroup]]]) – This is the process group over which the model is sharded and thusthe one used for FSDP’s all-gather and reduce-scatter collectivecommunications. If
None, then FSDP uses the default processgroup. For hybrid sharding strategies such asShardingStrategy.HYBRID_SHARD, users can pass in a tuple ofprocess groups, representing the groups over which to shard andreplicate, respectively. IfNone, then FSDP constructs processgroups for the user to shard intra-node and replicate inter-node.(Default:None)sharding_strategy (Optional[ShardingStrategy]) – This configures the sharding strategy, which may trade off memorysaving and communication overhead. See
ShardingStrategyfor details. (Default:FULL_SHARD)cpu_offload (Optional[CPUOffload]) – This configures CPU offloading. If this is set to
None, thenno CPU offloading happens. SeeCPUOffloadfor details.(Default:None)auto_wrap_policy (Optional[Union[Callable[[nn.Module,bool,int],bool],ModuleWrapPolicy,CustomPolicy]]) –
This specifies a policy to apply FSDP to submodules of
module,which is needed for communication and computation overlap and thusaffects performance. IfNone, then FSDP only applies tomodule, and users should manually apply FSDP to parent modulesthemselves (proceeding bottom-up). For convenience, this acceptsModuleWrapPolicydirectly, which allows users to specify themodule classes to wrap (e.g. the transformer block). Otherwise,this should be a callable that takes in three argumentsmodule:nn.Module,recurse:bool, andnonwrapped_numel:intand should return aboolspecifyingwhether the passed-inmoduleshould have FSDP applied ifrecurse=Falseor if the traversal should continue into themodule’s subtree ifrecurse=True. Users may add additionalarguments to the callable. Thesize_based_auto_wrap_policyintorch.distributed.fsdp.wrap.pygives an example callable thatapplies FSDP to a module if the parameters in its subtree exceed100M numel. We recommend printing the model after applying FSDPand adjusting as needed.Example:
>>>defcustom_auto_wrap_policy(>>>module:nn.Module,>>>recurse:bool,>>>nonwrapped_numel:int,>>># Additional custom arguments>>>min_num_params:int=int(1e8),>>>)->bool:>>>returnnonwrapped_numel>=min_num_params>>># Configure a custom `min_num_params`>>>my_auto_wrap_policy=functools.partial(custom_auto_wrap_policy,min_num_params=int(1e5))
backward_prefetch (Optional[BackwardPrefetch]) – This configures explicit backward prefetching of all-gathers. If
None, then FSDP does not backward prefetch, and there is nocommunication and computation overlap in the backward pass. SeeBackwardPrefetchfor details. (Default:BACKWARD_PRE)mixed_precision (Optional[MixedPrecision]) – This configures native mixed precision for FSDP. If this is set to
None, then no mixed precision is used. Otherwise, parameter,buffer, and gradient reduction dtypes can be set. SeeMixedPrecisionfor details. (Default:None)ignored_modules (Optional[Iterable[torch.nn.Module]]) – Modules whoseown parameters and child modules’ parameters and buffers areignored by this instance. None of the modules directly in
ignored_modulesshould beFullyShardedDataParallelinstances, and any child modules that are already-constructedFullyShardedDataParallelinstances will not be ignored ifthey are nested under this instance. This argument may be used toavoid sharding specific parameters at module granularity when using anauto_wrap_policyor if parameters’ sharding is not managed byFSDP. (Default:None)param_init_fn (Optional[Callable[[nn.Module],None]]) –
A
Callable[torch.nn.Module]->Nonethatspecifies how modules that are currently on the meta device shouldbe initialized onto an actual device. As of v1.12, FSDP detectsmodules with parameters or buffers on meta device viais_metaand either appliesparam_init_fnif specified or callsnn.Module.reset_parameters()otherwise. For both cases, theimplementation shouldonly initialize the parameters/buffers ofthe module, not those of its submodules. This is to avoidre-initialization. In addition, FSDP also supports deferredinitialization via torchdistX’s (pytorch/torchdistX)deferred_init()API, where the deferred modules are initializedby callingparam_init_fnif specified or torchdistX’s defaultmaterialize_module()otherwise. Ifparam_init_fnisspecified, then it is applied to all meta-device modules, meaningthat it should probably case on the module type. FSDP calls theinitialization function before parameter flattening and sharding.Example:
>>>module=MyModule(device="meta")>>>defmy_init_fn(module:nn.Module):>>># E.g. initialize depending on the module type>>>...>>>fsdp_model=FSDP(module,param_init_fn=my_init_fn,auto_wrap_policy=size_based_auto_wrap_policy)>>>print(next(fsdp_model.parameters()).device)# current CUDA device>>># With torchdistX>>>module=deferred_init.deferred_init(MyModule,device="cuda")>>># Will initialize via deferred_init.materialize_module().>>>fsdp_model=FSDP(module,auto_wrap_policy=size_based_auto_wrap_policy)
device_id (Optional[Union[int,torch.device]]) – An
intortorch.devicegiving the CUDA device on which FSDPinitialization takes place, including the module initializationif needed and the parameter sharding. This should be specified toimprove initialization speed ifmoduleis on CPU. If thedefault CUDA device was set (e.g. viatorch.cuda.set_device),then the user may passtorch.cuda.current_deviceto this.(Default:None)sync_module_states (bool) – If
True, then each FSDP module willbroadcast module parameters and buffers from rank 0 to ensure thatthey are replicated across ranks (adding communication overhead tothis constructor). This can help loadstate_dictcheckpointsviaload_state_dictin a memory efficient way. SeeFullStateDictConfigfor an example of this. (Default:False)forward_prefetch (bool) – If
True, then FSDPexplicitly prefetchesthe next forward-pass all-gather before the current forwardcomputation. This is only useful for CPU-bound workloads, in whichcase issuing the next all-gather earlier may improve overlap. Thisshould only be used for static-graph models since the prefetchingfollows the first iteration’s execution order. (Default:False)limit_all_gathers (bool) – If
True, then FSDP explicitlysynchronizes the CPU thread to ensure GPU memory usage from onlytwo consecutive FSDP instances (the current instance runningcomputation and the next instance whose all-gather is prefetched).IfFalse, then FSDP allows the CPU thread to issue all-gatherswithout any extra synchronization. (Default:True) We oftenrefer to this feature as the “rate limiter”. This flag should onlybe set toFalsefor specific CPU-bound workloads with lowmemory pressure in which case the CPU thread can aggressively issueall kernels without concern for the GPU memory usage.use_orig_params (bool) – Setting this to
Truehas FSDP usemodule‘s original parameters. FSDP exposes those originalparameters to the user viann.Module.named_parameters()instead of FSDP’s internalFlatParameters. This meansthat the optimizer step runs on the original parameters, enablingper-original-parameter hyperparameters. FSDP preserves the originalparameter variables and manipulates their data between unshardedand sharded forms, where they are always views into the underlyingunsharded or shardedFlatParameter, respectively. With thecurrent algorithm, the sharded form is always 1D, losing theoriginal tensor structure. An original parameter may have all,some, or none of its data present for a given rank. In the nonecase, its data will be like a size-0 empty tensor. Users should notauthor programs relying on what data is present for a givenoriginal parameter in its sharded form.Trueis required tousetorch.compile(). Setting this toFalseexposes FSDP’sinternalFlatParameters to the user viann.Module.named_parameters(). (Default:False)ignored_states (Optional[Iterable[torch.nn.Parameter]],Optional[Iterable[torch.nn.Module]]) – Ignored parameters or modules that will not be managed by this FSDPinstance, meaning that the parameters are not sharded and theirgradients are not reduced across ranks. This argument unifies withthe existing
ignored_modulesargument, and we may deprecateignored_modulessoon. For backward compatibility, we keep bothignored_statesandignored_modules`, but FSDP only allows oneof them to be specified as notNone.device_mesh (Optional[DeviceMesh]) – DeviceMesh can be used as an alternative toprocess_group. When device_mesh is passed, FSDP will use the underlying processgroups for all-gather and reduce-scatter collective communications. Therefore,these two args need to be mutually exclusive. For hybrid sharding strategies such as
ShardingStrategy.HYBRID_SHARD, users can pass in a 2D DeviceMesh insteadof a tuple of process groups. For 2D FSDP + TP, users are required to pass indevice_mesh instead of process_group. For more DeviceMesh info, please visit:https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
- apply(fn)[source]#
Apply
fnrecursively to every submodule (as returned by.children()) as well as self.Typical use includes initializing the parameters of a model (see alsotorch.nn.init).
Compared to
torch.nn.Module.apply, this version additionally gathersthe full parameters before applyingfn. It should not be called fromwithin anothersummon_full_paramscontext.- Parameters:
fn (
Module-> None) – function to be applied to each submodule- Returns:
self
- Return type:
- Module
- clip_grad_norm_(max_norm,norm_type=2.0)[source]#
Clip the gradient norm of all parameters.
The norm is computed over all parameters’ gradients as viewed as a single vector, and thegradients are modified in-place.
- Parameters:
- Returns:
Total norm of the parameters (viewed as a single vector).
- Return type:
If every FSDP instance uses
NO_SHARD, meaning that nogradients are sharded across ranks, then you may directly usetorch.nn.utils.clip_grad_norm_().If at least some FSDP instance uses a sharded strategy (i.e.one other than
NO_SHARD), then you should use this methodinstead oftorch.nn.utils.clip_grad_norm_()since this methodhandles the fact that gradients are sharded across ranks.The total norm returned will have the “largest” dtype acrossall parameters/gradients as defined by PyTorch’s type promotionsemantics. For example, ifall parameters/gradients use a lowprecision dtype, then the returned norm’s dtype will be that lowprecision dtype, but if there exists at least one parameter/gradient using FP32, then the returned norm’s dtype will be FP32.
Warning
This needs to be called on all ranks since it usescollective communications.
- staticflatten_sharded_optim_state_dict(sharded_optim_state_dict,model,optim)[source]#
Flatten a sharded optimizer state-dict.
The API is similar to
shard_full_optim_state_dict(). The onlydifference is that the inputsharded_optim_state_dictshould bereturned fromsharded_optim_state_dict(). Therefore, there willbe all-gather calls on each rank to gatherShardedTensors.- Parameters:
sharded_optim_state_dict (Dict[str,Any]) – Optimizer state dictcorresponding to the unflattened parameters and holding thesharded optimizer state.
model (torch.nn.Module) – Refer to
shard_full_optim_state_dict().optim (torch.optim.Optimizer) – Optimizer for
model‘sparameters.
- Returns:
Refer to
shard_full_optim_state_dict().- Return type:
- forward(*args,**kwargs)[source]#
Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.
- Return type:
- staticfsdp_modules(module,root_only=False)[source]#
Return all nested FSDP instances.
This possibly includes
moduleitself and only includes FSDP root modules ifroot_only=True.- Parameters:
module (torch.nn.Module) – Root module, which may or may not be an
FSDPmodule.root_only (bool) – Whether to return only FSDP root modules.(Default:
False)
- Returns:
FSDP modules that are nested inthe input
module.- Return type:
List[FullyShardedDataParallel]
- staticfull_optim_state_dict(model,optim,optim_input=None,rank0_only=True,group=None)[source]#
Return the full optimizer state-dict.
Consolidates the full optimizer state on rank 0 and returns itas a
dictfollowing the convention oftorch.optim.Optimizer.state_dict(), i.e. with keys"state"and"param_groups". The flattened parameters inFSDPmodulescontained inmodelare mapped back to their unflattened parameters.This needs to be called on all ranks since it usescollective communications. However, if
rank0_only=True, thenthe state dict is only populated on rank 0, and all other ranksreturn an emptydict.Unlike
torch.optim.Optimizer.state_dict(), this methoduses full parameter names as keys instead of parameter IDs.Like in
torch.optim.Optimizer.state_dict(), the tensorscontained in the optimizer state dict are not cloned, so there maybe aliasing surprises. For best practices, consider saving thereturned optimizer state dict immediately, e.g. usingtorch.save().- Parameters:
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallelinstance) whose parameterswere passed into the optimizeroptim.optim (torch.optim.Optimizer) – Optimizer for
model‘sparameters.optim_input (Optional[Union[List[Dict[str,Any]],Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer
optimrepresenting either alistof parameter groups or an iterable of parameters;ifNone, then this method assumes the input wasmodel.parameters(). This argument is deprecated, and thereis no need to pass it in anymore. (Default:None)rank0_only (bool) – If
True, saves the populateddictonly on rank 0; ifFalse, saves it on all ranks. (Default:True)group (dist.ProcessGroup) – Model’s process group or
Noneif usingthe default process group. (Default:None)
- Returns:
A
dictcontaining the optimizer state formodel‘s original unflattened parameters and including keys“state” and “param_groups” following the convention oftorch.optim.Optimizer.state_dict(). Ifrank0_only=True,then nonzero ranks return an emptydict.- Return type:
Dict[str, Any]
- staticget_state_dict_type(module)[source]#
Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at
module.The target module does not have to be an FSDP module.
- Returns:
A
StateDictSettingscontaining the state_dict_type andstate_dict / optim_state_dict configs that are currently set.- Raises:
AssertionError` if the StateDictSettings for differen –
FSDP submodules differ. –
- Return type:
- named_buffers(*args,**kwargs)[source]#
Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.
Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefixwhen inside the
summon_full_params()context manager.
- named_parameters(*args,**kwargs)[source]#
Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself.
Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefixwhen inside the
summon_full_params()context manager.
- no_sync()[source]#
Disable gradient synchronizations across FSDP instances.
Within this context, gradients will be accumulated in modulevariables, which will later be synchronized in the firstforward-backward pass after exiting the context. This should only beused on the root FSDP instance and will recursively apply to allchildren FSDP instances.
Note
This likely results in higher memory usage because FSDP willaccumulate the full model gradients (instead of gradient shards)until the eventual sync.
Note
When used with CPU offloading, the gradients will not beoffloaded to CPU when inside the context manager. Instead, theywill only be offloaded right after the eventual sync.
- Return type:
- staticoptim_state_dict(model,optim,optim_state_dict=None,group=None)[source]#
Transform the state-dict of an optimizer corresponding to a sharded model.
The given state-dict can be transformed to one of three types:1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict.
For full optimizer state_dict, all states are unflattened and not sharded.Rank0 only and CPU only can be specified via
state_dict_type()toavoid OOM.For sharded optimizer state_dict, all states are unflattened but sharded.CPU only can be specified via
state_dict_type()to further savememory.For local state_dict, no transformation will be performed. But a statewill be converted from nn.Tensor to ShardedTensor to represent its shardingnature (this is not supported yet).
Example:
>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>fromtorch.distributed.fsdpimportStateDictType>>>fromtorch.distributed.fsdpimportFullStateDictConfig>>>fromtorch.distributed.fsdpimportFullOptimStateDictConfig>>># Save a checkpoint>>>model,optim=...>>>FSDP.set_state_dict_type(>>>model,>>>StateDictType.FULL_STATE_DICT,>>>FullStateDictConfig(rank0_only=False),>>>FullOptimStateDictConfig(rank0_only=False),>>>)>>>state_dict=model.state_dict()>>>optim_state_dict=FSDP.optim_state_dict(model,optim)>>>save_a_checkpoint(state_dict,optim_state_dict)>>># Load a checkpoint>>>model,optim=...>>>state_dict,optim_state_dict=load_a_checkpoint()>>>FSDP.set_state_dict_type(>>>model,>>>StateDictType.FULL_STATE_DICT,>>>FullStateDictConfig(rank0_only=False),>>>FullOptimStateDictConfig(rank0_only=False),>>>)>>>model.load_state_dict(state_dict)>>>optim_state_dict=FSDP.optim_state_dict_to_load(>>>model,optim,optim_state_dict>>>)>>>optim.load_state_dict(optim_state_dict)
- Parameters:
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallelinstance) whose parameterswere passed into the optimizeroptim.optim (torch.optim.Optimizer) – Optimizer for
model‘sparameters.optim_state_dict (Dict[str,Any]) – the target optimizer state_dict totransform. If the value is None, optim.state_dict() will be used. (Default:
None)group (dist.ProcessGroup) – Model’s process group across which parametersare sharded or
Noneif using the default process group. (Default:None)
- Returns:
A
dictcontaining the optimizer state formodel. The sharding of the optimizer state is based onstate_dict_type.- Return type:
Dict[str, Any]
- staticoptim_state_dict_to_load(model,optim,optim_state_dict,is_named_optimizer=False,load_directly=False,group=None)[source]#
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
Given a
optim_state_dictthat is transformed throughoptim_state_dict(), it gets converted to the flattened optimizerstate_dict that can be loaded tooptimwhich is the optimizer formodel.modelmust be sharded by FullyShardedDataParallel.>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>fromtorch.distributed.fsdpimportStateDictType>>>fromtorch.distributed.fsdpimportFullStateDictConfig>>>fromtorch.distributed.fsdpimportFullOptimStateDictConfig>>># Save a checkpoint>>>model,optim=...>>>FSDP.set_state_dict_type(>>>model,>>>StateDictType.FULL_STATE_DICT,>>>FullStateDictConfig(rank0_only=False),>>>FullOptimStateDictConfig(rank0_only=False),>>>)>>>state_dict=model.state_dict()>>>original_osd=optim.state_dict()>>>optim_state_dict=FSDP.optim_state_dict(>>>model,>>>optim,>>>optim_state_dict=original_osd>>>)>>>save_a_checkpoint(state_dict,optim_state_dict)>>># Load a checkpoint>>>model,optim=...>>>state_dict,optim_state_dict=load_a_checkpoint()>>>FSDP.set_state_dict_type(>>>model,>>>StateDictType.FULL_STATE_DICT,>>>FullStateDictConfig(rank0_only=False),>>>FullOptimStateDictConfig(rank0_only=False),>>>)>>>model.load_state_dict(state_dict)>>>optim_state_dict=FSDP.optim_state_dict_to_load(>>>model,optim,optim_state_dict>>>)>>>optim.load_state_dict(optim_state_dict)
- Parameters:
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallelinstance) whose parameterswere passed into the optimizeroptim.optim (torch.optim.Optimizer) – Optimizer for
model‘sparameters.optim_state_dict (Dict[str,Any]) – The optimizer states to be loaded.
is_named_optimizer (bool) – Is this optimizer a NamedOptimizer orKeyedOptimizer. Only set to True if
optimis TorchRec’sKeyedOptimizer or torch.distributed’s NamedOptimizer.load_directly (bool) – If this is set to True, this API will alsocall optim.load_state_dict(result) before returning the result.Otherwise, users are responsible to call
optim.load_state_dict()(Default:False)group (dist.ProcessGroup) – Model’s process group across which parametersare sharded or
Noneif using the default process group. (Default:None)
- Return type:
- register_comm_hook(state,hook)[source]#
Register a communication hook.
This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregatesgradients across multiple workers.This hook can be used to implement several algorithms likeGossipGrad and gradient compressionwhich involve different communication strategies forparameter syncs while training with
FullyShardedDataParallel.Warning
FSDP communication hook should be registered before running an initial forward passand only once.
- Parameters:
state (object) –
Passed to the hook to maintain any state information during the training process.Examples include error feedback in gradient compression,peers to communicate with next inGossipGrad, etc.It is locally stored by each workerand shared by all the gradient tensors on the worker.
hook (Callable) – Callable, which has one of the following signatures:1)
hook:Callable[torch.Tensor]->None:This function takes in a Python tensor, which representsthe full, flattened, unsharded gradient with respect to all variablescorresponding to the model this FSDP unit is wrapping(that are not wrapped by other FSDP sub-units).It then performs all necessary processing and returnsNone;2)hook:Callable[torch.Tensor,torch.Tensor]->None:This function takes in two Python tensors, the first one representsthe full, flattened, unsharded gradient with respect to all variablescorresponding to the model this FSDP unit is wrapping(that are not wrapped by other FSDP sub-units). The latterrepresents a pre-sized tensor to store a chunk of a sharded gradient afterreduction.In both cases, callable performs all necessary processing and returnsNone.Callables with signature 1 are expected to handle gradient communication for aNO_SHARD case.Callables with signature 2 are expected to handle gradient communication for sharded cases.
- staticrekey_optim_state_dict(optim_state_dict,optim_state_key_type,model,optim_input=None,optim=None)[source]#
Re-keys the optimizer state dict
optim_state_dictto use the key typeoptim_state_key_type.This can be used to achieve compatibility between optimizer state dicts from models with FSDPinstances and ones without.
To re-key an FSDP full optimizer state dict (i.e. from
full_optim_state_dict()) to use parameter IDs and be loadable toa non-wrapped model:>>>wrapped_model,wrapped_optim=...>>>full_osd=FSDP.full_optim_state_dict(wrapped_model,wrapped_optim)>>>nonwrapped_model,nonwrapped_optim=...>>>rekeyed_osd=FSDP.rekey_optim_state_dict(full_osd,OptimStateKeyType.PARAM_ID,nonwrapped_model)>>>nonwrapped_optim.load_state_dict(rekeyed_osd)
To re-key a normal optimizer state dict from a non-wrapped model to beloadable to a wrapped model:
>>>nonwrapped_model,nonwrapped_optim=...>>>osd=nonwrapped_optim.state_dict()>>>rekeyed_osd=FSDP.rekey_optim_state_dict(osd,OptimStateKeyType.PARAM_NAME,nonwrapped_model)>>>wrapped_model,wrapped_optim=...>>>sharded_osd=FSDP.shard_full_optim_state_dict(rekeyed_osd,wrapped_model)>>>wrapped_optim.load_state_dict(sharded_osd)
- Returns:
The optimizer state dict re-keyed using theparameter keys specified by
optim_state_key_type.- Return type:
Dict[str, Any]
- staticscatter_full_optim_state_dict(full_optim_state_dict,model,optim_input=None,optim=None,group=None)[source]#
Scatter the full optimizer state dict from rank 0 to all other ranks.
Returns the sharded optimizer state dict on each rank.The return value is the same as
shard_full_optim_state_dict(), and on rank0, the first argument should be the return value offull_optim_state_dict().Example:
>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>model,optim=...>>>full_osd=FSDP.full_optim_state_dict(model,optim)# only non-empty on rank 0>>># Define new model with possibly different world size>>>new_model,new_optim,new_group=...>>>sharded_osd=FSDP.scatter_full_optim_state_dict(full_osd,new_model,group=new_group)>>>new_optim.load_state_dict(sharded_osd)
Note
Both
shard_full_optim_state_dict()andscatter_full_optim_state_dict()may be used to get thesharded optimizer state dict to load. Assuming that the fulloptimizer state dict resides in CPU memory, the former requireseach rank to have the full dict in CPU memory, where each rankindividually shards the dict without any communication, while thelatter requires only rank 0 to have the full dict in CPU memory,where rank 0 moves each shard to GPU memory (for NCCL) andcommunicates it to ranks appropriately. Hence, the former hashigher aggregate CPU memory cost, while the latter has highercommunication cost.- Parameters:
full_optim_state_dict (Optional[Dict[str,Any]]) – Optimizer statedict corresponding to the unflattened parameters and holdingthe full non-sharded optimizer state if on rank 0; the argumentis ignored on nonzero ranks.
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallelinstance) whose parameterscorrespond to the optimizer state infull_optim_state_dict.optim_input (Optional[Union[List[Dict[str,Any]],Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a
listof parameter groups or an iterable of parameters;ifNone, then this method assumes the input wasmodel.parameters(). This argument is deprecated, and thereis no need to pass it in anymore. (Default:None)optim (Optional[torch.optim.Optimizer]) – Optimizer that will loadthe state dict returned by this method. This is the preferredargument to use over
optim_input. (Default:None)group (dist.ProcessGroup) – Model’s process group or
Noneifusing the default process group. (Default:None)
- Returns:
The full optimizer state dict now remapped toflattened parameters instead of unflattened parameters andrestricted to only include this rank’s part of the optimizer state.
- Return type:
Dict[str, Any]
- staticset_state_dict_type(module,state_dict_type,state_dict_config=None,optim_state_dict_config=None)[source]#
Set the
state_dict_typeof all the descendant FSDP modules of the target module.Also takes (optional) configuration for the model’s and optimizer’s state dict.The target module does not have to be a FSDP module. If the targetmodule is a FSDP module, its
state_dict_typewill also be changed.Note
This API should be called for only the top-level (root)module.
Note
This API enables users to transparently use the conventional
state_dictAPI to take model checkpoints in cases where theroot FSDP module is wrapped by anothernn.Module. For example,the following will ensurestate_dictis called on all non-FSDPinstances, while dispatching intosharded_state_dict implementationfor FSDP:Example:
>>>model=DDP(FSDP(...))>>>FSDP.set_state_dict_type(>>>model,>>>StateDictType.SHARDED_STATE_DICT,>>>state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),>>>optim_state_dict_config=OptimStateDictConfig(offload_to_cpu=True),>>>)>>>param_state_dict=model.state_dict()>>>optim_state_dict=FSDP.optim_state_dict(model,optim)
- Parameters:
module (torch.nn.Module) – Root module.
state_dict_type (StateDictType) – the desired
state_dict_typeto set.state_dict_config (Optional[StateDictConfig]) – the configuration for thetarget
state_dict_type.optim_state_dict_config (Optional[OptimStateDictConfig]) – the configurationfor the optimizer state dict.
- Returns:
A StateDictSettings that include the previous state_dict type andconfiguration for the module.
- Return type:
- staticshard_full_optim_state_dict(full_optim_state_dict,model,optim_input=None,optim=None)[source]#
Shard a full optimizer state-dict.
Remaps the state in
full_optim_state_dictto flattened parameters instead of unflattenedparameters and restricts to only this rank’s part of the optimizer state.The first argument should be the return value offull_optim_state_dict().Example:
>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>model,optim=...>>>full_osd=FSDP.full_optim_state_dict(model,optim)>>>torch.save(full_osd,PATH)>>># Define new model with possibly different world size>>>new_model,new_optim=...>>>full_osd=torch.load(PATH)>>>sharded_osd=FSDP.shard_full_optim_state_dict(full_osd,new_model)>>>new_optim.load_state_dict(sharded_osd)
Note
Both
shard_full_optim_state_dict()andscatter_full_optim_state_dict()may be used to get thesharded optimizer state dict to load. Assuming that the fulloptimizer state dict resides in CPU memory, the former requireseach rank to have the full dict in CPU memory, where each rankindividually shards the dict without any communication, while thelatter requires only rank 0 to have the full dict in CPU memory,where rank 0 moves each shard to GPU memory (for NCCL) andcommunicates it to ranks appropriately. Hence, the former hashigher aggregate CPU memory cost, while the latter has highercommunication cost.- Parameters:
full_optim_state_dict (Dict[str,Any]) – Optimizer state dictcorresponding to the unflattened parameters and holding thefull non-sharded optimizer state.
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallelinstance) whose parameterscorrespond to the optimizer state infull_optim_state_dict.optim_input (Optional[Union[List[Dict[str,Any]],Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a
listof parameter groups or an iterable of parameters;ifNone, then this method assumes the input wasmodel.parameters(). This argument is deprecated, and thereis no need to pass it in anymore. (Default:None)optim (Optional[torch.optim.Optimizer]) – Optimizer that will loadthe state dict returned by this method. This is the preferredargument to use over
optim_input. (Default:None)
- Returns:
The full optimizer state dict now remapped toflattened parameters instead of unflattened parameters andrestricted to only include this rank’s part of the optimizer state.
- Return type:
Dict[str, Any]
- staticsharded_optim_state_dict(model,optim,group=None)[source]#
Return the optimizer state-dict in its sharded form.
The API is similar to
full_optim_state_dict()but this API chunksall non-zero-dimension states toShardedTensorto save memory.This API should only be used when the modelstate_dictis derivedwith the context managerwithstate_dict_type(SHARDED_STATE_DICT):.For the detailed usage, refer to
full_optim_state_dict().Warning
The returned state dict contains
ShardedTensorandcannot be directly used by the regularoptim.load_state_dict.
- staticstate_dict_type(module,state_dict_type,state_dict_config=None,optim_state_dict_config=None)[source]#
Set the
state_dict_typeof all the descendant FSDP modules of the target module.This context manager has the same functions as
set_state_dict_type(). Read the document ofset_state_dict_type()for the detail.Example:
>>>model=DDP(FSDP(...))>>>withFSDP.state_dict_type(>>>model,>>>StateDictType.SHARDED_STATE_DICT,>>>):>>>checkpoint=model.state_dict()
- Parameters:
module (torch.nn.Module) – Root module.
state_dict_type (StateDictType) – the desired
state_dict_typeto set.state_dict_config (Optional[StateDictConfig]) – the model
state_dictconfiguration for the targetstate_dict_type.optim_state_dict_config (Optional[OptimStateDictConfig]) – the optimizer
state_dictconfiguration for the targetstate_dict_type.
- Return type:
- staticsummon_full_params(module,recurse=True,writeback=True,rank0_only=False,offload_to_cpu=False,with_grads=False)[source]#
Expose full params for FSDP instances with this context manager.
Can be usefulafter forward/backward for a model to getthe params for additional processing or checking. It can take a non-FSDPmodule and will summon full params for all contained FSDP modules aswell as their children, depending on the
recurseargument.Note
This can be used on inner FSDPs.
Note
This cannot be used within a forward or backward pass. Norcan forward and backward be started from within this context.
Note
Parameters will revert to their local shards after the contextmanager exits, storage behavior is the same as forward.
Note
The full parameters can be modified, but only the portioncorresponding to the local param shard will persist after thecontext manager exits (unless
writeback=False, in which casechanges will be discarded). In the case where FSDP does not shardthe parameters, currently only whenworld_size==1, orNO_SHARDconfig, the modification is persisted regardless ofwriteback.Note
This method works on modules which are not FSDP themselves butmay contain multiple independent FSDP units. In that case, the givenarguments will apply to all contained FSDP units.
Warning
Note that
rank0_only=Truein conjunction withwriteback=Trueis not currently supported and will raise anerror. This is because model parameter shapes would be differentacross ranks within the context, and writing to them can lead toinconsistency across ranks when the context is exited.Warning
Note that
offload_to_cpuandrank0_only=Falsewillresult in full parameters being redundantly copied to CPU memory forGPUs that reside on the same machine, which may incur the risk ofCPU OOM. It is recommended to useoffload_to_cpuwithrank0_only=True.- Parameters:
recurse (bool,Optional) – recursively summon all params for nestedFSDP instances (default: True).
writeback (bool,Optional) – if
False, modifications to params arediscarded after the context manager exits;disabling this can be slightly more efficient (default: True)rank0_only (bool,Optional) – if
True, full parameters arematerialized on only global rank 0. This means that within thecontext, only rank 0 will have full parameters and the otherranks will have sharded parameters. Note that settingrank0_only=Truewithwriteback=Trueis not supported,as model parameter shapes will be different across rankswithin the context, and writing to them can lead toinconsistency across ranks when the context is exited.offload_to_cpu (bool,Optional) – If
True, full parameters areoffloaded to CPU. Note that this offloading currently onlyoccurs if the parameter is sharded (which is only not the casefor world_size = 1 orNO_SHARDconfig). It is recommendedto useoffload_to_cpuwithrank0_only=Trueto avoidredundant copies of model parameters being offloaded to the same CPU memory.with_grads (bool,Optional) – If
True, gradients are alsounsharded with the parameters. Currently, this is onlysupported when passinguse_orig_params=Trueto the FSDPconstructor andoffload_to_cpu=Falseto this method.(Default:False)
- Return type:
- classtorch.distributed.fsdp.BackwardPrefetch(value)[source]#
This configures explicit backward prefetching, which improves throughput byenabling communication and computation overlap in the backward pass at thecost of slightly increased memory usage.
BACKWARD_PRE: This enables the most overlap but increases memoryusage the most. This prefetches the next set of parametersbefore thecurrent set of parameters’ gradient computation. This overlaps thenextall-gather and thecurrent gradient computation, and at the peak, itholds the current set of parameters, next set of parameters, and currentset of gradients in memory.BACKWARD_POST: This enables less overlap but requires less memoryusage. This prefetches the next set of parametersafter the currentset of parameters’ gradient computation. This overlaps thecurrentreduce-scatter and thenext gradient computation, and it frees thecurrent set of parameters before allocating memory for the next set ofparameters, only holding the next set of parameters and current set ofgradients in memory at the peak.FSDP’s
backward_prefetchargument acceptsNone, which disablesthe backward prefetching altogether. This has no overlap and does notincrease memory usage. In general, we do not recommend this setting sinceit may degrade throughput significantly.
For more technical context: For a single process group using NCCL backend,any collectives, even if issued from different streams, contend for thesame per-device NCCL stream, which implies that the relative order in whichthe collectives are issued matters for overlapping. The two backwardprefetching values correspond to different issue orders.
- classtorch.distributed.fsdp.ShardingStrategy(value)[source]#
This specifies the sharding strategy to be used for distributed training by
FullyShardedDataParallel.FULL_SHARD: Parameters, gradients, and optimizer states are sharded.For the parameters, this strategy unshards (via all-gather) before theforward, reshards after the forward, unshards before the backwardcomputation, and reshards after the backward computation. For gradients,it synchronizes and shards them (via reduce-scatter) after the backwardcomputation. The sharded optimizer states are updated locally per rank.SHARD_GRAD_OP: Gradients and optimizer states are sharded duringcomputation, and additionally, parameters are sharded outsidecomputation. For the parameters, this strategy unshards before theforward, does not reshard them after the forward, and only reshards themafter the backward computation. The sharded optimizer states are updatedlocally per rank. Insideno_sync(), the parameters are not reshardedafter the backward computation.NO_SHARD: Parameters, gradients, and optimizer states are not shardedbut instead replicated across ranks similar to PyTorch’sDistributedDataParallelAPI. For gradients, this strategysynchronizes them (via all-reduce) after the backward computation. Theunsharded optimizer states are updated locally per rank.HYBRID_SHARD: ApplyFULL_SHARDwithin a node, and replicate parameters acrossnodes. This results in reduced communication volume as expensive all-gathers andreduce-scatters are only done within a node, which can be more performant for medium-sized models._HYBRID_SHARD_ZERO2: ApplySHARD_GRAD_OPwithin a node, and replicate parameters acrossnodes. This is likeHYBRID_SHARD, except this may provide even higher throughputsince the unsharded parameters are not freed after the forward pass, saving theall-gathers in the pre-backward.
- classtorch.distributed.fsdp.MixedPrecision(param_dtype=None,reduce_dtype=None,buffer_dtype=None,keep_low_precision_grads=False,cast_forward_inputs=False,cast_root_forward_inputs=True,_module_classes_to_ignore=(<class'torch.nn.modules.batchnorm._BatchNorm'>,))[source]#
This configures FSDP-native mixed precision training.
- Variables:
param_dtype (Optional[torch.dtype]) – This specifies the dtype for modelparameters during forward and backward and thus the dtype forforward and backward computation. Outside forward and backward, thesharded parameters are kept in full precision (e.g. for theoptimizer step), and for model checkpointing, the parameters arealways saved in full precision. (Default:
None)reduce_dtype (Optional[torch.dtype]) – This specifies the dtype forgradient reduction (i.e. reduce-scatter or all-reduce). If this is
Nonebutparam_dtypeis notNone, then this takes ontheparam_dtypevalue, still running gradient reduction in lowprecision. This is permitted to differ fromparam_dtype, e.g.to force gradient reduction to run in full precision. (Default:None)buffer_dtype (Optional[torch.dtype]) – This specifies the dtype forbuffers. FSDP does not shard buffers. Rather, FSDP casts them to
buffer_dtypein the first forward pass and keeps them in thatdtype thereafter. For model checkpointing, the buffers are savedin full precision except forLOCAL_STATE_DICT. (Default:None)keep_low_precision_grads (bool) – If
False, then FSDP upcastsgradients to full precision after the backward pass in preparationfor the optimizer step. IfTrue, then FSDP keeps the gradientsin the dtype used for gradient reduction, which can save memory ifusing a custom optimizer that supports running in low precision.(Default:False)cast_forward_inputs (bool) – If
True, then this FSDP module castsits forward args and kwargs toparam_dtype. This is to ensurethat parameter and input dtypes match for forward computation, asrequired by many ops. This may need to be set toTruewhen onlyapplying mixed precision to some but not all FSDP modules, in whichcase a mixed-precision FSDP submodule needs to recast its inputs.(Default:False)cast_root_forward_inputs (bool) – If
True, then the root FSDP modulecasts its forward args and kwargs toparam_dtype, overridingthe value ofcast_forward_inputs. For non-root FSDP modules,this does not do anything. (Default:True)_module_classes_to_ignore (collections.abc.Sequence[type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): This specifiesmodule classes to ignore for mixed precision when using an
auto_wrap_policy: Modules of these classes will have FSDPapplied to them separately with mixed precision disabled (meaningthat the final FSDP construction would deviate from the specifiedpolicy). Ifauto_wrap_policyis not specified, then this doesnot do anything. This API is experimental and subject to change.(Default:(_BatchNorm,))
Note
This API is experimental and subject to change.
Note
Only floating point tensors are cast to their specified dtypes.
Note
In
summon_full_params, parameters are forced to fullprecision, but buffers are not.Note
Layer norm and batch norm accumulate in
float32even whentheir inputs are in a low precision likefloat16orbfloat16.Disabling FSDP’s mixed precision for those norm modules only means thatthe affine parameters are kept infloat32. However, this incursseparate all-gathers and reduce-scatters for those norm modules, whichmay be inefficient, so if the workload permits, the user should preferto still apply mixed precision to those modules.Note
By default, if the user passes a model with any
_BatchNormmodules and specifies anauto_wrap_policy, then the batch normmodules will have FSDP applied to them separately with mixed precisiondisabled. See the_module_classes_to_ignoreargument.Note
MixedPrecisionhascast_root_forward_inputs=Trueandcast_forward_inputs=Falseby default. For the root FSDP instance,itscast_root_forward_inputstakes precedence over itscast_forward_inputs. For non-root FSDP instances, theircast_root_forward_inputsvalues are ignored. The default setting issufficient for the typical case where each FSDP instance has the sameMixedPrecisionconfiguration and only needs to cast inputs to theparam_dtypeat the beginning of the model’s forward pass.Note
For nested FSDP instances with different
MixedPrecisionconfigurations, we recommend setting individualcast_forward_inputsvalues to configure casting inputs or not before each instance’sforward. In such a case, since the casts happen before each FSDPinstance’s forward, a parent FSDP instance should have its non-FSDPsubmodules run before its FSDP submodules to avoid the activation dtypebeing changed due to a differentMixedPrecisionconfiguration.Example:
>>>model=nn.Sequential(nn.Linear(3,3),nn.Linear(3,3))>>>model[1]=FSDP(>>>model[1],>>>mixed_precision=MixedPrecision(param_dtype=torch.float16,cast_forward_inputs=True),>>>)>>>model=FSDP(>>>model,>>>mixed_precision=MixedPrecision(param_dtype=torch.bfloat16,cast_forward_inputs=True),>>>)
The above shows a working example. On the other hand, if
model[1]were replaced withmodel[0], meaning that the submodule usingdifferentMixedPrecisionran its forward first, thenmodel[1]would incorrectly seefloat16activations instead ofbfloat16ones.
- classtorch.distributed.fsdp.CPUOffload(offload_params=False)[source]#
This configures CPU offloading.
- Variables:
offload_params (bool) – This specifies whether to offload parameters toCPU when not involved in computation. If
True, then thisoffloads gradients to CPU as well, meaning that the optimizer stepruns on CPU.
- classtorch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[source]#
StateDictConfigis the base class for allstate_dictconfigurationclasses. Users should instantiate a child class (e.g.FullStateDictConfig) in order to configure settings for thecorrespondingstate_dicttype supported by FSDP.- Variables:
offload_to_cpu (bool) – If
True, then FSDP offloads the state dictvalues to CPU, and ifFalse, then FSDP keeps them on GPU.(Default:False)
- classtorch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False,rank0_only=False)[source]#
FullStateDictConfigis a config class meant to be used withStateDictType.FULL_STATE_DICT. We recommend enabling bothoffload_to_cpu=Trueandrank0_only=Truewhen saving full statedicts to save GPU memory and CPU memory, respectively. This config classis meant to be used via thestate_dict_type()context manager asfollows:>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>fsdp=FSDP(model,auto_wrap_policy=...)>>>cfg=FullStateDictConfig(offload_to_cpu=True,rank0_only=True)>>>withFSDP.state_dict_type(fsdp,StateDictType.FULL_STATE_DICT,cfg):>>>state=fsdp.state_dict()>>># `state` will be empty on non rank 0 and contain CPU tensors on rank 0.>>># To reload checkpoint for inference, finetuning, transfer learning, etc:>>>model=model_fn()# Initialize model in preparation for wrapping with FSDP>>>ifdist.get_rank()==0:>>># Load checkpoint only on rank 0 to avoid memory redundancy>>>state_dict=torch.load("my_checkpoint.pt")>>>model.load_state_dict(state_dict)>>># All ranks initialize FSDP module as usual. `sync_module_states` argument>>># communicates loaded checkpoint states from rank 0 to rest of the world.>>>fsdp=FSDP(...model,...device_id=torch.cuda.current_device(),...auto_wrap_policy=...,...sync_module_states=True,...)>>># After this point, all ranks have FSDP model with loaded checkpoint.
- Variables:
rank0_only (bool) – If
True, then only rank 0 saves the full statedict, and nonzero ranks save an empty dict. IfFalse, then allranks save the full state dict. (Default:False)
- classtorch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False,_use_dtensor=False)[source]#
ShardedStateDictConfigis a config class meant to be used withStateDictType.SHARDED_STATE_DICT.- Variables:
_use_dtensor (bool) – If
True, then FSDP saves the state dict valuesasDTensor, and ifFalse, then FSDP saves them asShardedTensor. (Default:False)
Warning
_use_dtensoris a private field ofShardedStateDictConfigand it is used by FSDP to determine the type of state dict values. Users should notmanually modify_use_dtensor.
- classtorch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[source]#
OptimStateDictConfigis the base class for alloptim_state_dictconfiguration classes. Users should instantiate a child class (e.g.FullOptimStateDictConfig) in order to configure settings for thecorrespondingoptim_state_dicttype supported by FSDP.- Variables:
offload_to_cpu (bool) – If
True, then FSDP offloads the state dict’stensor values to CPU, and ifFalse, then FSDP keeps them on theoriginal device (which is GPU unless parameter CPU offloading isenabled). (Default:True)
- classtorch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True,rank0_only=False)[source]#
- Variables:
rank0_only (bool) – If
True, then only rank 0 saves the full statedict, and nonzero ranks save an empty dict. IfFalse, then allranks save the full state dict. (Default:False)
- classtorch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True,_use_dtensor=False)[source]#
ShardedOptimStateDictConfigis a config class meant to be used withStateDictType.SHARDED_STATE_DICT.- Variables:
_use_dtensor (bool) – If
True, then FSDP saves the state dict valuesasDTensor, and ifFalse, then FSDP saves them asShardedTensor. (Default:False)
Warning
_use_dtensoris a private field ofShardedOptimStateDictConfigand it is used by FSDP to determine the type of state dict values. Users should notmanually modify_use_dtensor.
- classtorch.distributed.fsdp.StateDictSettings(state_dict_type:torch.distributed.fsdp.api.StateDictType,state_dict_config:torch.distributed.fsdp.api.StateDictConfig,optim_state_dict_config:torch.distributed.fsdp.api.OptimStateDictConfig)[source]#