Distributed Checkpoint - torch.distributed.checkpoint#
Created On: Nov 16, 2022 | Last Updated On: Oct 08, 2025
Distributed Checkpoint (DCP) support loading and saving models from multiple ranks in parallel.It handles load-time resharding which enables saving in one cluster topology and loading into another.
DCP is different thantorch.save andtorch.load in a few significant ways:
It produces multiple files per checkpoint, with at least one per rank.
It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.
The entrypoints to load and save a checkpoint are the following:
Additional resources:#
- classtorch.distributed.checkpoint.state_dict_saver.AsyncCheckpointerType(value)[source]#
Enum for async checkpointer type.
- classtorch.distributed.checkpoint.state_dict_saver.AsyncSaveResponse(staging_completion,upload_completion)[source]#
This class contains futures for staging and upload completion.It is returned by async_save().staging_completion is a future that indicates when local copyof state_dict is complete.upload_completion is a future that indicates when a checkpointcompleted saving.
- torch.distributed.checkpoint.state_dict_saver.save(state_dict,*,checkpoint_id=None,storage_writer=None,planner=None,process_group=None,no_dist=False,use_collectives=True)[source]#
Save a distributed model in SPMD style.
This function is different from
torch.save()as it handlesShardedTensor, andDTensorby having each rank only save their local shards.For each
Statefulobject (having both astate_dictand aload_state_dict),save will callstate_dictbefore serialization.Warning
There is no guarantees of Backwards Compatibility across PyTorch versionsfor saved state_dicts.
Warning
If using theprocess_group argument, make sure that only its rankscallsave_state_dict and that all data in state_dict belong to it.
Note
When saving checkpoint for FSDP’sShardingStrategy.HYBRID_SHARD, only one ofthe shard_group should be callingsave_state_dict and the corresponding processgroup needs to be passed in.
Note
- If no process group is available, this function assumes the intention is to save the
state_dict in the local process.
- Parameters:
state_dict (Dict[str,Any]) – The state_dict to save.
checkpoint_id (Union[str,os.PathLike,None]) – The ID of this checkpoint instance. The meaning of the checkpoint_iddepends on the storage. It can be a path to a folder or to a file.It can also be a key if the storage is a key-value store.(Default:
None)storage_writer (Optional[StorageWriter]) – Instance of StorageWriter used to perform writes. If this is notspecified, DCP will automatically infer the writer based on thecheckpoint_id. If checkpoint_id is also None, an exception willbe raised. (Default:
None)planner (Optional[SavePlanner]) – Instance of SavePlanner. If this is not specified, the defaultplanner will be used. (Default:
None)process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization.(Default:
None)no_dist (bool) – If
True, this function will assume the intent is to loada checkpoint on a single rank/process.(Default:False)use_collectives (bool) – If
False, this function will assume the intent is to savea checkpoint without using cross-rank synchronization.(Default:True)This configuration is experimental and should be used with caution.It will change the format of the saved checkpoint and may not be backward compatible.
- Returns:
Metadata object for the saved checkpoint.
- Return type:
Metadata
Example
>>>my_model=MyModule()
>>>state_dict={"model":my_model}
>>>fs_storage_writer=torch.distributed.checkpoint.FileSystemWriter(..."/checkpoint/1"...)>>>torch.distributed.checkpoint.save(>>>state_dict=state_dict,>>>storage_writer=fs_storage_writer,>>>)
Note
save_state_dict uses collectives to coordinate writes across ranks.For NCCL-based process groups, internal tensor representations ofobjects must be moved to the GPU device before communication takes place.In this case, the device used is given by
torch.cuda.current_device()and it is the user’s responsibility to ensure that this is set so thateach rank has an individual GPU, viatorch.cuda.set_device().
- torch.distributed.checkpoint.state_dict_saver.async_save(state_dict,*,checkpoint_id=None,storage_writer=None,planner=None,process_group=None,async_checkpointer_type=AsyncCheckpointerType.THREAD,async_stager=None,no_dist=False,use_collectives=True)[source]#
Asynchronous version of
save. This code first de-stages the state_dict on to thestaging storage (defaults to CPU memory), and then calls thesave in a separate thread.Warning
This feature is experimental and subject to change.MUST CALL CLOSE AFTER LAST CHECKPOINT IS SAVED
- Parameters:
state_dict (Dict[str,Any]) – The state_dict to save.
checkpoint_id (Union[str,os.PathLike,None]) – The ID of this checkpoint instance. The meaning of the checkpoint_iddepends on the storage. It can be a path to a folder or to a file.It can also be a key if the storage is a key-value store.(Default:
None)storage_writer (Optional[StorageWriter]) – Instance of StorageWriter used to perform ‘stage’ and ‘save’. Ifthis is not specified, DCP will automatically infer the writer based on thecheckpoint_id. If checkpoint_id is also None, an exception willbe raised. (Default:
None)planner (Optional[SavePlanner]) – Instance of SavePlanner. If this is not specified, the defaultplanner will be used. (Default:
None)process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization.(Default:
None)async_checkpointer_type (AsyncCheckpointerType) – whether to do checkpoint in separate thread or process(Default:
AsyncCheckpointerType.THREAD)async_stager (AsyncStager) – provides staging implementation. If storage_writer implements AsyncStagerand async_stager is provided, async_stager will be used for staging
no_dist (bool) – If
True, this function will assume the intent is to savea checkpoint on a single rank/process.(Default:False)use_collectives (bool) – If False, Save the checkpoint without rank coordination. (Default:
True)This configuration is experimental and should be used with caution.It will change the format of the saved checkpoint and may not be backward compatible.
- Returns:
A future holding the resultant Metadata object fromsave.
- Return type:
Example
>>>my_model=MyModule()
>>>state_dict={"model":my_model}
>>>fs_storage_writer=torch.distributed.checkpoint.FileSystemWriter(..."/checkpoint/1"...)>>>checkpoint_future=torch.distributed.checkpoint.async_save(>>>state_dict=state_dict,>>>storage_writer=fs_storage_writer,>>>)>>>>>># ... do some work ...>>>>>>checkpoint_future.result()
- torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict,storage_writer,process_group=None,coordinator_rank=0,no_dist=False,planner=None)[source]#
This method is deprecated. Please switch to ‘save’.
- Return type:
Metadata
- torch.distributed.checkpoint.state_dict_loader.load(state_dict,*,checkpoint_id=None,storage_reader=None,planner=None,process_group=None,no_dist=False)[source]#
Load a checkpoint into a distributed state dict in SPMD style.
Each rank must have the same keys in their
state_dictprovided to thisAPI. Mismatched keys may result in hangs or errors. If unsure, you can usetheutils._assert_same_keysAPI to check (but may incur communicationcosts).Each rank will try to read the least amount of data necessaryto fulfill the requestedstate_dict. When loading
ShardedTensororDTensorinstances, each rank only reads data for their local shards.For each
Statefulobject (having both astate_dictand aload_state_dict),load will first callstate_dictbefore attempting deserialization, followed byload_state_dictonce the deserialization is complete.For each non-Statefulobject, load will deserialize the object, and then replaceit in thestate_dictwith the deserialized object.Warning
All tensors in
state_dictmust be allocated on theirdestination deviceprior to calling this function.All non-tensor data is loaded usingtorch.load() and modified in placeon state_dict.
Warning
Users must callload_state_dict on the root module to ensure loadpos-processing and non-tensor data properly propagates.
- Parameters:
state_dict (Dict[str,Any]) – The state_dict to load the checkpoint into.
checkpoint_id (Union[str,os.PathLike,None]) – The ID of this checkpoint instance. The meaning of the checkpoint_iddepends on the storage. It can be a path to a folder or to a file.It can also be a key if the storage is a key-value store.(Default:
None)storage_reader (Optional[StorageReader]) – Instance of StorageWriter used to perform reads. If this is notspecified, DCP will automatically infer the reader based on thecheckpoint_id. If checkpoint_id is also None, an exception willbe raised. (Default:
None)planner (Optional[LoadPlanner]) – Instance of LoadPlanner. If this is not specified, the defaultplanner will be used. (Default:
None)process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization.(Default:
None)no_dist (bool) – If
True, this function will assume the intent is to loada checkpoint without using cross-rank synchronization. (Default:False)
- Returns:
None.
- Return type:
None
- Examples
>>>my_model=MyModule()>>>optimizer=Adagrad(my_model.parameters())>>>model_state_dict=my_model.state_dict()>>>fs_storage_reader=torch.distributed.checkpoint.FileSystemReader(..."/checkpoint/1"...)
>>>torch.distributed.checkpoint.load_state_dict(>>>state_dict=model_state_dict,>>>storage_reader=fs_storage_reader,>>>)
>>># module.load_state_dict() function might have customized steps>>># to flush the state_dict, must call it to>>># ensure correct behavior.>>>my_model.load_state_dict(model_state_dict)
Note
load_state_dict uses collectives to coordinate reads across ranks.For NCCL-based process groups, internal tensor representations ofobjects must be moved to the GPU device before communication takes place.In this case, the device used is given by
torch.cuda.current_device()and it is the user’s responsibility to ensure that this is set so that eachrank has an individual GPU, viatorch.cuda.set_device().
- torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict,storage_reader,process_group=None,coordinator_rank=0,no_dist=False,planner=None)[source]#
This method is deprecated. Please switch to ‘load’.
The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (torch.distributed.checkpoint.async_save):
- classtorch.distributed.checkpoint.staging.AsyncStager(*args,**kwargs)[source]#
This protocol is meant to provide customization and extensibility for dcp.async_save, allowing usersto customize how data is staged previous to executing the usual dcp.save path in parallel.The expected order of operations (concretely defined intorch.distributed.state_dict_saver.async_save)is the following:
- AsyncStager.stage_data(state_dict):
This call gives the AsyncStager the opportunity to ‘stage’the state_dict. The expectation and purpose of staging in this context is to create a “training-safe”representation of the state dict, meaning that any updates to module data after staging is completeshould not be reflected in the state dict returned from this method. For example, in the defaultcase a copy of the entire state dict is created on CPU RAM and returned here, allowing usersto continue training without risking changes to data which is being serialized.
- dcp.save is called on the state_dict returned from stage in parallel. This call is responsible
for serializing the state_dict and writing it to storage.
- If AsyncStager.should_synchronize_after_execute is True, this method will be called immediately after
the serialization thread starts and before returning from dcp.async_save. If this is set to False,the assumption is the user has defined a custom synchronization point for the purpose of furtheroptimizing save latency in the training loop (for example, by overlapping staging with theforward/backward pass), and it is the respondsibility of the user to callAsyncStager.synchronize_stagingat the appropriate time.
- classtorch.distributed.checkpoint.staging.DefaultStager(config=StagingOptions(use_pinned_memory=True,use_shared_memory=True,use_async_staging=True,use_non_blocking_copy=True))[source]#
DefaultStager provides a full-featured staging implementation that combinesmultiple optimization techniques for efficient checkpoint preparation.
The staging process works as follows:1. State dictionary is submitted for staging (sync or async)2. Tensors are copied from GPU to optimized CPU storage3. CUDA operations are synchronized if non-blocking copies are used4. Staged state dictionary is returned or made available via Future
- Usage Patterns:
# Synchronous stagingstager = DefaultStager(StagingOptions(use_async_staging=False))staged_dict = stager.stage(state_dict)stager.close()
# Asynchronous stagingstager = DefaultStager(StagingOptions(use_async_staging=True))future = stager.stage(state_dict)# … do other work …staged_dict = future.result()stager.close()
# Context manager pattern (recommended)stager = DefaultStager(config)with stager:result = stager.stage(state_dict)
- Performance Considerations:
Async staging provides best performance when model computationcan overlap with staging operations
Pinned memory improves CPU-GPU transfer speeds but uses more memory
Shared memory allows efficient IPC to checkpoint process
Non-blocking copies reduce GPU idle time during memory transfers
- Thread Safety:
DefaultStager is not thread-safe. Each thread should use its owninstance, or external synchronization should be provided.
- close()[source]#
Clean up all resources used by the DefaultStager. Shuts down the ThreadPoolExecutorused for async staging operations and cleans up the underlying StateDictStager’scached storages. Should be called when the stager is no longer needed to preventresource leaks, especially in long-running applications. After calling close(),the stager should not be used for further staging operations.
- Example Usage:
stager = DefaultStager(StagingOptions(use_async_staging=True))future = stager.stage(state_dict)result = future.result()stager.close() # Clean up all resources
- stage(state_dict,**kwargs)[source]#
This function is responsible for staging staging the state_dict.See class docstring for more details on staging.If use_async_staging is True, it will return a Future object that will befulfilled when staging is complete.If use_async_staging is False, it will return the fully staged state_dict.
- classtorch.distributed.checkpoint.staging.StagingOptions(use_pinned_memory=True,use_shared_memory=True,use_async_staging=True,use_non_blocking_copy=True)[source]#
Configuration options for checkpoint staging behavior.
- Variables:
use_pinned_memory (bool) – Enable pinned memory allocation for fasterCPU-GPU transfers. Requires CUDA to be available. Default: True
use_shared_memory (bool) – Enable shared memory for multi-processscenarios. Useful when multiple processes need access to thesame staged data. Default: True
use_async_staging (bool) – Enable asynchronous staging using abackground thread pool. Allows overlapping computation withstaging operations. Requires CUDA. Default: True
use_non_blocking_copy (bool) – Use non-blocking device memorycopies with stream synchronization. Improves performance byallowing CPU work to continue during GPU transfers. Default: True
Note
CUDA-dependent features will raise exception if CUDA is not available.
- classtorch.distributed.checkpoint.staging.BlockingAsyncStager(cache_staged_state_dict=False,type_check=False)[source]#
An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete.This implementation also provides an option to optimize stage latency using pinned memory.
N.B. synchronize_staging is a no-op in this case.
In addition to the above entrypoints,Stateful objects, as described below, provide additional customization during saving/loading
- classtorch.distributed.checkpoint.stateful.Stateful(*args,**kwargs)[source]#
Stateful protocol for objects that can be checkpointed and restored.
- state_dict()[source]#
Objects should return their state_dict representation as a dictionary.The output of this function will be checkpointed, and later restored inload_state_dict().
Warning
Because of the inplace nature of restoring a checkpoint, this functionis also called duringtorch.distributed.checkpoint.load.
- Returns:
The objects state dict
- Return type:
Dict
Thisexample shows how to use Pytorch Distributed Checkpoint to save a FSDP model.
The following types define the IO interface used during checkpoint:
- classtorch.distributed.checkpoint.StorageReader[source]#
Interface used by
load_state_dictto read from storage.One StorageReader instance acts as both the coordinator and the followerin a distributed checkpoint. As part of initialization, each instanceis told its role.
A subclass should expected the following sequence of calls by
load_state_dict:(all ranks) set checkpoint_id if users pass a valid checkpoint_id.
(all ranks) read_metadata()
(all ranks) set_up_storage_reader()
(all ranks) prepare_local_plan()
(coordinator) prepare_global_plan()
(all ranks) read_data()
- abstractprepare_global_plan(plans)[source]#
Perform centralized planning of storage loading.
This method is only called on the coordinator instance.
While this method can produce a completely different plan, the preferredway is to store storage specific data in LoadPlan::storage_data.
- abstractprepare_local_plan(plan)[source]#
Perform storage-specific local planning.
While this method can produce a completely different plan, the recommendedway is to store storage specific data in LoadPlan::storage_data.
- abstractread_data(plan,planner)[source]#
Read all items from
planusingplannerto resolve the data.A subclass should call
LoadPlanner::load_bytesto deserialize a BytesIOobject into the right place.A subclass should call
LoadPlanner::resolve_tensorto get access to thetensors that in should load data into.It’s the StorageLayer responsibility to properly schedule any cross device copiesrequired.
- Parameters:
plan (LoadPlan) – The local plan to execute on
planner (LoadPlanner) – The planner object to use to resolve items.
- Returns:
A future that completes once all reads are finished.
- Return type:
Future[None]
- abstractread_metadata(*args,**kwargs)[source]#
Read the checkpoint metadata.
- Returns:
The metadata object associated with the checkpoint being loaded.
- Return type:
Metadata
- abstractreset(checkpoint_id=None)[source]#
Calls to indicates a brand new checkpoint read is going to happen.A checkpoint_id may be present if users set the checkpoint_id forthis checkpoint read. The meaning of the checkpiont_id isstorage-dependent. It can be a path to a folder/file or a key fora key-value storage.
- Parameters:
checkpoint_id (Union[str,os.PathLike,None]) – The ID of this checkpoint instance. The meaning of the checkpoint_iddepends on the storage. It can be a path to a folder or to a file.It can also be a key if the storage is more like a key-value store.(Default:
None)
- classtorch.distributed.checkpoint.StorageWriter[source]#
Interface used by
save_state_dictto write to storage.One StorageWriter instance acts as both the coordinator and the followerin a distributed checkpoint. As part of initialization, each instanceis told its role.
A subclass should expect the following sequence of calls.
(all ranks) set checkpoint_id if users pass a valid checkpoint_id.
(all ranks) set_up_storage_writer()
(all ranks) prepare_local_plan()
(coordinator) prepare_global_plan()
(all ranks) write_data()
(coordinator) finish()
- abstractfinish(metadata,results)[source]#
Write the metadata and marks the current checkpoint as successful.
The actual format/schema used for serializingmetadata is animplementation detail. The only requirement is that it’s recoverablein to the same object graph.
- abstractprepare_global_plan(plans)[source]#
Perform centralized planning of storage.
This method is only called on the coordinator instance.
While this method can produce a completely different plan, the preferredway is to store storage specific data in SavePlan::storage_data.
- abstractprepare_local_plan(plan)[source]#
Perform storage-specific local planning.
While this method can produce a completely different plan, the recommendedway is to store storage specific data in SavePlan::storage_data.
- abstractreset(checkpoint_id=None)[source]#
Calls to indicates a brand new checkpoint write is going to happen.A checkpoint_id may be present if users set the checkpoint_id forthis checkpoint write. The meaning of the checkpiont_id isstorage-dependent. It can be a path to a folder/file or a key fora key-value storage.
- Parameters:
checkpoint_id (Union[str,os.PathLike,None]) – The ID of this checkpoint instance. The meaning of the checkpoint_iddepends on the storage. It can be a path to a folder or to a file.It can also be a key if the storage is a key-value store.(Default:
None)
- abstractset_up_storage_writer(is_coordinator,*args,**kwargs)[source]#
Initialize this instance.
- Parameters:
is_coordinator (bool) – Whether this instance is responsible for coordinatingthe checkpoint.
- storage_meta()[source]#
Return the storage-specific metadata. This is used to store additional informationin a checkpoint that can be useful for providing request-level observability. StorageMetais passed to the
SavePlannerduring save calls. Returns None by default.TODO: provide an example
- Return type:
StorageMeta | None
- abstractclassmethodvalidate_checkpoint_id(checkpoint_id)[source]#
Check if the given checkpoint_id is supported by the storage. This allowus to enable automatic storage selection.
- Return type:
- abstractwrite_data(plan,planner)[source]#
Write all items from
planusingplannerto resolve the data.A subclass should call
SavePlanner::resolve_dataon each itemfrom the plan to get access to the underlying object to write.Subclasses should lazily callresolve_data as it can allocate memory.In case of tensors, make following assumptions:
They might be on any device, including not matching the one on
WriteItem::tensor_dataThey might be views or not contiguous. Only the projection needs to be saved.
- Parameters:
plan (SavePlan) – The save plan to execute.
planner (SavePlanner) – Planner object to be used to resolve items to data.
- Returns:
A future that completes to a list of WriteResult
- Return type:
The following types define the planner interface used during checkpoint:
- classtorch.distributed.checkpoint.LoadPlanner[source]#
Abstract class defining the protocol used by load_state_dict to plan the load process.
LoadPlanner are stateful objects that can be used to customize the whole load process.
LoadPlanner acts as an access proxy to the state_dict, so any transformation done to itwill be visible to the whole process.
A planner subclass can expect the following sequence of calls during load_state_dict:
- set_up_planner - called on all ranks.
Signals the start of loading a checkpoint.
- create_local_plan - called on all ranks.
Process the state_dict and produces aLoadPlan that will be sent for global planning.
- create_global_plan - called on the coordinator rank only.
Takes the LoadPlan from all ranks and make any global decision.
- load_bytes - called multiple times on each rank
This is called once per non-tensor value in state_dict.
- resolve_tensor and commit_tensor - called multiple times on each rank
They are called in pair for each Tensor value in state_dict.
Users are recommended to extend DefaultLoadPlanner instead of this interface directly asmost changes can be expressed by changes in a single method.
There are two usual patterns of extension:
Rewriting state_dict. This is the simplest way to extend the load process as itdoesn’t requite understanding the intrincacies of how LoadPlan works. We needto keep a reference to the original state_dict as load happens in place sowe need to be able to perform it in place
>>>classRenamePlanner(DefaultLoadPlanner):>>>defset_up_planner(>>>self,>>>state_dict:STATE_DICT_TYPE,>>>metadata:Metadata,>>>is_coordinator:bool,>>>)->None:>>>self.original_state_dict=state_dict>>>state_dict={"foo_"+k:vfork,vinstate_dict.items()}>>>>>>ifself.flatten_sharded_tensors:>>>state_dict=_flatten_sharded_tensors(state_dict)>>>>>>ifself.flatten_state_dict:>>>state_dict,self.mappings=flatten_state_dict(state_dict)>>>>>>self.state_dict=state_dict>>>self.metadata=metadata>>>self.is_coordinator=is_coordinator>>>>>>defload_bytes(self,read_item,value):>>># Remove the "foo_" prefix>>>self.original_state_dict[read_item.dest_index.fqn[4:]]=torch.load(value,weights_only=False)
Modifying resolve_tensor and commit_tensor to handle load time transformation.
>>>classMetaModelMaterialize(DefaultSavePlanner):>>>defresolve_tensor(self,read_item):>>>tensor=super().resolve_tensor(read_item)>>>returntorch.empty_like(tensor,device="cpu")>>>>>>defcommit_tensor(self,read_item,tensor):>>>self.state_dict[read_item.dest_index.fqn]=tensor
- abstractcommit_tensor(read_item,tensor)[source]#
Call once the StorageReader finished loading data into
tensor.The provided tensor is the same one returned by the call to
resolve_tensor.This method is only needed if this LoadPlanner needs to post processtensorprior tocopying it back to the one in the state_dict.The contents of tensor will follow its device synchronization model.
- abstractcreate_global_plan(global_plan)[source]#
Compute the global load plan and return plans for each rank.
. N.B. This is called on the coordinator rank only
- abstractcreate_local_plan()[source]#
Create a LoadPlan based on state_dict and metadata provided by set_up_planner.
. N.B. This is called on every rank.
- Return type:
- abstractfinish_plan(central_plan)[source]#
Accept the plan from coordinator and return final LoadPlan.
- Return type:
- abstractload_bytes(read_item,value)[source]#
Load the item described by
read_item``and``value.This method is expected to modify in-place the underlying state_dict.
The contents of
valueare defined by the SavePlanner used to producethe checkpoint being loaded.
- resolve_bytes(read_item)[source]#
Return the BytesIO to be used by the StorageReader to loadread_item.
The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents.
- Return type:
BytesIO
- abstractresolve_tensor(read_item)[source]#
Return the tensor described by
read_itemto be used by the StorageReader to loadread_item.The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents.If, for any reason, that’s not possible, the planner can use the
commit_tensormethod to copy the databack to the one in state_dict.- Return type:
- classtorch.distributed.checkpoint.LoadPlan(items:list[torch.distributed.checkpoint.planner.ReadItem],storage_data:Any=None,planner_data:Any=None)[source]#
- classtorch.distributed.checkpoint.ReadItem(type:torch.distributed.checkpoint.planner.LoadItemType,dest_index:torch.distributed.checkpoint.metadata.MetadataIndex,dest_offsets:torch.Size,storage_index:torch.distributed.checkpoint.metadata.MetadataIndex,storage_offsets:torch.Size,lengths:torch.Size)[source]#
- classtorch.distributed.checkpoint.SavePlanner[source]#
Abstract class defining the protocol used by save_state_dict to plan the save process.
SavePlanners are stateful objects that can be used to customize the whole save process.
SavePlanner acts as an access proxy to the state_dict, so any transformation done to itwill be visible to the whole process.
A planner subclass can expect the following sequence of calls during save_state_dict:
- set_up_planner - called on all ranks.
Signals the start of a checkpoint save.
- create_local_plan - called on all ranks.
Process the state_dict and produces aSavePlan that will be sent for global planning.
- create_global_plan - called on the coordinator rank only.
Takes the SavePlan from all ranks and make any global decision.
- finish_plan - called on all ranks.
This gives each rank a chance to adjust to global planning decisions.
- resolve_data - called multiple times on each rank
Lookups a value on thestate_dict for the storage layer to write.
Users are recommended to extend DefaultSavePlanner instead of this interface directly asmost changes can be expressed by changes in a single method.
There are 3 usual patterns of extension:
Rewriting state_dict. This is the simplest way to extend the save process as itdoesn’t requite understanding the intrincacies of how SavePlan works:
>>>classRenamePlanner(DefaultSavePlanner):>>>defset_up_planner(>>>self,>>>state_dict:STATE_DICT_TYPE,>>>storage_meta:Optional[StorageMeta],>>>is_coordinator:bool,>>>)->None:>>># prefix all keys with `foo_``>>>super().set_up_planner({"foo_"+k:vfork,vinstate_dict.items()},storage_meta,is_coordinator)
Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
>>>classFP16Planner(DefaultSavePlanner):>>>defcreate_local_plan(self):>>>plan=super().create_local_plan()>>>forpinplan:>>>ifp.tensor_dataisnotNone:>>>p.tensor_data.properties.dtype=torch.float16>>>returnplan>>>>>>defresolve_data(self,write_item):>>>item=super().resolve_data(write_item)>>>returnitemifwrite_item.type==WriteItemType.BYTE_IOelseitem.to(torch.float16)
Using the global planning step to make central decisions that can’t be made individually by each rank
>>>fromitertoolsimportzip_longest>>>fromdataclassesimportreplace>>>classDDPLoadBalancingPlanner(DefaultSavePlanner):>>># This uses the default local plan behavior of having all non-sharded writes in rank 0>>># This sample doesn't handle ShardedTensors>>>defcreate_global_plan(self,all_plans):>>>iters=[iter(all_plans[0].items)]*len(all_plans)>>>items_per_rank=[>>>[itemforiteminitemsifitemisnotNone]>>>foritemsinzip(*zip_longest(*iters),strict=True)>>>]>>>all_plans=[>>>replace(plan,items=items)>>>forplan,itemsinzip(all_plans,items_per_rank,strict=True)>>>]>>>returnsuper().create_global_plan(all_plans)
Finally, some planners need to save additional metadata in the checkpoint, this isaccomplished by having each rank contribute their data items in the local plan andthe global planner aggregate them:
>>>classSaveExtraDataPlanner(DefaultSavePlanner):>>>defcreate_local_plan(self)->SavePlan:>>>plan=super().create_local_plan()>>>returnreplace(plan,planner_data="per-rank-data")>>>>>>defcreate_global_plan(self,all_plans:List[SavePlan])->Tuple[List[SavePlan],Metadata]:>>>global_plan,metadata=super().create_global_plan(all_plans)>>>merged_data=[p.planner_dataforpinglobal_plan]>>>metadata=replace(metadata,planner_data=merged_data)>>>returnglobal_plan,metadata
- abstractcreate_global_plan(all_plans)[source]#
Compute the global checkpoint plan and return the local plan of each rank.
This is called on the coordinator rank only.
- abstractcreate_local_plan()[source]#
Compute the save plan for the current rank.
This will be aggregated and passed to create_global_plan.Planner specific data can be passed through SavePlan::planner_data.
This is called on all ranks.
- Return type:
- abstractfinish_plan(new_plan)[source]#
Merge the plan created bycreate_local_plan and the result ofcreate_global_plan.
This is called on all ranks.
- Return type:
- abstractresolve_data(write_item)[source]#
Transform and prepare
write_itemfromstate_dictfor storage, ensuring idempotency and thread-safety.Lookup the object associated with
write_iteminstate_dictand apply anytransformation (such as serialization) prior to the storage layer consuming it.Called on each rank multiple times, at least once per WriteItem in the final SavePlan.
This method should be idempotent and thread-save. StorageWriter implementationsare free to call it as frequently as they need.
Any transformation that allocates memory should be lazily done when his methodis called in order to reduce peak memory required by checkpointing.
When returning tensors, they can be on any device or format, they can be views too.It’s the storage layer responsibility to figure out how to save them.
- Return type:
Tensor |BytesIO
- classtorch.distributed.checkpoint.SavePlan(items:list[torch.distributed.checkpoint.planner.WriteItem],storage_data:Any=None,planner_data:Any=None,usable:bool=True)[source]#
- classtorch.distributed.checkpoint.planner.WriteItem(index,type,bytes_io_data=None,tensor_data=None)[source]#
Dataclass which holds information about what needs to be written to storage.
We provide a filesystem based storage layer:
- classtorch.distributed.checkpoint.FileSystemWriter(path,single_file_per_rank=True,sync_files=True,thread_count=1,per_thread_copy_ahead=10000000,cache_staged_state_dict=False,overwrite=True,_extensions=None,serialization_format=SerializationFormat.TORCH_SAVE)[source]#
Basic implementation of StorageWriter using file IO.
This implementation makes the following assumptions and simplifications:
The checkpoint path is an empty or non-existing directory.
File creation is atomic
The checkpoint consist of one file per write request plusa global.metadata file with the serialized metadata if rank coordination is enabled.a rank local__{rank}.metadata file with the serialized metadata if rank coordination is NOT enabled.
We also provide other storage layers, including ones to interact with HuggingFace safetensors:
.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageReader:members:
.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageWriter:members:
.. autoclass:: torch.distributed.checkpoint.QuantizedHuggingFaceStorageReader:members:
We provide default implementations ofLoadPlanner andSavePlanner thatcan handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor.
- classtorch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True,flatten_sharded_tensors=True,dedup_replicated_tensors=None,dedup_save_to_lowest_rank=False,enable_plan_caching=False)[source]#
- classtorch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True,flatten_sharded_tensors=True,allow_partial_load=False)[source]#
DefaultLoadPlanner that adds multiple features on top of LoadPlanner.
In particular it adds the following:
flatten_state_dict: Handle state_dict with nested dictsflatten_sharded_tensors: For FSDP in 2D parallel modeallow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint.
Due to legacy design decisions, the state dictionaries ofFSDP andDDP may have different keys or fully qualified names (e.g., layer1.weight) even when the original unparallelized model is identical. Moreover,FSDP offers various types of model state dictionaries, such as full and sharded state dictionaries. Additionally, optimizer state dictionaries employ parameter IDs instead of fully qualified names to identify parameters, potentially causing issues when parallelisms are used (e.g., pipeline parallelism).
To tackle these challenges, we offer a collection of APIs for users to easily manage state_dicts.get_model_state_dict() returns a model state dictionary with keys consistent with those returned by the unparallelized model state dictionary. Similarly,get_optimizer_state_dict() provides the optimizer state dictionary with keys uniform across all parallelisms applied. To achieve this consistency,get_optimizer_state_dict() converts parameter IDs to fully qualified names identical to those found in the unparallelized model state dictionary.
Note that results returned by these APIs can be used directly with thetorch.distributed.checkpoint.save() andtorch.distributed.checkpoint.load() methods without requiring any additional conversions.
set_model_state_dict() andset_optimizer_state_dict() are provided to load the model and optimizer state_dict generated by by their respective getter APIs.
Note thatset_optimizer_state_dict() can only be called beforebackward() or afterstep() is called on optimizers.
Note that this feature is experimental, and API signatures might change in the future.
- torch.distributed.checkpoint.state_dict.get_state_dict(model,optimizers,*,submodules=None,options=None)[source]#
Return the model state_dict and optimizers state_dict.
get_state_dictcan process any module that is parallelized by PyTorchFSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and anycombination of these parallelisms. The main functions ofget_state_dictare: 1.) returning a model and optimizer state_dict that can be reshardedwith a different number of trainers and/or different parallelisms.2.) hiding the parallelism-specific state_dict APIs. Users don’t have to callthese APIs.3.) sanity checking the result state_dict.The keys of the result state dictionary are the canonical FQNs (FullyQualified Names). A canonical FQN refers to the FQN based on a parameter’sposition in an nn.Module hierarchy. More specifically, a canonical FQN to aparameter is the FQN returned by
module.named_parameters()ormodule.named_buffers()when the module is not distributed by anyparallelisms. Since the optimizer internally uses parameter IDs to representa parameter, there will be a conversion from the parameter IDs to thecanonical FQNs when calling this API.get_state_dictcan also process a module that is not parallelized. Insuch a case,get_state_dictonly performs one function – converting theoptimizer parameter IDs to the canonical FQNs.Example
>>>importtorch>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>fromtorch.nn.parallelimportDistributedDataParallelasDDP>>>fromtorch.distributed.checkpoint.state_dictimportget_state_dict
>>>fsdp_model=FSDP(copy.deepcopy(model))>>>fsdp_optim=torch.optim.Adam(model.parameters(),lr=1e-3)>>>ddp_model=DDP(copy.deepcopy(model))>>>ddp_optim=torch.optim.Adam(model.parameters(),lr=1e-3)
>>>ddp_state_dict,ddp_optim_state_dict=get_state_dict(ddp_model,ddp_optim)>>>fsdp_state_dict,fsdp_optim_state_dict=get_state_dict(...fsdp_model,fsdp_optim...)
>>># if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),>>># the asserts will fail.>>>assertddp_state_dict==fsdp_state_dict>>>assertddp_optim_state==fsdp_optim_state_dict
- Parameters:
model (nn.Module) – the nn.Module to the model.
optimizers (Union[None,Optimizer,Iterable[Optimizer]]) – The optimizers that are used to optimize
model.submodules (deprecated) – Optional[set[nn.Module]]: only return the model parametersthat belong to the submodules.
options (StateDictOptions) – the options to control howmodel state_dict and optimizer state_dict should be returned. SeeStateDictOptions for the details.
- Returns:
Tuplethat contain model state_dict and optimizer state_dict.- Return type:
- torch.distributed.checkpoint.state_dict.get_model_state_dict(model,*,submodules=None,options=None)[source]#
Return the model state_dict of
model.See
get_state_dictfor the detail usage.- Parameters:
model (nn.Module) – the nn.Module to the model.
submodules (deprecated) – Optional[set[nn.Module]]: only return the model parametersthat belong to the submodules.
options (StateDictOptions) – the options to control howmodel state_dict and optimizer state_dict should be returned. SeeStateDictOptions for the details.
- Returns:
The state_dict for
model.- Return type:
- torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model,optimizers,*,submodules=None,options=None)[source]#
Return the combined state_dict for optimizers.
See
get_state_dictfor the detail usage.- Parameters:
model (nn.Module) – the nn.Module to the model.
optimizers (Union[None,Optimizer,Iterable[Optimizer]]) – The optimizers that are used to optimize
model.submodules (deprecated) – Optional[set[nn.Module]]: only return the model parametersthat belong to the submodules.
options (StateDictOptions) – the options to control howmodel state_dict and optimizer state_dict should be returned. SeeStateDictOptions for the details.
- Returns:
The state_dict for
optimizers.- Return type:
OptimizerStateType
- torch.distributed.checkpoint.state_dict.set_state_dict(model,optimizers,*,model_state_dict,optim_state_dict,options=None)[source]#
Load the model state_dict and optimizers state_dict.
The counterpart of
get_state_dictto set the state_dict to the model andoptimizers. The givenmodel_state_dictandoptim_state_dictdo nothave to be returned byget_state_dictbut must meet the followingrequirements: 1) all FQNs are canonical FQNs as defined inget_state_dict,2) if a tensor is sharded, it must be either a ShardedTensor or DTensor,3) optimizer state_dict cannot contain the parameter IDs; the keys should bethe canonical FQNs.- WARN:
set_state_dictcan only be called beforebackward()or afterstep() is called on the optimizers. Otherwise, the optimizer states won’t be initializedcorrectly.
- Parameters:
model (nn.Module) – the nn.Module to the model.
optimizers (Union[Optimizer,Iterable[Optimizer]]) – The optimizers that are used to optimize
model.model_state_dict (Dict[str,ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]):the model state_dict to load. If the key of the
model_state_dictis nn.Module, the key is a submodule ofmodeland the value shouldbe the state_dict of the submodule. When loading the state_dict,the prefix of the submodule will be append to the state_dict.optim_state_dict (OptimizerStateType) – OptimizerStateType:the optimizer state_dict to load.
options (StateDictOptions) – the options to control howmodel state_dict and optimizer state_dict should be loaded. SeeStateDictOptions for the details.
- Returns:
missing_keys is a list of str containing the missing keys of the model state_dict.
unexpected_keys is a list of str containing the unexpected keys of the model state_dict.
- Return type:
NamedTuplewithmissing_keysandunexpected_keysfields
- WARN:
- torch.distributed.checkpoint.state_dict.set_model_state_dict(model,model_state_dict,*,options=None)[source]#
Load the model state_dict.
The counterpart of
get_model_state_dictto set the state_dict to themodel. Seeset_state_dictfor the detail usage.- Parameters:
model (nn.Module) – the nn.Module to the model.
model_state_dict (Dict[str,ValueType]) – (Dict[str, ValueType]):the model state_dict to load. If the key of the
model_state_dictis nn.Module, the key is a submodule ofmodeland the value shouldbe the state_dict of the submodule. When loading the state_dict,the prefix of the submodule will be append to the state_dict.options (StateDictOptions) – the options to control howmodel state_dict and optimizer state_dict should be loaded. SeeStateDictOptions for the details.
- Returns:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type:
NamedTuplewithmissing_keysandunexpected_keysfields
- torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model,optimizers,optim_state_dict,*,options=None)[source]#
Load the optimizers state_dict.
The counterpart of
get_optimizer_state_dictto set the state_dict to theoptimizers. Seeset_state_dictfor the detail usage.- WARN:
set_optimizer_state_dictcan only be called beforebackward()or after step()is called on the optimizers. Otherwise, the optimizer states won’t beinitialized correctly.
- Parameters:
model (nn.Module) – the nn.Module to the model.
optimizers (Union[Optimizer,Iterable[Optimizer]]) – The optimizers that are used to optimize
model.optim_state_dict (OptimizerStateType) – OptimizerStateType:the optimizer state_dict to load.
options (StateDictOptions) – the options to control howmodel state_dict and optimizer state_dict should be loaded. SeeStateDictOptions for the details.
- Returns:
None
- Return type:
None
- WARN:
- classtorch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False,cpu_offload=False,ignore_frozen_params=False,keep_submodule_prefixes=True,strict=True,broadcast_from_rank0=False,flatten_optimizer_state_dict=False,dsd_fqn_modifiers='_fqn_modifiers')[source]#
This dataclass specifies how get_state_dict/set_state_dict will work.
full_state_dict: if this is set to True, all the tensors in thereturned state_dict will be gathered. No ShardedTensor and DTensorwill be in the returned state_dict.cpu_offload: offload all the tensors to cpu. To prevent CPU OOM, iffull_state_dictis also true, then only the rank0 will get thestate_dict and all other ranks will get empty state_dict.ignore_frozen_params: if the value is True, the returned state_dictwon’t contain any frozen parameters – therequires_gradis False.The default value is False.keep_submodule_prefixes(deprecated): whensubmodulesis not None, this optionindicates whether to keep the submodule prefixes from the state_dict keys.or example, if the submodule ismodule.pretrainand the full FQN ofthe parameter ispretrain.layer1.weightof the param. When this optionis True, the parameter’s key in the returned state_dict will bepretrain.layer1.weight. If the options is False, the key will belayer1.weight.Note that ifkeep_submodule_prefixesis False, there may be conflictedFQNs, hence there should be only one submodule insubmodules.strict: thestrictoption whenset_state_dictcallsmodel.load_state_dict().broadcast_from_rank0: when the option is True, rank0 should receive afull state_dict and will broadcast the tensors in the state_dict/optim_state_dict one by one to other ranks. Other ranks will receivethe tensors and shard according to the local shards in the model andoptimizer.
full_state_dictmust be set to True when using this option.This option currently only supports DTensor, not the legacy ShardedTensor.
For users which are used to using and sharing models in thetorch.save format, the following methods are provided which provide offline utilities for converting betweeing formats.
- torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir,torch_save_path)[source]#
Given a directory containing a DCP checkpoint, this function will convert it into aTorch save file.
- Parameters:
Warning
To avoid OOM, it’s recommended to only run this function on a single rank.
- torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path,dcp_checkpoint_dir)[source]#
Given the location of a torch save file, converts it into a DCP checkpoint.
- Parameters:
Warning
To avoid OOM, it’s recommended to only run this function on a single rank.
The following classes can also be utilized for online loading and resharding of models from the torch.save format.
- classtorch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None,coordinator_rank=0)[source]#
StorageReader for reading a Torch Save file. This reader will read the entire checkpointon the coordinator rank, and then broadcast and shard each tensor to all ranks.
. N.B. Intended to be used with DynamicMetaLoadPlanner
Warning
Current implementation only supports loading Tensors.
>>>sd={"mode":model}>>>dcp.load(>>>sd,>>>storage_reader=BroadcastingTorchSaveReader(),>>>planner=DynamicMetaLoadPlanner(),>>>checkpoint_id="path_to_model.pt">>>)
- read_data(plan,planner)[source]#
Reads torch save data on the coordinator rank, and broadcast afterwardsthis incurrs a communication cost, but avoids having to loadthe entire checkpoint on each rank, hopefully preventing OOM issues
- Return type:
Future[None]
- classtorch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True,flatten_sharded_tensors=True,allow_partial_load=False)[source]#
Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict,avoiding the need to read metadata from disk. This is useful when reading formats which don’t have ametadata file, like Torch Save files.
. N.B. Intended to be used with BroadcastingTorchSaveReader
Warning
Current implementation only supports loading Tensors.
>>>sd={"mode":model}>>>dcp.load(>>>sd,>>>storage_reader=BroadcastingTorchSaveReader(),>>>planner=DynamicMetaLoadPlanner(),>>>checkpoint_id="path_to_model.pt">>>)
The following experimental interfaces are provided for improved observability in production environments: