nemo_rl.utils.native_checkpoint#

Checkpoint management utilities for HF models.

Module Contents#

Classes#

ModelState

Helper class for tracking model state in distributed checkpointing.

OptimizerState

Helper class for tracking optimizer state in distributed checkpointing.

Functions#

save_checkpoint

Save a checkpoint of the model and optionally optimizer state.

load_checkpoint

Load a model weights and optionally optimizer state.

convert_dcp_to_hf

Convert a Torch DCP checkpoint to a Hugging Face checkpoint.

API#

classnemo_rl.utils.native_checkpoint.ModelState(model:torch.nn.Module)#

Bases:torch.distributed.checkpoint.stateful.Stateful

Helper class for tracking model state in distributed checkpointing.

This class is compliant with the Stateful protocol, allowing DCP to automaticallycall state_dict/load_state_dict as needed in the dcp.save/load APIs.

Parameters:

model – The PyTorch model to track.

Initialization

state_dict()dict[str,Any]#

Get the model’s state dictionary.

Returns:

Dictionary containing the model’s state dict with CPU offloading enabled.

Return type:

dict

load_state_dict(state_dict:dict[str,Any])None#

Load the state dictionary into the model.

Parameters:

state_dict (dict) – State dictionary to load.

classnemo_rl.utils.native_checkpoint.OptimizerState(
model:torch.nn.Module,
optimizer:torch.optim.Optimizer,
scheduler:Optional[Any]=None,
)#

Bases:torch.distributed.checkpoint.stateful.Stateful

Helper class for tracking optimizer state in distributed checkpointing.

This class is compliant with the Stateful protocol, allowing DCP to automaticallycall state_dict/load_state_dict as needed in the dcp.save/load APIs.

Parameters:
  • model – The PyTorch model associated with the optimizer.

  • optimizer – The optimizer to track.

  • scheduler – Optional learning rate scheduler.

Initialization

state_dict()dict[str,Any]#

Get the optimizer and scheduler state dictionaries.

Returns:

Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled.

Return type:

dict

load_state_dict(state_dict:dict[str,Any])None#

Load the state dictionaries into the optimizer and scheduler.

Parameters:

state_dict (dict) – State dictionary containing optimizer and scheduler states to load.

nemo_rl.utils.native_checkpoint.save_checkpoint(
model:torch.nn.Module,
weights_path:str,
optimizer:Optional[torch.optim.Optimizer]=None,
scheduler:Optional[Any]=None,
optimizer_path:Optional[str]=None,
tokenizer:Optional[Any]=None,
tokenizer_path:Optional[str]=None,
)None#

Save a checkpoint of the model and optionally optimizer state.

Parameters:
  • model – The PyTorch model to save

  • weights_path – Path to save model weights

  • optimizer – Optional optimizer to save

  • scheduler – Optional scheduler to save

  • optimizer_path – Path to save optimizer state (required if optimizer provided)

  • tokenizer – Optional tokenizer to save

  • tokenizer_path – Path to save tokenizer state (required if tokenizer provided)

nemo_rl.utils.native_checkpoint.load_checkpoint(
model:torch.nn.Module,
weights_path:str,
optimizer:Optional[torch.optim.Optimizer]=None,
scheduler:Optional[Any]=None,
optimizer_path:Optional[str]=None,
)None#

Load a model weights and optionally optimizer state.

Parameters:
  • model – The PyTorch model whose weights to update

  • weights_path – Path to load model weights from

  • optimizer – Optional optimizer to load state into

  • scheduler – Optional scheduler to load state into

  • optimizer_path – Path to load optimizer state from (required if optimizer provided)

nemo_rl.utils.native_checkpoint.convert_dcp_to_hf(
dcp_ckpt_path:str,
hf_ckpt_path:str,
model_name_or_path:str,
tokenizer_name_or_path:str,
overwrite:bool=False,
hf_overrides:Optional[dict[str,Any]]={},
)str#

Convert a Torch DCP checkpoint to a Hugging Face checkpoint.

This is not an optimized utility. If checkpoint is too large, consider saving DCP during trainingand using this utility to convert to HF format.

Parameters:
  • dcp_ckpt_path (str) – Path to DCP checkpoint

  • hf_ckpt_path (str) – Path to save HF checkpoint

  • model_name_or_path (str) – Model name or path for config

  • tokenizer_name_or_path (str,optional) – Tokenizer name or path.Defaults to model_name_or_path if None.

  • overwrite (bool,optional) – Whether to overwrite existing checkpoint. Defaults to False.

Returns:

Path to the saved HF checkpoint

Return type:

str

Raises:

FileExistsError – If HF checkpoint already exists and overwrite is False