Checkpoint Loading#

The PyTorch backend provides a flexible and extensible infrastructure for loading model checkpoints from different formats, such as HuggingFace (HF). This system allows you to load models from various sources (e.g., HuggingFace or custom formats) by implementing the required components, such as the checkpoint’s weight loader, mapper, and configuration parser.

Table of Contents#

  1. Overview

  2. Core Components

  3. Built-in Checkpoint Formats

  4. Using Checkpoint Loaders

  5. Creating Custom Checkpoint Loaders

Overview#

The checkpoint loading design is built around a plugin-like architecture that is separated into four distinct components:

  • Checkpoint Loaders: Orchestrate the loading process for specific formats

  • Config Loaders: Handle model configuration parsing and validation

  • Weight Loaders: Manage the actual loading of model weights from storage into memory

  • Weight Mappers: Map and transform loaded weights to TensorRT LLM model’s definition

This modular design allows for easy extension to support new checkpoint formats while maintaining backward compatibility and performance optimizations. By separating the checkpoint loading components into four different subcomponents, any user can employ any relevant previous work while also introducing their own custom checkpoint-specific components.

If one wishes to support a new checkpoint format, they must implement all four components.Likewise, if the format shares some components with an already supported framework (e.g., HF), only the custom-specific components need to be implemented.

Core Components#

BaseCheckpointLoader#

TheBaseCheckpointLoader is the central base interface for all checkpoint loading required operators. It provides a unified API regardless of the underlying checkpoint format. This interface is responsible for holding and exposing all objects required for the loading and parsing process.

Key Methods:

  • load_config(checkpoint_dir,**kwargs): Loads and returns aModelConfig object

  • load_weights(checkpoint_dir,**kwargs): Loads and returns a dictionary of weights

  • get_initialized_weight_mapper(model,config): Returns a runtime initialized weight mapper for the model

  • cleanup(): Releases resources and cleans up internal state

BaseConfigLoader#

Responsible for loading model configurations from checkpoint directories and parsing them into TRTLLMModelConfig:

fromtensorrt_llm._torch.models.checkpoints.base_config_loaderimportBaseConfigLoaderclassCustomConfigLoader(BaseConfigLoader):defload(self,checkpoint_dir:str,**kwargs)->ModelConfig:# Load and parse configuration from your custom formatpretrained_config=self._get_pretrained_config(checkpoint_dir,**kwargs)returnModelConfig(pretrained_config=pretrained_config,...)def_get_pretrained_config(self,checkpoint_dir,**kwargs):...

BaseWeightLoader#

Handles the loading of model weights from storage:

fromtensorrt_llm._torch.models.checkpoints.base_weight_loaderimportBaseWeightLoaderclassCustomWeightLoader(BaseWeightLoader):defload_weights(self,checkpoint_dir:str)->dict[str,Any]:# Load weights from your custom format# Return a dictionary mapping parameter names to tensorsreturnweights_dict

BaseWeightMapper#

Transforms weights between different naming conventions and applies model-specific transformations into TRTLLM model’s object.

Built-in Checkpoint Formats#

HuggingFace Format#

Currently, HF checkpoint loader is the primary built-in format, supporting:

  • Weights loading (.safetensors/.bin/.pth) - Loading HF compatible weights from disk

  • Configuration parser - Parsing HF stored configuration information to TRTLLMModelConfig object

  • Weights Mapping - Converting HF weights into TRTLLM compatible representation

Using Checkpoint Loaders#

Basic Usage#

There are two main approaches to trigger the use of checkpoint loading objects.

The first approach, through llm-api, as shown in the following example:

fromtensorrt_llmimportLLMhf_model_dir="llama-models-v2/llama-v2-13b-hf"llm=LLM(model=hf_model_dir)

In this example,HfCheckpointLoader will be selected by default.

To explicitly set the checkpoint loader, you need to call the required checkpoint-specific loader

fromtensorrt_llmimportLLMfromtensorrt_llm._torch.models.checkpoints.hf.checkpoint_loaderimportHfCheckpointLoaderhf_model_dir="llama-models-v2/llama-v2-13b-hf"llm=LLM(model=hf_model_dir,checkpoint_loader=HfCheckpointLoader())

Similarly, if one wants to use a basic implemented checkpoint loader, but with a specific subcomponent, they can provide any specific subcomponent upon need

fromtensorrt_llmimportLLMfromtensorrt_llm._torch.models.checkpoints.hf.checkpoint_loaderimportHfCheckpointLoaderhf_model_dir="llama-models-v2/llama-v2-13b-hf"llm=LLM(model=hf_model_dir,checkpoint_loader=HfCheckpointLoader(weight_loader=MyCustomWeightLoader()))

In the second approach, one can directly use the components of the checkpoint loading.

fromtensorrt_llm._torch.models.checkpoints.hf.gemma3_weight_mapperimport \Gemma3HfWeightMapperfromtensorrt_llm._torch.models.modeling_gemma3importGemma3ForCausalLMgemma3=Gemma3ForCausalLM(model_config)weight_mapper=Gemma3HfWeightMapper()weight_mapper.init_model_and_config(gemma3,model_config)gemma3.load_weights(hf_gemma3.state_dict(),weight_mapper)

Creating Custom Checkpoint Loaders#

To support a new checkpoint format, you need to implement all four components. This section provides minimal templates for each component.

When to Create Custom Components#

  • Complete New Format: Implement all four components when supporting a completely new checkpoint format

  • Custom Weight Storage: Only implement a custom weight loader if you have a unique weight storage format (e.g., custom binary format, database storage, etc.)

  • Custom Configuration: Only implement a custom config loader if your configuration format cannot be parsed by existing parsers.

  • Custom Weight Mapping: Only implement a custom weight mapper if your model has unique weight naming or transformation requirements that are checkpoint-specific.

Step 1: Create the Checkpoint Loader#

fromtypingimportOptionalfromtensorrt_llm._torch.models.checkpoints.base_checkpoint_loaderimportBaseCheckpointLoaderfromtensorrt_llm._torch.models.checkpoints.base_config_loaderimportBaseConfigLoaderfromtensorrt_llm._torch.models.checkpoints.base_weight_loaderimportBaseWeightLoaderfromtensorrt_llm._torch.models.checkpoints.base_weight_mapperimportBaseWeightMapperfromtensorrt_llm._torch.models.modeling_utilsimportregister_checkpoint_loader@register_checkpoint_loader("CUSTOM_FORMAT")classCustomCheckpointLoader(BaseCheckpointLoader):def__init__(self,*,weight_loader:Optional[BaseWeightLoader]=None,weight_mapper:Optional[BaseWeightMapper]=None,config_loader:Optional[BaseConfigLoader]=None):self._weight_loader=weight_loaderorself.get_default_weight_loader()self._config_loader=config_loaderorself.get_default_config_loader()self._weight_mapper=weight_mapperself._checkpoint_format="CUSTOM_FORMAT"defget_default_weight_loader(self)->BaseWeightLoader:returnCustomWeightLoader()defget_default_config_loader(self)->BaseConfigLoader:returnCustomConfigLoader()

Step 2: Create the Checkpoint Weight Loader#

fromtypingimportAnyfromtensorrt_llm._torch.models.checkpoints.base_weight_loaderimportBaseWeightLoaderfromtensorrt_llm._torch.models.modeling_utilsimportregister_checkpoint_weight_loader@register_checkpoint_weight_loader("CUSTOM_FORMAT")classCustomWeightLoader(BaseWeightLoader):defload_weights(self,checkpoint_dir:str,**kwargs)->dict[str,Any]:"""        Load weights from your custom format.        Args:            checkpoint_dir: Directory containing checkpoint files            **kwargs: Additional loading parameters        Returns:            Dictionary mapping parameter names to tensors        """weights={}# Implement your custom weight loading logic here# Examples:# - Load from custom binary files# - Load from databases# - Load from compressed archives# - Apply custom preprocessingreturnweights

Step 3: Create the Checkpoint Config Loader#

fromtensorrt_llm._torch.model_configimportModelConfigfromtensorrt_llm._torch.models.checkpoints.base_config_loaderimportBaseConfigLoaderfromtensorrt_llm._torch.models.modeling_utilsimportregister_config_loader@register_config_loader("CUSTOM_FORMAT")classCustomConfigLoader(BaseConfigLoader):defload(self,checkpoint_dir:str,**kwargs)->ModelConfig:"""        Load and parse configuration from your custom format.        Args:            checkpoint_dir: Directory containing configuration files            **kwargs: Additional loading parameters        Returns:            ModelConfig object containing parsed configuration        """# Load your custom configuration format# Examples:# - Parse YAML/TOML files# - Convert from proprietary formatspretrained_config=self._load_pretrained_config(checkpoint_dir,**kwargs)returnModelConfig(pretrained_config=pretrained_config,# Add other ModelConfig parameters as needed)def_load_pretrained_config(self,checkpoint_dir:str,**kwargs):"""Load the raw configuration from your custom format."""pass

Step 4: Create the Checkpoint Weight Mapper#

fromtorchimportnnfromtensorrt_llm._torch.models.checkpoints.base_weight_mapperimportBaseWeightMapperfromtensorrt_llm._torch.models.modeling_utilsimportregister_mapper@register_mapper("CUSTOM_FORMAT")classCustomWeightMapper(BaseWeightMapper):def__init__(self):super().__init__()# Define any weight transformation callbacksself._callbacks=[# Add your custom weight transformation functions# self._custom_transform_function,]defmap_weights(self)->None:"""        Define mappings between source and target weight names.        """self.mapping.update({# Map source names to target names# 'target_module_name': ['source_param1', 'source_param2'],# Example: 'qkv_proj': ['q_proj', 'k_proj', 'v_proj']})defapply_callbacks(self,module:nn.Module,module_name:str,module_names_breakdown:list[str],weights:dict)->list[dict]:"""        Apply weight transformations for modules that require special handling.        Args:            module: The target module            module_name: The specific module name being processed            module_names_breakdown: Module path components            weights: Source weights dictionary        Returns:            List of transformed weight dictionaries        """module_weights=[]fornew_nameinself._mapping[module_name]:# Filter weights for this specific parameterfw=self.filter_weights('.'.join(module_names_breakdown+[new_name]),weights)# Apply transformation callbacksforcallbackinself._callbacks:fw=callback(module,new_name,fw)module_weights.append(fw)returnmodule_weightsdefshould_skip_module(self,module_name:str)->bool:"""        Define which modules should be skipped during loading.        """# Add logic to skip specific modules based on your requirements# Examples:# - Skip LoRA-specific modules# - Skip temporary/auxiliary modulesreturnsuper().should_skip_module(module_name)

Note: when creating a custom mapper, you can either define a checkpoint-format-specific mapper. For example:

@register_mapper("CUSTOM_FORMAT")classCustomWeightMapper(BaseWeightMapper)

Alternatively, you can define a checkpoint-model-specific mapper. For example:

@register_mapper("CUSTOM_FORMAT","Gemma3ForCausalLM")classCustomWeightMapper(BaseWeightMapper)

By setting the model name, the registered mapper will be asscoiated with the specific model.