Rate this Page

Distributed Checkpoint - torch.distributed.checkpoint#

Created On: Nov 16, 2022 | Last Updated On: Oct 08, 2025

Distributed Checkpoint (DCP) support loading and saving models from multiple ranks in parallel.It handles load-time resharding which enables saving in one cluster topology and loading into another.

DCP is different thantorch.save andtorch.load in a few significant ways:

  • It produces multiple files per checkpoint, with at least one per rank.

  • It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.

The entrypoints to load and save a checkpoint are the following:

Additional resources:#

classtorch.distributed.checkpoint.state_dict_saver.AsyncCheckpointerType(value)[source]#

Enum for async checkpointer type.

classtorch.distributed.checkpoint.state_dict_saver.AsyncSaveResponse(staging_completion,upload_completion)[source]#

This class contains futures for staging and upload completion.It is returned by async_save().staging_completion is a future that indicates when local copyof state_dict is complete.upload_completion is a future that indicates when a checkpointcompleted saving.

torch.distributed.checkpoint.state_dict_saver.save(state_dict,*,checkpoint_id=None,storage_writer=None,planner=None,process_group=None,no_dist=False,use_collectives=True)[source]#

Save a distributed model in SPMD style.

This function is different fromtorch.save() as it handlesShardedTensor , andDTensor by having each rank only save their local shards.

For eachStateful object (having both astate_dict and aload_state_dict),save will callstate_dict before serialization.

Warning

There is no guarantees of Backwards Compatibility across PyTorch versionsfor saved state_dicts.

Warning

If using theprocess_group argument, make sure that only its rankscallsave_state_dict and that all data in state_dict belong to it.

Note

When saving checkpoint for FSDP’sShardingStrategy.HYBRID_SHARD, only one ofthe shard_group should be callingsave_state_dict and the corresponding processgroup needs to be passed in.

Note

If no process group is available, this function assumes the intention is to save the

state_dict in the local process.

Parameters:
  • state_dict (Dict[str,Any]) – The state_dict to save.

  • checkpoint_id (Union[str,os.PathLike,None]) – The ID of this checkpoint instance. The meaning of the checkpoint_iddepends on the storage. It can be a path to a folder or to a file.It can also be a key if the storage is a key-value store.(Default:None)

  • storage_writer (Optional[StorageWriter]) – Instance of StorageWriter used to perform writes. If this is notspecified, DCP will automatically infer the writer based on thecheckpoint_id. If checkpoint_id is also None, an exception willbe raised. (Default:None)

  • planner (Optional[SavePlanner]) – Instance of SavePlanner. If this is not specified, the defaultplanner will be used. (Default:None)

  • process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization.(Default:None)

  • no_dist (bool) – IfTrue, this function will assume the intent is to loada checkpoint on a single rank/process.(Default:False)

  • use_collectives (bool) – IfFalse, this function will assume the intent is to savea checkpoint without using cross-rank synchronization.(Default:True)This configuration is experimental and should be used with caution.It will change the format of the saved checkpoint and may not be backward compatible.

Returns:

Metadata object for the saved checkpoint.

Return type:

Metadata

Example

>>>my_model=MyModule()
>>>state_dict={"model":my_model}
>>>fs_storage_writer=torch.distributed.checkpoint.FileSystemWriter(..."/checkpoint/1"...)>>>torch.distributed.checkpoint.save(>>>state_dict=state_dict,>>>storage_writer=fs_storage_writer,>>>)

Note

save_state_dict uses collectives to coordinate writes across ranks.For NCCL-based process groups, internal tensor representations ofobjects must be moved to the GPU device before communication takes place.In this case, the device used is given bytorch.cuda.current_device()and it is the user’s responsibility to ensure that this is set so thateach rank has an individual GPU, viatorch.cuda.set_device().

torch.distributed.checkpoint.state_dict_saver.async_save(state_dict,*,checkpoint_id=None,storage_writer=None,planner=None,process_group=None,async_checkpointer_type=AsyncCheckpointerType.THREAD,async_stager=None,no_dist=False,use_collectives=True)[source]#

Asynchronous version ofsave. This code first de-stages the state_dict on to thestaging storage (defaults to CPU memory), and then calls thesave in a separate thread.

Warning

This feature is experimental and subject to change.MUST CALL CLOSE AFTER LAST CHECKPOINT IS SAVED

Parameters:
  • state_dict (Dict[str,Any]) – The state_dict to save.

  • checkpoint_id (Union[str,os.PathLike,None]) – The ID of this checkpoint instance. The meaning of the checkpoint_iddepends on the storage. It can be a path to a folder or to a file.It can also be a key if the storage is a key-value store.(Default:None)

  • storage_writer (Optional[StorageWriter]) – Instance of StorageWriter used to perform ‘stage’ and ‘save’. Ifthis is not specified, DCP will automatically infer the writer based on thecheckpoint_id. If checkpoint_id is also None, an exception willbe raised. (Default:None)

  • planner (Optional[SavePlanner]) – Instance of SavePlanner. If this is not specified, the defaultplanner will be used. (Default:None)

  • process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization.(Default:None)

  • async_checkpointer_type (AsyncCheckpointerType) – whether to do checkpoint in separate thread or process(Default:AsyncCheckpointerType.THREAD)

  • async_stager (AsyncStager) – provides staging implementation. If storage_writer implements AsyncStagerand async_stager is provided, async_stager will be used for staging

  • no_dist (bool) – IfTrue, this function will assume the intent is to savea checkpoint on a single rank/process.(Default:False)

  • use_collectives (bool) – If False, Save the checkpoint without rank coordination. (Default:True)This configuration is experimental and should be used with caution.It will change the format of the saved checkpoint and may not be backward compatible.

Returns:

A future holding the resultant Metadata object fromsave.

Return type:

Future

Example

>>>my_model=MyModule()
>>>state_dict={"model":my_model}
>>>fs_storage_writer=torch.distributed.checkpoint.FileSystemWriter(..."/checkpoint/1"...)>>>checkpoint_future=torch.distributed.checkpoint.async_save(>>>state_dict=state_dict,>>>storage_writer=fs_storage_writer,>>>)>>>>>># ... do some work ...>>>>>>checkpoint_future.result()
torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict,storage_writer,process_group=None,coordinator_rank=0,no_dist=False,planner=None)[source]#

This method is deprecated. Please switch to ‘save’.

Return type:

Metadata

torch.distributed.checkpoint.state_dict_loader.load(state_dict,*,checkpoint_id=None,storage_reader=None,planner=None,process_group=None,no_dist=False)[source]#

Load a checkpoint into a distributed state dict in SPMD style.

Each rank must have the same keys in theirstate_dict provided to thisAPI. Mismatched keys may result in hangs or errors. If unsure, you can usetheutils._assert_same_keys API to check (but may incur communicationcosts).

Each rank will try to read the least amount of data necessaryto fulfill the requestedstate_dict. When loadingShardedTensororDTensor instances, each rank only reads data for their local shards.

For eachStateful object (having both astate_dict and aload_state_dict),load will first callstate_dict before attempting deserialization, followed byload_state_dict once the deserialization is complete.For each non-Stateful object, load will deserialize the object, and then replaceit in thestate_dict with the deserialized object.

Warning

All tensors instate_dict must be allocated on theirdestination deviceprior to calling this function.

All non-tensor data is loaded usingtorch.load() and modified in placeon state_dict.

Warning

Users must callload_state_dict on the root module to ensure loadpos-processing and non-tensor data properly propagates.

Parameters:
  • state_dict (Dict[str,Any]) – The state_dict to load the checkpoint into.

  • checkpoint_id (Union[str,os.PathLike,None]) – The ID of this checkpoint instance. The meaning of the checkpoint_iddepends on the storage. It can be a path to a folder or to a file.It can also be a key if the storage is a key-value store.(Default:None)

  • storage_reader (Optional[StorageReader]) – Instance of StorageWriter used to perform reads. If this is notspecified, DCP will automatically infer the reader based on thecheckpoint_id. If checkpoint_id is also None, an exception willbe raised. (Default:None)

  • planner (Optional[LoadPlanner]) – Instance of LoadPlanner. If this is not specified, the defaultplanner will be used. (Default:None)

  • process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization.(Default:None)

  • no_dist (bool) – IfTrue, this function will assume the intent is to loada checkpoint without using cross-rank synchronization. (Default:False)

Returns:

None.

Return type:

None

Examples
>>>my_model=MyModule()>>>optimizer=Adagrad(my_model.parameters())>>>model_state_dict=my_model.state_dict()>>>fs_storage_reader=torch.distributed.checkpoint.FileSystemReader(..."/checkpoint/1"...)
>>>torch.distributed.checkpoint.load_state_dict(>>>state_dict=model_state_dict,>>>storage_reader=fs_storage_reader,>>>)
>>># module.load_state_dict() function might have customized steps>>># to flush the state_dict, must call it to>>># ensure correct behavior.>>>my_model.load_state_dict(model_state_dict)

Note

load_state_dict uses collectives to coordinate reads across ranks.For NCCL-based process groups, internal tensor representations ofobjects must be moved to the GPU device before communication takes place.In this case, the device used is given bytorch.cuda.current_device()and it is the user’s responsibility to ensure that this is set so that eachrank has an individual GPU, viatorch.cuda.set_device().

torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict,storage_reader,process_group=None,coordinator_rank=0,no_dist=False,planner=None)[source]#

This method is deprecated. Please switch to ‘load’.

The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (torch.distributed.checkpoint.async_save):

classtorch.distributed.checkpoint.staging.AsyncStager(*args,**kwargs)[source]#

This protocol is meant to provide customization and extensibility for dcp.async_save, allowing usersto customize how data is staged previous to executing the usual dcp.save path in parallel.The expected order of operations (concretely defined intorch.distributed.state_dict_saver.async_save)is the following:

  1. AsyncStager.stage_data(state_dict):

    This call gives the AsyncStager the opportunity to ‘stage’the state_dict. The expectation and purpose of staging in this context is to create a “training-safe”representation of the state dict, meaning that any updates to module data after staging is completeshould not be reflected in the state dict returned from this method. For example, in the defaultcase a copy of the entire state dict is created on CPU RAM and returned here, allowing usersto continue training without risking changes to data which is being serialized.

  2. dcp.save is called on the state_dict returned from stage in parallel. This call is responsible

    for serializing the state_dict and writing it to storage.

  3. If AsyncStager.should_synchronize_after_execute is True, this method will be called immediately after

    the serialization thread starts and before returning from dcp.async_save. If this is set to False,the assumption is the user has defined a custom synchronization point for the purpose of furtheroptimizing save latency in the training loop (for example, by overlapping staging with theforward/backward pass), and it is the respondsibility of the user to callAsyncStager.synchronize_stagingat the appropriate time.

close()[source]#

Clean up all resources used by the stager.

propertyshould_synchronize_after_execute:bool#

Whether to synchronize after executing the stage.

stage(state_dict)[source]#

Returns a “staged” copy ofstate_dict. The expectation of the staged copy is that it isinoculated from any updates incurred after the stage call is complete.

Return type:

Future[dict[str,StatefulT |Any]] |dict[str,StatefulT |Any]

synchronize_staging()[source]#

In the casestage is async in some way, this method should be called to ensure stagingis complete and it is safe to begin modifying the originalstate_dict

classtorch.distributed.checkpoint.staging.DefaultStager(config=StagingOptions(use_pinned_memory=True,use_shared_memory=True,use_async_staging=True,use_non_blocking_copy=True))[source]#

DefaultStager provides a full-featured staging implementation that combinesmultiple optimization techniques for efficient checkpoint preparation.

The staging process works as follows:1. State dictionary is submitted for staging (sync or async)2. Tensors are copied from GPU to optimized CPU storage3. CUDA operations are synchronized if non-blocking copies are used4. Staged state dictionary is returned or made available via Future

Usage Patterns:

# Synchronous stagingstager = DefaultStager(StagingOptions(use_async_staging=False))staged_dict = stager.stage(state_dict)stager.close()

# Asynchronous stagingstager = DefaultStager(StagingOptions(use_async_staging=True))future = stager.stage(state_dict)# … do other work …staged_dict = future.result()stager.close()

# Context manager pattern (recommended)stager = DefaultStager(config)with stager:result = stager.stage(state_dict)

Performance Considerations:
  • Async staging provides best performance when model computationcan overlap with staging operations

  • Pinned memory improves CPU-GPU transfer speeds but uses more memory

  • Shared memory allows efficient IPC to checkpoint process

  • Non-blocking copies reduce GPU idle time during memory transfers

Thread Safety:

DefaultStager is not thread-safe. Each thread should use its owninstance, or external synchronization should be provided.

close()[source]#

Clean up all resources used by the DefaultStager. Shuts down the ThreadPoolExecutorused for async staging operations and cleans up the underlying StateDictStager’scached storages. Should be called when the stager is no longer needed to preventresource leaks, especially in long-running applications. After calling close(),the stager should not be used for further staging operations.

Example Usage:

stager = DefaultStager(StagingOptions(use_async_staging=True))future = stager.stage(state_dict)result = future.result()stager.close() # Clean up all resources

stage(state_dict,**kwargs)[source]#

This function is responsible for staging staging the state_dict.See class docstring for more details on staging.If use_async_staging is True, it will return a Future object that will befulfilled when staging is complete.If use_async_staging is False, it will return the fully staged state_dict.

Parameters:

state_dict (STATE_DICT_TYPE) – The state_dict to be staged.

Return type:

dict[str,StatefulT |Any] |Future[dict[str,StatefulT |Any]]

synchronize_staging()[source]#

When use_async_staging is True, this method will wait until staging is complete.If use_async_staging is False, this method is a no-op.

classtorch.distributed.checkpoint.staging.StagingOptions(use_pinned_memory=True,use_shared_memory=True,use_async_staging=True,use_non_blocking_copy=True)[source]#

Configuration options for checkpoint staging behavior.

Variables:
  • use_pinned_memory (bool) – Enable pinned memory allocation for fasterCPU-GPU transfers. Requires CUDA to be available. Default: True

  • use_shared_memory (bool) – Enable shared memory for multi-processscenarios. Useful when multiple processes need access to thesame staged data. Default: True

  • use_async_staging (bool) – Enable asynchronous staging using abackground thread pool. Allows overlapping computation withstaging operations. Requires CUDA. Default: True

  • use_non_blocking_copy (bool) – Use non-blocking device memorycopies with stream synchronization. Improves performance byallowing CPU work to continue during GPU transfers. Default: True

Note

CUDA-dependent features will raise exception if CUDA is not available.

classtorch.distributed.checkpoint.staging.BlockingAsyncStager(cache_staged_state_dict=False,type_check=False)[source]#

An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete.This implementation also provides an option to optimize stage latency using pinned memory.

N.B. synchronize_staging is a no-op in this case.

stage(state_dict)[source]#

Returns a copy ofstate_dict on the CPU.

Return type:

dict[str,StatefulT |Any]

synchronize_staging()[source]#

No-op function, since staging is blocking.

In addition to the above entrypoints,Stateful objects, as described below, provide additional customization during saving/loading

classtorch.distributed.checkpoint.stateful.Stateful(*args,**kwargs)[source]#

Stateful protocol for objects that can be checkpointed and restored.

load_state_dict(state_dict)[source]#

Restore the object’s state from the provided state_dict.

Parameters:

state_dict (dict[str,Any]) – The state dict to restore from

state_dict()[source]#

Objects should return their state_dict representation as a dictionary.The output of this function will be checkpointed, and later restored inload_state_dict().

Warning

Because of the inplace nature of restoring a checkpoint, this functionis also called duringtorch.distributed.checkpoint.load.

Returns:

The objects state dict

Return type:

Dict

Thisexample shows how to use Pytorch Distributed Checkpoint to save a FSDP model.

The following types define the IO interface used during checkpoint:

classtorch.distributed.checkpoint.StorageReader[source]#

Interface used byload_state_dict to read from storage.

One StorageReader instance acts as both the coordinator and the followerin a distributed checkpoint. As part of initialization, each instanceis told its role.

A subclass should expected the following sequence of calls byload_state_dict:

  1. (all ranks) set checkpoint_id if users pass a valid checkpoint_id.

  2. (all ranks) read_metadata()

  3. (all ranks) set_up_storage_reader()

  4. (all ranks) prepare_local_plan()

  5. (coordinator) prepare_global_plan()

  6. (all ranks) read_data()

abstractprepare_global_plan(plans)[source]#

Perform centralized planning of storage loading.

This method is only called on the coordinator instance.

While this method can produce a completely different plan, the preferredway is to store storage specific data in LoadPlan::storage_data.

Parameters:

plans (list[LoadPlan]) – A list ofLoadPlan instances, one for each rank.

Returns:

A list of transformedLoadPlan after storage global planning

Return type:

list[LoadPlan]

abstractprepare_local_plan(plan)[source]#

Perform storage-specific local planning.

While this method can produce a completely different plan, the recommendedway is to store storage specific data in LoadPlan::storage_data.

Parameters:

plan (LoadPlan) – The local plan from theLoadPlan in use.

Returns:

A transformedLoadPlan after storage local planning

Return type:

LoadPlan

abstractread_data(plan,planner)[source]#

Read all items fromplan usingplanner to resolve the data.

A subclass should callLoadPlanner::load_bytes to deserialize a BytesIOobject into the right place.

A subclass should callLoadPlanner::resolve_tensor to get access to thetensors that in should load data into.

It’s the StorageLayer responsibility to properly schedule any cross device copiesrequired.

Parameters:
  • plan (LoadPlan) – The local plan to execute on

  • planner (LoadPlanner) – The planner object to use to resolve items.

Returns:

A future that completes once all reads are finished.

Return type:

Future[None]

abstractread_metadata(*args,**kwargs)[source]#

Read the checkpoint metadata.

Returns:

The metadata object associated with the checkpoint being loaded.

Return type:

Metadata

abstractreset(checkpoint_id=None)[source]#

Calls to indicates a brand new checkpoint read is going to happen.A checkpoint_id may be present if users set the checkpoint_id forthis checkpoint read. The meaning of the checkpiont_id isstorage-dependent. It can be a path to a folder/file or a key fora key-value storage.

Parameters:

checkpoint_id (Union[str,os.PathLike,None]) – The ID of this checkpoint instance. The meaning of the checkpoint_iddepends on the storage. It can be a path to a folder or to a file.It can also be a key if the storage is more like a key-value store.(Default:None)

abstractset_up_storage_reader(metadata,is_coordinator,*args,**kwargs)[source]#

Initialize this instance.

Parameters:
  • metadata (Metadata) – The metadata schema to use.

  • is_coordinator (bool) – Whether this instance is responsible for coordinatingthe checkpoint.

abstractclassmethodvalidate_checkpoint_id(checkpoint_id)[source]#

Check if the given checkpoint_id is supported by the storage. This allowus to enable automatic storage selection.

Return type:

bool

classtorch.distributed.checkpoint.StorageWriter[source]#

Interface used bysave_state_dict to write to storage.

One StorageWriter instance acts as both the coordinator and the followerin a distributed checkpoint. As part of initialization, each instanceis told its role.

A subclass should expect the following sequence of calls.

  1. (all ranks) set checkpoint_id if users pass a valid checkpoint_id.

  2. (all ranks) set_up_storage_writer()

  3. (all ranks) prepare_local_plan()

  4. (coordinator) prepare_global_plan()

  5. (all ranks) write_data()

  6. (coordinator) finish()

abstractfinish(metadata,results)[source]#

Write the metadata and marks the current checkpoint as successful.

The actual format/schema used for serializingmetadata is animplementation detail. The only requirement is that it’s recoverablein to the same object graph.

Parameters:
  • metadata (Metadata) – metadata for the new checkpoint

  • results (list[list[WriteResult]]) – A list of WriteResults from all ranks.

Returns:

None

Return type:

None

abstractprepare_global_plan(plans)[source]#

Perform centralized planning of storage.

This method is only called on the coordinator instance.

While this method can produce a completely different plan, the preferredway is to store storage specific data in SavePlan::storage_data.

Parameters:

plans (list[SavePlan]) – A list ofSavePlan instances, one for each rank.

Returns:

A list of transformedSavePlan after storage global planning

Return type:

list[SavePlan]

abstractprepare_local_plan(plan)[source]#

Perform storage-specific local planning.

While this method can produce a completely different plan, the recommendedway is to store storage specific data in SavePlan::storage_data.

Parameters:

plan (SavePlan) – The local plan from theSavePlanner in use.

Returns:

A transformedSavePlan after storage local planning

Return type:

SavePlan

abstractreset(checkpoint_id=None)[source]#

Calls to indicates a brand new checkpoint write is going to happen.A checkpoint_id may be present if users set the checkpoint_id forthis checkpoint write. The meaning of the checkpiont_id isstorage-dependent. It can be a path to a folder/file or a key fora key-value storage.

Parameters:

checkpoint_id (Union[str,os.PathLike,None]) – The ID of this checkpoint instance. The meaning of the checkpoint_iddepends on the storage. It can be a path to a folder or to a file.It can also be a key if the storage is a key-value store.(Default:None)

abstractset_up_storage_writer(is_coordinator,*args,**kwargs)[source]#

Initialize this instance.

Parameters:

is_coordinator (bool) – Whether this instance is responsible for coordinatingthe checkpoint.

storage_meta()[source]#

Return the storage-specific metadata. This is used to store additional informationin a checkpoint that can be useful for providing request-level observability. StorageMetais passed to theSavePlanner during save calls. Returns None by default.

TODO: provide an example

Return type:

StorageMeta | None

abstractclassmethodvalidate_checkpoint_id(checkpoint_id)[source]#

Check if the given checkpoint_id is supported by the storage. This allowus to enable automatic storage selection.

Return type:

bool

abstractwrite_data(plan,planner)[source]#

Write all items fromplan usingplanner to resolve the data.

A subclass should callSavePlanner::resolve_data on each itemfrom the plan to get access to the underlying object to write.

Subclasses should lazily callresolve_data as it can allocate memory.In case of tensors, make following assumptions:

  • They might be on any device, including not matching the one onWriteItem::tensor_data

  • They might be views or not contiguous. Only the projection needs to be saved.

Parameters:
  • plan (SavePlan) – The save plan to execute.

  • planner (SavePlanner) – Planner object to be used to resolve items to data.

Returns:

A future that completes to a list of WriteResult

Return type:

Future[list[WriteResult]]

The following types define the planner interface used during checkpoint:

classtorch.distributed.checkpoint.LoadPlanner[source]#

Abstract class defining the protocol used by load_state_dict to plan the load process.

LoadPlanner are stateful objects that can be used to customize the whole load process.

LoadPlanner acts as an access proxy to the state_dict, so any transformation done to itwill be visible to the whole process.

A planner subclass can expect the following sequence of calls during load_state_dict:

  1. set_up_planner - called on all ranks.

    Signals the start of loading a checkpoint.

  2. create_local_plan - called on all ranks.

    Process the state_dict and produces aLoadPlan that will be sent for global planning.

  3. create_global_plan - called on the coordinator rank only.

    Takes the LoadPlan from all ranks and make any global decision.

  4. load_bytes - called multiple times on each rank

    This is called once per non-tensor value in state_dict.

  5. resolve_tensor and commit_tensor - called multiple times on each rank

    They are called in pair for each Tensor value in state_dict.

Users are recommended to extend DefaultLoadPlanner instead of this interface directly asmost changes can be expressed by changes in a single method.

There are two usual patterns of extension:

Rewriting state_dict. This is the simplest way to extend the load process as itdoesn’t requite understanding the intrincacies of how LoadPlan works. We needto keep a reference to the original state_dict as load happens in place sowe need to be able to perform it in place

>>>classRenamePlanner(DefaultLoadPlanner):>>>defset_up_planner(>>>self,>>>state_dict:STATE_DICT_TYPE,>>>metadata:Metadata,>>>is_coordinator:bool,>>>)->None:>>>self.original_state_dict=state_dict>>>state_dict={"foo_"+k:vfork,vinstate_dict.items()}>>>>>>ifself.flatten_sharded_tensors:>>>state_dict=_flatten_sharded_tensors(state_dict)>>>>>>ifself.flatten_state_dict:>>>state_dict,self.mappings=flatten_state_dict(state_dict)>>>>>>self.state_dict=state_dict>>>self.metadata=metadata>>>self.is_coordinator=is_coordinator>>>>>>defload_bytes(self,read_item,value):>>># Remove the "foo_" prefix>>>self.original_state_dict[read_item.dest_index.fqn[4:]]=torch.load(value,weights_only=False)

Modifying resolve_tensor and commit_tensor to handle load time transformation.

>>>classMetaModelMaterialize(DefaultSavePlanner):>>>defresolve_tensor(self,read_item):>>>tensor=super().resolve_tensor(read_item)>>>returntorch.empty_like(tensor,device="cpu")>>>>>>defcommit_tensor(self,read_item,tensor):>>>self.state_dict[read_item.dest_index.fqn]=tensor
abstractcommit_tensor(read_item,tensor)[source]#

Call once the StorageReader finished loading data intotensor.

The provided tensor is the same one returned by the call toresolve_tensor.This method is only needed if this LoadPlanner needs to post processtensor prior tocopying it back to the one in the state_dict.

The contents of tensor will follow its device synchronization model.

abstractcreate_global_plan(global_plan)[source]#

Compute the global load plan and return plans for each rank.

. N.B. This is called on the coordinator rank only

Return type:

list[LoadPlan]

abstractcreate_local_plan()[source]#

Create a LoadPlan based on state_dict and metadata provided by set_up_planner.

. N.B. This is called on every rank.

Return type:

LoadPlan

abstractfinish_plan(central_plan)[source]#

Accept the plan from coordinator and return final LoadPlan.

Return type:

LoadPlan

abstractload_bytes(read_item,value)[source]#

Load the item described byread_item``and``value.

This method is expected to modify in-place the underlying state_dict.

The contents ofvalue are defined by the SavePlanner used to producethe checkpoint being loaded.

resolve_bytes(read_item)[source]#

Return the BytesIO to be used by the StorageReader to loadread_item.

The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents.

Return type:

BytesIO

abstractresolve_tensor(read_item)[source]#

Return the tensor described byread_item to be used by the StorageReader to loadread_item.

The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents.If, for any reason, that’s not possible, the planner can use thecommit_tensor method to copy the databack to the one in state_dict.

Return type:

Tensor

abstractset_up_planner(state_dict,metadata=None,is_coordinator=False)[source]#

Initialize this instance to load data intostate_dict.

. N.B. This is called on every rank.

classtorch.distributed.checkpoint.LoadPlan(items:list[torch.distributed.checkpoint.planner.ReadItem],storage_data:Any=None,planner_data:Any=None)[source]#
classtorch.distributed.checkpoint.ReadItem(type:torch.distributed.checkpoint.planner.LoadItemType,dest_index:torch.distributed.checkpoint.metadata.MetadataIndex,dest_offsets:torch.Size,storage_index:torch.distributed.checkpoint.metadata.MetadataIndex,storage_offsets:torch.Size,lengths:torch.Size)[source]#
classtorch.distributed.checkpoint.SavePlanner[source]#

Abstract class defining the protocol used by save_state_dict to plan the save process.

SavePlanners are stateful objects that can be used to customize the whole save process.

SavePlanner acts as an access proxy to the state_dict, so any transformation done to itwill be visible to the whole process.

A planner subclass can expect the following sequence of calls during save_state_dict:

  1. set_up_planner - called on all ranks.

    Signals the start of a checkpoint save.

  2. create_local_plan - called on all ranks.

    Process the state_dict and produces aSavePlan that will be sent for global planning.

  3. create_global_plan - called on the coordinator rank only.

    Takes the SavePlan from all ranks and make any global decision.

  4. finish_plan - called on all ranks.

    This gives each rank a chance to adjust to global planning decisions.

  5. resolve_data - called multiple times on each rank

    Lookups a value on thestate_dict for the storage layer to write.

Users are recommended to extend DefaultSavePlanner instead of this interface directly asmost changes can be expressed by changes in a single method.

There are 3 usual patterns of extension:

Rewriting state_dict. This is the simplest way to extend the save process as itdoesn’t requite understanding the intrincacies of how SavePlan works:

>>>classRenamePlanner(DefaultSavePlanner):>>>defset_up_planner(>>>self,>>>state_dict:STATE_DICT_TYPE,>>>storage_meta:Optional[StorageMeta],>>>is_coordinator:bool,>>>)->None:>>># prefix all keys with `foo_``>>>super().set_up_planner({"foo_"+k:vfork,vinstate_dict.items()},storage_meta,is_coordinator)

Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted

>>>classFP16Planner(DefaultSavePlanner):>>>defcreate_local_plan(self):>>>plan=super().create_local_plan()>>>forpinplan:>>>ifp.tensor_dataisnotNone:>>>p.tensor_data.properties.dtype=torch.float16>>>returnplan>>>>>>defresolve_data(self,write_item):>>>item=super().resolve_data(write_item)>>>returnitemifwrite_item.type==WriteItemType.BYTE_IOelseitem.to(torch.float16)

Using the global planning step to make central decisions that can’t be made individually by each rank

>>>fromitertoolsimportzip_longest>>>fromdataclassesimportreplace>>>classDDPLoadBalancingPlanner(DefaultSavePlanner):>>># This uses the default local plan behavior of having all non-sharded writes in rank 0>>># This sample doesn't handle ShardedTensors>>>defcreate_global_plan(self,all_plans):>>>iters=[iter(all_plans[0].items)]*len(all_plans)>>>items_per_rank=[>>>[itemforiteminitemsifitemisnotNone]>>>foritemsinzip(*zip_longest(*iters),strict=True)>>>]>>>all_plans=[>>>replace(plan,items=items)>>>forplan,itemsinzip(all_plans,items_per_rank,strict=True)>>>]>>>returnsuper().create_global_plan(all_plans)

Finally, some planners need to save additional metadata in the checkpoint, this isaccomplished by having each rank contribute their data items in the local plan andthe global planner aggregate them:

>>>classSaveExtraDataPlanner(DefaultSavePlanner):>>>defcreate_local_plan(self)->SavePlan:>>>plan=super().create_local_plan()>>>returnreplace(plan,planner_data="per-rank-data")>>>>>>defcreate_global_plan(self,all_plans:List[SavePlan])->Tuple[List[SavePlan],Metadata]:>>>global_plan,metadata=super().create_global_plan(all_plans)>>>merged_data=[p.planner_dataforpinglobal_plan]>>>metadata=replace(metadata,planner_data=merged_data)>>>returnglobal_plan,metadata
abstractcreate_global_plan(all_plans)[source]#

Compute the global checkpoint plan and return the local plan of each rank.

This is called on the coordinator rank only.

Return type:

tuple[list[SavePlan],Metadata]

abstractcreate_local_plan()[source]#

Compute the save plan for the current rank.

This will be aggregated and passed to create_global_plan.Planner specific data can be passed through SavePlan::planner_data.

This is called on all ranks.

Return type:

SavePlan

abstractfinish_plan(new_plan)[source]#

Merge the plan created bycreate_local_plan and the result ofcreate_global_plan.

This is called on all ranks.

Return type:

SavePlan

abstractresolve_data(write_item)[source]#

Transform and preparewrite_item fromstate_dict for storage, ensuring idempotency and thread-safety.

Lookup the object associated withwrite_item instate_dict and apply anytransformation (such as serialization) prior to the storage layer consuming it.

Called on each rank multiple times, at least once per WriteItem in the final SavePlan.

This method should be idempotent and thread-save. StorageWriter implementationsare free to call it as frequently as they need.

Any transformation that allocates memory should be lazily done when his methodis called in order to reduce peak memory required by checkpointing.

When returning tensors, they can be on any device or format, they can be views too.It’s the storage layer responsibility to figure out how to save them.

Return type:

Tensor |BytesIO

abstractset_up_planner(state_dict,storage_meta=None,is_coordinator=False)[source]#

Initialize this planner to savestate_dict.

Implementations should save those values as they won’t be provided lated in the save process.

This is called on all ranks.

classtorch.distributed.checkpoint.SavePlan(items:list[torch.distributed.checkpoint.planner.WriteItem],storage_data:Any=None,planner_data:Any=None,usable:bool=True)[source]#
classtorch.distributed.checkpoint.planner.WriteItem(index,type,bytes_io_data=None,tensor_data=None)[source]#

Dataclass which holds information about what needs to be written to storage.

tensor_storage_size()[source]#

Calculates the storage size of the underlying tensor, or None if this is not a tensor write.

Returns:

Optional[int] storage size, in bytes of underlying tensor if any.

Return type:

int | None

classtorch.distributed.checkpoint.planner.BytesIOWriteData(nbytes:int)[source]#

We provide a filesystem based storage layer:

classtorch.distributed.checkpoint.FileSystemReader(path,_extension_registry=None)[source]#
propertycheckpoint_id:str|PathLike#

return the checkpoint_id that will be used to load the checkpoint.

classtorch.distributed.checkpoint.FileSystemWriter(path,single_file_per_rank=True,sync_files=True,thread_count=1,per_thread_copy_ahead=10000000,cache_staged_state_dict=False,overwrite=True,_extensions=None,serialization_format=SerializationFormat.TORCH_SAVE)[source]#

Basic implementation of StorageWriter using file IO.

This implementation makes the following assumptions and simplifications:

  • The checkpoint path is an empty or non-existing directory.

  • File creation is atomic

The checkpoint consist of one file per write request plusa global.metadata file with the serialized metadata if rank coordination is enabled.a rank local__{rank}.metadata file with the serialized metadata if rank coordination is NOT enabled.

stage(state_dict)[source]#

Override of AsyncStager.stage

Return type:

dict[str,StatefulT |Any]

We also provide other storage layers, including ones to interact with HuggingFace safetensors:

.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageReader:members:

.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageWriter:members:

.. autoclass:: torch.distributed.checkpoint.QuantizedHuggingFaceStorageReader:members:

We provide default implementations ofLoadPlanner andSavePlanner thatcan handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor.

classtorch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True,flatten_sharded_tensors=True,dedup_replicated_tensors=None,dedup_save_to_lowest_rank=False,enable_plan_caching=False)[source]#
lookup_object(index)[source]#

Extension from the planner interface to make it easy to extend the default planner.

Return type:

Any

transform_object(write_item,object)[source]#

Extension from the planner interface to make it easy to extend the default planner.

classtorch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True,flatten_sharded_tensors=True,allow_partial_load=False)[source]#

DefaultLoadPlanner that adds multiple features on top of LoadPlanner.

In particular it adds the following:

flatten_state_dict: Handle state_dict with nested dictsflatten_sharded_tensors: For FSDP in 2D parallel modeallow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint.

lookup_tensor(index)[source]#

Extension from the planner interface to make it easy to extend the default planner.

Return type:

Tensor

transform_tensor(read_item,tensor)[source]#

Extension from the planner interface to make it easy to extend the default planner.

Due to legacy design decisions, the state dictionaries ofFSDP andDDP may have different keys or fully qualified names (e.g., layer1.weight) even when the original unparallelized model is identical. Moreover,FSDP offers various types of model state dictionaries, such as full and sharded state dictionaries. Additionally, optimizer state dictionaries employ parameter IDs instead of fully qualified names to identify parameters, potentially causing issues when parallelisms are used (e.g., pipeline parallelism).

To tackle these challenges, we offer a collection of APIs for users to easily manage state_dicts.get_model_state_dict() returns a model state dictionary with keys consistent with those returned by the unparallelized model state dictionary. Similarly,get_optimizer_state_dict() provides the optimizer state dictionary with keys uniform across all parallelisms applied. To achieve this consistency,get_optimizer_state_dict() converts parameter IDs to fully qualified names identical to those found in the unparallelized model state dictionary.

Note that results returned by these APIs can be used directly with thetorch.distributed.checkpoint.save() andtorch.distributed.checkpoint.load() methods without requiring any additional conversions.

set_model_state_dict() andset_optimizer_state_dict() are provided to load the model and optimizer state_dict generated by by their respective getter APIs.

Note thatset_optimizer_state_dict() can only be called beforebackward() or afterstep() is called on optimizers.

Note that this feature is experimental, and API signatures might change in the future.

torch.distributed.checkpoint.state_dict.get_state_dict(model,optimizers,*,submodules=None,options=None)[source]#

Return the model state_dict and optimizers state_dict.

get_state_dict can process any module that is parallelized by PyTorchFSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and anycombination of these parallelisms. The main functions ofget_state_dictare: 1.) returning a model and optimizer state_dict that can be reshardedwith a different number of trainers and/or different parallelisms.2.) hiding the parallelism-specific state_dict APIs. Users don’t have to callthese APIs.3.) sanity checking the result state_dict.

The keys of the result state dictionary are the canonical FQNs (FullyQualified Names). A canonical FQN refers to the FQN based on a parameter’sposition in an nn.Module hierarchy. More specifically, a canonical FQN to aparameter is the FQN returned bymodule.named_parameters() ormodule.named_buffers() when the module is not distributed by anyparallelisms. Since the optimizer internally uses parameter IDs to representa parameter, there will be a conversion from the parameter IDs to thecanonical FQNs when calling this API.

get_state_dict can also process a module that is not parallelized. Insuch a case,get_state_dict only performs one function – converting theoptimizer parameter IDs to the canonical FQNs.

Example

>>>importtorch>>>fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP>>>fromtorch.nn.parallelimportDistributedDataParallelasDDP>>>fromtorch.distributed.checkpoint.state_dictimportget_state_dict
>>>fsdp_model=FSDP(copy.deepcopy(model))>>>fsdp_optim=torch.optim.Adam(model.parameters(),lr=1e-3)>>>ddp_model=DDP(copy.deepcopy(model))>>>ddp_optim=torch.optim.Adam(model.parameters(),lr=1e-3)
>>>ddp_state_dict,ddp_optim_state_dict=get_state_dict(ddp_model,ddp_optim)>>>fsdp_state_dict,fsdp_optim_state_dict=get_state_dict(...fsdp_model,fsdp_optim...)
>>># if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),>>># the asserts will fail.>>>assertddp_state_dict==fsdp_state_dict>>>assertddp_optim_state==fsdp_optim_state_dict
Parameters:
  • model (nn.Module) – the nn.Module to the model.

  • optimizers (Union[None,Optimizer,Iterable[Optimizer]]) – The optimizers that are used to optimizemodel.

  • submodules (deprecated) – Optional[set[nn.Module]]: only return the model parametersthat belong to the submodules.

  • options (StateDictOptions) – the options to control howmodel state_dict and optimizer state_dict should be returned. SeeStateDictOptions for the details.

Returns:

Tuple that contain model state_dict and optimizer state_dict.

Return type:

Tuple[Dict[str, ValueType], OptimizerStateType]

torch.distributed.checkpoint.state_dict.get_model_state_dict(model,*,submodules=None,options=None)[source]#

Return the model state_dict ofmodel.

Seeget_state_dict for the detail usage.

Parameters:
  • model (nn.Module) – the nn.Module to the model.

  • submodules (deprecated) – Optional[set[nn.Module]]: only return the model parametersthat belong to the submodules.

  • options (StateDictOptions) – the options to control howmodel state_dict and optimizer state_dict should be returned. SeeStateDictOptions for the details.

Returns:

The state_dict formodel.

Return type:

Dict[str, ValueType]

torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model,optimizers,*,submodules=None,options=None)[source]#

Return the combined state_dict for optimizers.

Seeget_state_dict for the detail usage.

Parameters:
  • model (nn.Module) – the nn.Module to the model.

  • optimizers (Union[None,Optimizer,Iterable[Optimizer]]) – The optimizers that are used to optimizemodel.

  • submodules (deprecated) – Optional[set[nn.Module]]: only return the model parametersthat belong to the submodules.

  • options (StateDictOptions) – the options to control howmodel state_dict and optimizer state_dict should be returned. SeeStateDictOptions for the details.

Returns:

The state_dict foroptimizers.

Return type:

OptimizerStateType

torch.distributed.checkpoint.state_dict.set_state_dict(model,optimizers,*,model_state_dict,optim_state_dict,options=None)[source]#

Load the model state_dict and optimizers state_dict.

The counterpart ofget_state_dict to set the state_dict to the model andoptimizers. The givenmodel_state_dict andoptim_state_dict do nothave to be returned byget_state_dict but must meet the followingrequirements: 1) all FQNs are canonical FQNs as defined inget_state_dict,2) if a tensor is sharded, it must be either a ShardedTensor or DTensor,3) optimizer state_dict cannot contain the parameter IDs; the keys should bethe canonical FQNs.

WARN:set_state_dict can only be called beforebackward() or afterstep()

is called on the optimizers. Otherwise, the optimizer states won’t be initializedcorrectly.

Parameters:
  • model (nn.Module) – the nn.Module to the model.

  • optimizers (Union[Optimizer,Iterable[Optimizer]]) – The optimizers that are used to optimizemodel.

  • model_state_dict (Dict[str,ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]):the model state_dict to load. If the key of themodel_state_dictis nn.Module, the key is a submodule ofmodel and the value shouldbe the state_dict of the submodule. When loading the state_dict,the prefix of the submodule will be append to the state_dict.

  • optim_state_dict (OptimizerStateType) – OptimizerStateType:the optimizer state_dict to load.

  • options (StateDictOptions) – the options to control howmodel state_dict and optimizer state_dict should be loaded. SeeStateDictOptions for the details.

Returns:

  • missing_keys is a list of str containing the missing keys of the model state_dict.

  • unexpected_keys is a list of str containing the unexpected keys of the model state_dict.

Return type:

NamedTuple withmissing_keys andunexpected_keys fields

torch.distributed.checkpoint.state_dict.set_model_state_dict(model,model_state_dict,*,options=None)[source]#

Load the model state_dict.

The counterpart ofget_model_state_dict to set the state_dict to themodel. Seeset_state_dict for the detail usage.

Parameters:
  • model (nn.Module) – the nn.Module to the model.

  • model_state_dict (Dict[str,ValueType]) – (Dict[str, ValueType]):the model state_dict to load. If the key of themodel_state_dictis nn.Module, the key is a submodule ofmodel and the value shouldbe the state_dict of the submodule. When loading the state_dict,the prefix of the submodule will be append to the state_dict.

  • options (StateDictOptions) – the options to control howmodel state_dict and optimizer state_dict should be loaded. SeeStateDictOptions for the details.

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple withmissing_keys andunexpected_keys fields

torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model,optimizers,optim_state_dict,*,options=None)[source]#

Load the optimizers state_dict.

The counterpart ofget_optimizer_state_dict to set the state_dict to theoptimizers. Seeset_state_dict for the detail usage.

WARN:set_optimizer_state_dict can only be called beforebackward() or after

step() is called on the optimizers. Otherwise, the optimizer states won’t beinitialized correctly.

Parameters:
  • model (nn.Module) – the nn.Module to the model.

  • optimizers (Union[Optimizer,Iterable[Optimizer]]) – The optimizers that are used to optimizemodel.

  • optim_state_dict (OptimizerStateType) – OptimizerStateType:the optimizer state_dict to load.

  • options (StateDictOptions) – the options to control howmodel state_dict and optimizer state_dict should be loaded. SeeStateDictOptions for the details.

Returns:

None

Return type:

None

classtorch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False,cpu_offload=False,ignore_frozen_params=False,keep_submodule_prefixes=True,strict=True,broadcast_from_rank0=False,flatten_optimizer_state_dict=False,dsd_fqn_modifiers='_fqn_modifiers')[source]#

This dataclass specifies how get_state_dict/set_state_dict will work.

  • full_state_dict: if this is set to True, all the tensors in thereturned state_dict will be gathered. No ShardedTensor and DTensorwill be in the returned state_dict.

  • cpu_offload: offload all the tensors to cpu. To prevent CPU OOM, iffull_state_dict is also true, then only the rank0 will get thestate_dict and all other ranks will get empty state_dict.

  • ignore_frozen_params: if the value is True, the returned state_dictwon’t contain any frozen parameters – therequires_grad is False.The default value is False.

  • keep_submodule_prefixes (deprecated): whensubmodules is not None, this optionindicates whether to keep the submodule prefixes from the state_dict keys.or example, if the submodule ismodule.pretrain and the full FQN ofthe parameter ispretrain.layer1.weight of the param. When this optionis True, the parameter’s key in the returned state_dict will bepretrain.layer1.weight. If the options is False, the key will belayer1.weight.Note that ifkeep_submodule_prefixes is False, there may be conflictedFQNs, hence there should be only one submodule insubmodules.

  • strict: thestrict option whenset_state_dict callsmodel.load_state_dict().

  • broadcast_from_rank0: when the option is True, rank0 should receive a

    full state_dict and will broadcast the tensors in the state_dict/optim_state_dict one by one to other ranks. Other ranks will receivethe tensors and shard according to the local shards in the model andoptimizer.full_state_dict must be set to True when using this option.This option currently only supports DTensor, not the legacy ShardedTensor.

For users which are used to using and sharing models in thetorch.save format, the following methods are provided which provide offline utilities for converting betweeing formats.

torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir,torch_save_path)[source]#

Given a directory containing a DCP checkpoint, this function will convert it into aTorch save file.

Parameters:
  • dcp_checkpoint_dir (str |PathLike) – Directory containing the DCP checkpoint.

  • torch_save_path (str |PathLike) – Filename to store the converted Torch save file.

Warning

To avoid OOM, it’s recommended to only run this function on a single rank.

torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path,dcp_checkpoint_dir)[source]#

Given the location of a torch save file, converts it into a DCP checkpoint.

Parameters:
  • torch_save_path (str |PathLike) – Filename of the Torch save file.

  • dcp_checkpoint_dir (str |PathLike) – Directory to store the DCP checkpoint.

Warning

To avoid OOM, it’s recommended to only run this function on a single rank.

The following classes can also be utilized for online loading and resharding of models from the torch.save format.

classtorch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None,coordinator_rank=0)[source]#

StorageReader for reading a Torch Save file. This reader will read the entire checkpointon the coordinator rank, and then broadcast and shard each tensor to all ranks.

. N.B. Intended to be used with DynamicMetaLoadPlanner

Warning

Current implementation only supports loading Tensors.

>>>sd={"mode":model}>>>dcp.load(>>>sd,>>>storage_reader=BroadcastingTorchSaveReader(),>>>planner=DynamicMetaLoadPlanner(),>>>checkpoint_id="path_to_model.pt">>>)
prepare_global_plan(global_plan)[source]#

Implementation of the StorageReader method

Return type:

list[LoadPlan]

prepare_local_plan(plan)[source]#

Implementation of the StorageReader method

Return type:

LoadPlan

read_data(plan,planner)[source]#

Reads torch save data on the coordinator rank, and broadcast afterwardsthis incurrs a communication cost, but avoids having to loadthe entire checkpoint on each rank, hopefully preventing OOM issues

Return type:

Future[None]

read_metadata()[source]#

Extends the default StorageReader to support building the metadata file

Return type:

Metadata

reset(checkpoint_id=None)[source]#

Implementation of the StorageReader method

set_up_storage_reader(metadata,is_coordinator)[source]#

Implementation of the StorageReader method

classmethodvalidate_checkpoint_id(checkpoint_id)[source]#

Implementation of the StorageReader method

Return type:

bool

classtorch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True,flatten_sharded_tensors=True,allow_partial_load=False)[source]#

Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict,avoiding the need to read metadata from disk. This is useful when reading formats which don’t have ametadata file, like Torch Save files.

. N.B. Intended to be used with BroadcastingTorchSaveReader

Warning

Current implementation only supports loading Tensors.

>>>sd={"mode":model}>>>dcp.load(>>>sd,>>>storage_reader=BroadcastingTorchSaveReader(),>>>planner=DynamicMetaLoadPlanner(),>>>checkpoint_id="path_to_model.pt">>>)
set_up_planner(state_dict,metadata=None,is_coordinator=False)[source]#

Setups of the planner, extnding default behavior by creating the Metadata object from the state dict

The following experimental interfaces are provided for improved observability in production environments:

On this page