nemo_rl.utils.native_checkpoint#
Checkpoint management utilities for HF models.
Module Contents#
Classes#
Helper class for tracking model state in distributed checkpointing. | |
Helper class for tracking optimizer state in distributed checkpointing. |
Functions#
Save a checkpoint of the model and optionally optimizer state. | |
Load a model weights and optionally optimizer state. | |
Convert a Torch DCP checkpoint to a Hugging Face checkpoint. |
API#
- classnemo_rl.utils.native_checkpoint.ModelState(model:torch.nn.Module)#
Bases:
torch.distributed.checkpoint.stateful.StatefulHelper class for tracking model state in distributed checkpointing.
This class is compliant with the Stateful protocol, allowing DCP to automaticallycall state_dict/load_state_dict as needed in the dcp.save/load APIs.
- Parameters:
model – The PyTorch model to track.
Initialization
- state_dict()→dict[str,Any]#
Get the model’s state dictionary.
- Returns:
Dictionary containing the model’s state dict with CPU offloading enabled.
- Return type:
dict
- load_state_dict(state_dict:dict[str,Any])→None#
Load the state dictionary into the model.
- Parameters:
state_dict (dict) – State dictionary to load.
- classnemo_rl.utils.native_checkpoint.OptimizerState(
- model:torch.nn.Module,
- optimizer:torch.optim.Optimizer,
- scheduler:Optional[Any]=None,
Bases:
torch.distributed.checkpoint.stateful.StatefulHelper class for tracking optimizer state in distributed checkpointing.
This class is compliant with the Stateful protocol, allowing DCP to automaticallycall state_dict/load_state_dict as needed in the dcp.save/load APIs.
- Parameters:
model – The PyTorch model associated with the optimizer.
optimizer – The optimizer to track.
scheduler – Optional learning rate scheduler.
Initialization
- state_dict()→dict[str,Any]#
Get the optimizer and scheduler state dictionaries.
- Returns:
Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled.
- Return type:
dict
- load_state_dict(state_dict:dict[str,Any])→None#
Load the state dictionaries into the optimizer and scheduler.
- Parameters:
state_dict (dict) – State dictionary containing optimizer and scheduler states to load.
- nemo_rl.utils.native_checkpoint.save_checkpoint(
- model:torch.nn.Module,
- weights_path:str,
- optimizer:Optional[torch.optim.Optimizer]=None,
- scheduler:Optional[Any]=None,
- optimizer_path:Optional[str]=None,
- tokenizer:Optional[Any]=None,
- tokenizer_path:Optional[str]=None,
Save a checkpoint of the model and optionally optimizer state.
- Parameters:
model – The PyTorch model to save
weights_path – Path to save model weights
optimizer – Optional optimizer to save
scheduler – Optional scheduler to save
optimizer_path – Path to save optimizer state (required if optimizer provided)
tokenizer – Optional tokenizer to save
tokenizer_path – Path to save tokenizer state (required if tokenizer provided)
- nemo_rl.utils.native_checkpoint.load_checkpoint(
- model:torch.nn.Module,
- weights_path:str,
- optimizer:Optional[torch.optim.Optimizer]=None,
- scheduler:Optional[Any]=None,
- optimizer_path:Optional[str]=None,
Load a model weights and optionally optimizer state.
- Parameters:
model – The PyTorch model whose weights to update
weights_path – Path to load model weights from
optimizer – Optional optimizer to load state into
scheduler – Optional scheduler to load state into
optimizer_path – Path to load optimizer state from (required if optimizer provided)
- nemo_rl.utils.native_checkpoint.convert_dcp_to_hf(
- dcp_ckpt_path:str,
- hf_ckpt_path:str,
- model_name_or_path:str,
- tokenizer_name_or_path:str,
- overwrite:bool=False,
- hf_overrides:Optional[dict[str,Any]]={},
Convert a Torch DCP checkpoint to a Hugging Face checkpoint.
This is not an optimized utility. If checkpoint is too large, consider saving DCP during trainingand using this utility to convert to HF format.
- Parameters:
dcp_ckpt_path (str) – Path to DCP checkpoint
hf_ckpt_path (str) – Path to save HF checkpoint
model_name_or_path (str) – Model name or path for config
tokenizer_name_or_path (str,optional) – Tokenizer name or path.Defaults to model_name_or_path if None.
overwrite (bool,optional) – Whether to overwrite existing checkpoint. Defaults to False.
- Returns:
Path to the saved HF checkpoint
- Return type:
str
- Raises:
FileExistsError – If HF checkpoint already exists and overwrite is False