Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Support fp32 head for qwen and internlm models#4160

Open
RunningLeon wants to merge 8 commits intoInternLM:mainfrom
RunningLeon:fp32-head
Open

Support fp32 head for qwen and internlm models#4160
RunningLeon wants to merge 8 commits intoInternLM:mainfrom
RunningLeon:fp32-head

Conversation

@RunningLeon
Copy link
Collaborator

@RunningLeonRunningLeon commentedNov 27, 2025
edited
Loading

Motivation

Support fp32 head for qwen and internlm models

Modification

fromlmdeployimportpipeline,GenerationConfig,PytorchEngineConfigif__name__=='__main__':backend_config=PytorchEngineConfig(hf_overrides=dict(enforce_fp32_head=True))model_path='Qwen/Qwen3-30B-A3B'pipe=pipeline(model_path,backend_config=backend_config)resps=pipe(['Hi.'])forresinresps:print(res)

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@RunningLeonRunningLeon marked this pull request as ready for reviewJanuary 21, 2026 11:53
CopilotAI review requested due to automatic review settingsJanuary 21, 2026 11:53
@RunningLeonRunningLeon changed the title[WIP]: Support fp32 head for qwen and internlm modelsSupport fp32 head for qwen and internlm modelsJan 21, 2026
Copy link
Contributor

CopilotAI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Pull request overview

This PR adds support for FP32 precision language model heads for Qwen and InternLM model families. The feature allows the embedding and lm_head layers to compute in FP32 for improved numerical stability while the rest of the model runs in lower precision (e.g., FP16/BF16). This is enabled through a newenforce_fp32_head configuration option passed viahf_overrides.

Changes:

  • AddedDeployModelMixinV1 base class with FP32 head support
  • EnhancedParallelEmbedding withforce_dtype parameter for FP32 weight storage
  • Refactored 12+ model classes to use the new mixin and build methods

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 13 comments.

Show a summary per file
FileDescription
lmdeploy/pytorch/nn/embedding.pyAddedforce_dtype parameter to ParallelEmbedding for fp32 embeddings with dtype conversion on output
lmdeploy/pytorch/models/utils/model.pyIntroducedDeployModelMixinV1 withbuild_lm_head() andget_logits() methods supporting fp32 head
lmdeploy/pytorch/config.pyAdded config handling to extract and propagateenforce_fp32_head from hf_overrides
lmdeploy/pytorch/models/qwen*.pyRefactored Qwen models to use DeployModelMixinV1 and ParallelEmbedding with fp32 support
lmdeploy/pytorch/models/internlm*.pyRefactored InternLM models to use DeployModelMixinV1 and ParallelEmbedding with fp32 support
lmdeploy/pytorch/models/internvl*.pyUpdated InternVL models to use DeployModelMixinV1 and delegate get_lm_head to language_model
lmdeploy/pytorch/models/phi3*.pyUpdated Phi3 models to use DeployModelMixinV1
lmdeploy/pytorch/models/qwen*_vl.pyUpdated Qwen VL models to use DeployModelMixinV1 and ParallelEmbedding
lmdeploy/pytorch/models/gpt_oss.pyUpdated GPT OSS model to use DeployModelMixinV1 and ParallelEmbedding
Comments suppressed due to low confidence (1)

lmdeploy/pytorch/models/internlm2.py:318

  • The InternLM2ForCausalLM class usesself.output as the name for its language model head, but the parent classDeployModelMixinV1.get_lm_head() expects the attribute to be namedself.lm_head. This mismatch will cause an AttributeError when the inheritedget_logits method tries to accessself.get_lm_head().weight.dtype.

You need to either:

  1. Overrideget_lm_head() in InternLM2ForCausalLM to returnself.output, or
  2. Keep the existingget_logits() override and update it to match the fp32 head behavior from DeployModelMixinV1
class InternLM2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):    """Rewrote model of InternLM2ForCausalLM."""    packed_modules_mapping = {        'gate_up_proj': [            'w1',            'w3',        ],    }    def __init__(self,                 config: PretrainedConfig,                 ctx_mgr: StepContextManager,                 dtype: torch.dtype = None,                 device: torch.device = None):        super().__init__()        self.config = config        self.ctx_mgr = ctx_mgr        # build Model        self.model = InternLM2Model(config, dtype=dtype, device=device)        # build lm_head        self.output = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)    def forward(        self,        input_ids: torch.Tensor,        position_ids: torch.Tensor,        past_key_values: List[List[torch.Tensor]],        attn_metadata: Any = None,        inputs_embeds: torch.Tensor = None,        **kwargs,    ):        """Model forward, return logits."""        hidden_states = self.model(            input_ids=input_ids,            position_ids=position_ids,            past_key_values=past_key_values,            attn_metadata=attn_metadata,            inputs_embeds=inputs_embeds,        )        return hidden_states    def get_logits(self, hidden_states: torch.Tensor):        """Compute logits of the model output."""        return self.output(hidden_states)

💡Add Copilot custom instructions for smarter, more guided reviews.Learn how to get started.

Comment on lines +59 to +62
head_dtype = self.get_lm_head().weight.dtype
if hidden_states.dtype != head_dtype:
hidden_states = hidden_states.to(dtype=head_dtype)
hidden_states = self.get_lm_head()(hidden_states)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Theget_logits method callsself.get_lm_head() twice (lines 59 and 62). While this works correctly, it's inefficient as it involves two method calls and could potentially cause issues ifget_lm_head() has side effects. Consider storing the result in a local variable to avoid the duplicate call.

Suggested change
head_dtype=self.get_lm_head().weight.dtype
ifhidden_states.dtype!=head_dtype:
hidden_states=hidden_states.to(dtype=head_dtype)
hidden_states=self.get_lm_head()(hidden_states)
lm_head=self.get_lm_head()
head_dtype=lm_head.weight.dtype
ifhidden_states.dtype!=head_dtype:
hidden_states=hidden_states.to(dtype=head_dtype)
hidden_states=lm_head(hidden_states)

Copilot uses AI. Check for mistakes.
@grimoire
Copy link
Collaborator

Should we putforce_fp32_head in BuildModelContext so we don't have to read and set args in every model?

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

Copilot code reviewCopilotCopilot left review comments

@grimoiregrimoiregrimoire approved these changes

@lvhan028lvhan028Awaiting requested review from lvhan028

At least 2 approving reviews are required to merge this pull request.

Assignees

No one assigned

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

3 participants

@RunningLeon@grimoire@lvhan028

Comments


[8]ページ先頭

©2009-2026 Movatter.jp