Adding a New Model#
Table of Contents#
Introduction#
This guide provides a step-by-step process for adding a new model in PyTorch Backend.
Prerequisites#
Before you begin, ensure you have the following:
A working installation of TensorRT-LLM. Follow theseinstructions.
Step-by-Step Guide#
Model Configuration#
Suppose you want to support a new model namedMyModel. If the model is already supported in HuggingFace’s transformers, you should bring the PyTorch modeling code and reuse HuggingFace’s configuration class. For example, ourtensorrt_llm/_torch/models/modeling_llama.py was adapted from HuggingFace’smodeling_llama.py; in the modeling code, we reuse the configuration class:
fromtransformersimportLlamaConfig
If the model is not registered in HuggingFace’s transformers, you need to define the configuration class in yourconfiguration_mymodel.py following HuggingFace’sconfiguration_llama.py:
fromtransformers.configuration_utilsimportPretrainedConfigclassMyConfig(PretrainedConfig):def__init__(self,...):...
Model Definition#
Remove any unnecessary code (e.g., training-specific code), and then rewrite some PyTorch modules. For a typical Transformer decoder model, you need to implement yourmodeling_mymodel.py like this:
fromtypingimportOptionalimporttorchfromtorchimportnnfromtensorrt_llm._torch.attention_backendimportAttentionMetadatafromtensorrt_llm._torch.model_configimportModelConfigfromtensorrt_llm._torch.models.modeling_utilsimportDecoderModel,DecoderModelForCausalLMfromtensorrt_llm._torch.modules.attentionimportAttentionfromtensorrt_llm._torch.modules.decoder_layerimportDecoderLayerfromconfiguration_mymodelimportMyConfigclassMyAttention(Attention):def__init__(self,model_config:ModelConfig[MyConfig],layer_idx:Optional[int]=None):# Use model_config to initialize the Attention modulesuper().__init__(...)classMyDecoderLayer(DecoderLayer):def__init__(self,model_config:ModelConfig[MyConfig],layer_idx:int):super().__init__()# Use model_config to initialize the submodulesself.input_layernorm=...self.self_attn=MyAttention(model_config,layer_idx)self.post_attention_layernorm=...self.mlp=...defforward(self,hidden_states:torch.Tensor,attn_metadata:AttentionMetadata,**kwargs):# Define the forward computation of a single decoder layer...classMyModel(DecoderModel):def__init__(self,model_config:ModelConfig[MyConfig]):super().__init__(model_config)# Use model_config to initialize the submodulesself.embed_tokens=...self.layers=nn.ModuleList([MyDecoderLayer(model_config,layer_idx)forlayer_idxinrange(model_config.pretrained_config.num_hidden_layers)])defforward(self,attn_metadata:AttentionMetadata,input_ids:Optional[torch.IntTensor]=None,position_ids:Optional[torch.IntTensor]=None,inputs_embeds:Optional[torch.FloatTensor]=None):# Define the forward computation of the model...classMyModelForCausalLM(DecoderModelForCausalLM[MyModel,MyConfig]):def__init__(self,model_config:ModelConfig[MyConfig]):super().__init__(MyModel(model_config),config=model_config,hidden_size=model_config.pretrained_config.hidden_size,vocab_size=model_config.pretrained_config.vocab_size)
Note thatMyAttention inherits from ourAttention module (intensorrt_llm/_torch/modules/attention.py), so that the attention computation is compatible with our PyTorch runtime. Related to this, module inputs should also be adapted:
The
attn_metadatastores the metadata from the batched input and KV cache for the attention backend. It is created by and passed from the runtime, and model developers need to ensure thatattn_metadatais correctly passed to the attention module.The input tensors (i.e.,
input_ids,position_ids,hidden_states) are in the packed mode. The first dimension corresponds to the number of tokens in a batch.
Additionally,MyDecoderLayer,MyModel, andMyModelForCausalLM are subclasses ofDecoderLayer,DecoderModel, andDecoderModelForCausalLM respectively. The base classes define interfaces and provide a generic scaffolding to define model layers, load weights, etc.
Optionally, you may replace the native PyTorch modules with our implementations to enable features or achieve higher performance:
Linear(intensorrt_llm/_torch/modules/linear.py): Enables tensor parallelism and quantization.Embedding(intensorrt_llm/_torch/modules/embedding.py): Enables tensor parallelism for embedding.RotaryEmbedding(intensorrt_llm/_torch/modules/rotary_embedding.py): Enables performant rotary embedding.RMSNorm(intensorrt_llm/_torch/modules/rms_norm.py): Enables performant RMS norm.
For a concrete reference, check outtensorrt_llm/_torch/models/modeling_llama.py.
Weight Loading#
The base classDecoderModelForCausalLM provides aload_weights method that loads the weights from the checkpoint file and assigns them to the corresponding layers in the model. However, if the default method does not work forMyModelForCausalLM, you need to implement your ownload_weights:
classMyModelForCausalLM(DecoderModelForCausalLM[MyModel,MyConfig]):defload_weights(self,weights:dict):# Define the weight loading logic...
For example, Huggingface’s LLaMA model uses three linear layers for Q/K/V projections, resulting in three weight tensors in the checkpoint:
>>>weights{ ..., "model.layers.0.self_attn.q_proj.weight": torch.Tensor([hidden_size, hidden_size]), "model.layers.0.self_attn.k_proj.weight": torch.Tensor([hidden_size, hidden_size]), "model.layers.0.self_attn.v_proj.weight": torch.Tensor([hidden_size, hidden_size]), ...,}
However, our LLaMA model fuses the three layers into one linear layer:
>>>llama.model.layers[0].self_attn.qkv_proj.weight.datatorch.Tensor([hidden_size * 3, hidden_size])
Hence,load_weights needs to collect the three weight tensors from the original checkpoint, concatenate them, and assign them to the fused linear layer. Considering tensor parallelism and quantization, the process would be more complicated. We recommend calling the predefined module-levelload_weights (e.g.,Linear andEmbedding) when implementing your model-levelload_weights method.
Overall,load_weights should handle any discrepancy betweenMyModelForCausalLM and the weights loaded from the checkpoint, so thatMyModelForCausalLM can perform forward computation equivalent to the original model.
Model Registration#
The new model needs to be registered so that it can be recognized by the PyTorch runtime. The registration can be done simply by adding theregister_auto_model decorator forMyModelForCausalLM:
fromtensorrt_llm._torch.models.modeling_utilsimportregister_auto_model@register_auto_model("MyModelForCausalLM")classMyModelForCausalLM(DecoderModelForCausalLM[MyModel,MyConfig]):def__init__(self,model_config:ModelConfig[MyConfig]):...
Core Models#
To add the new model to core models,modeling_mymodel.py (and potentiallyconfiguration_mymodel.py) should be placed intensorrt_llm/_torch/models. Then, you need to import the modeling code intensorrt_llm/_torch/models/__init__.py:
from.modeling_mymodelimportMyModelForCausalLM__all__=[...,"MyModelForCausalLM",]
Out-of-Tree Models#
Alternatively, you can register the new model as an out-of-tree model, so that you can use the new model without touching the TensorRT LLM codebase. To do so, placemodeling_mymodel.py (and potentiallyconfiguration_mymodel.py) in your working directory, and import the modeling code in your script:
fromtensorrt_llmimportLLMimportmodeling_mymodeldefmain():llm=LLM(...)if__name__=='__main__':main()
We provide an out-of-tree modeling example inexamples/pytorch/out_of_tree_example. The model is implemented inmodeling_opt.py and you can run the example by:
pythonexamples/pytorch/out_of_tree_example/main.py