- Notifications
You must be signed in to change notification settings - Fork655
Support fp32 head for qwen and internlm models#4160
Support fp32 head for qwen and internlm models#4160RunningLeon wants to merge 8 commits intoInternLM:mainfrom
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Pull request overview
This PR adds support for FP32 precision language model heads for Qwen and InternLM model families. The feature allows the embedding and lm_head layers to compute in FP32 for improved numerical stability while the rest of the model runs in lower precision (e.g., FP16/BF16). This is enabled through a newenforce_fp32_head configuration option passed viahf_overrides.
Changes:
- Added
DeployModelMixinV1base class with FP32 head support - Enhanced
ParallelEmbeddingwithforce_dtypeparameter for FP32 weight storage - Refactored 12+ model classes to use the new mixin and build methods
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 13 comments.
Show a summary per file
| File | Description |
|---|---|
| lmdeploy/pytorch/nn/embedding.py | Addedforce_dtype parameter to ParallelEmbedding for fp32 embeddings with dtype conversion on output |
| lmdeploy/pytorch/models/utils/model.py | IntroducedDeployModelMixinV1 withbuild_lm_head() andget_logits() methods supporting fp32 head |
| lmdeploy/pytorch/config.py | Added config handling to extract and propagateenforce_fp32_head from hf_overrides |
| lmdeploy/pytorch/models/qwen*.py | Refactored Qwen models to use DeployModelMixinV1 and ParallelEmbedding with fp32 support |
| lmdeploy/pytorch/models/internlm*.py | Refactored InternLM models to use DeployModelMixinV1 and ParallelEmbedding with fp32 support |
| lmdeploy/pytorch/models/internvl*.py | Updated InternVL models to use DeployModelMixinV1 and delegate get_lm_head to language_model |
| lmdeploy/pytorch/models/phi3*.py | Updated Phi3 models to use DeployModelMixinV1 |
| lmdeploy/pytorch/models/qwen*_vl.py | Updated Qwen VL models to use DeployModelMixinV1 and ParallelEmbedding |
| lmdeploy/pytorch/models/gpt_oss.py | Updated GPT OSS model to use DeployModelMixinV1 and ParallelEmbedding |
Comments suppressed due to low confidence (1)
lmdeploy/pytorch/models/internlm2.py:318
- The InternLM2ForCausalLM class uses
self.outputas the name for its language model head, but the parent classDeployModelMixinV1.get_lm_head()expects the attribute to be namedself.lm_head. This mismatch will cause an AttributeError when the inheritedget_logitsmethod tries to accessself.get_lm_head().weight.dtype.
You need to either:
- Override
get_lm_head()in InternLM2ForCausalLM to returnself.output, or - Keep the existing
get_logits()override and update it to match the fp32 head behavior from DeployModelMixinV1
class InternLM2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """Rewrote model of InternLM2ForCausalLM.""" packed_modules_mapping = { 'gate_up_proj': [ 'w1', 'w3', ], } def __init__(self, config: PretrainedConfig, ctx_mgr: StepContextManager, dtype: torch.dtype = None, device: torch.device = None): super().__init__() self.config = config self.ctx_mgr = ctx_mgr # build Model self.model = InternLM2Model(config, dtype=dtype, device=device) # build lm_head self.output = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, **kwargs, ): """Model forward, return logits.""" hidden_states = self.model( input_ids=input_ids, position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, ) return hidden_states def get_logits(self, hidden_states: torch.Tensor): """Compute logits of the model output.""" return self.output(hidden_states)💡Add Copilot custom instructions for smarter, more guided reviews.Learn how to get started.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
| head_dtype = self.get_lm_head().weight.dtype | ||
| if hidden_states.dtype != head_dtype: | ||
| hidden_states = hidden_states.to(dtype=head_dtype) | ||
| hidden_states = self.get_lm_head()(hidden_states) |
CopilotAIJan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Theget_logits method callsself.get_lm_head() twice (lines 59 and 62). While this works correctly, it's inefficient as it involves two method calls and could potentially cause issues ifget_lm_head() has side effects. Consider storing the result in a local variable to avoid the duplicate call.
| head_dtype=self.get_lm_head().weight.dtype | |
| ifhidden_states.dtype!=head_dtype: | |
| hidden_states=hidden_states.to(dtype=head_dtype) | |
| hidden_states=self.get_lm_head()(hidden_states) | |
| lm_head=self.get_lm_head() | |
| head_dtype=lm_head.weight.dtype | |
| ifhidden_states.dtype!=head_dtype: | |
| hidden_states=hidden_states.to(dtype=head_dtype) | |
| hidden_states=lm_head(hidden_states) |
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
grimoire commentedJan 21, 2026
Should we put |
Uh oh!
There was an error while loading.Please reload this page.
Motivation
Support fp32 head for qwen and internlm models
Modification
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist