Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit263b2b7

Browse files
Moved vllm fq export code to separate files (#612)
## What does this PR do?**Type of change:** : Bug fix**Overview:** moved vLLM fakequant checkpoint export code to separate files:1. for HF export -> modelopt.torch.export.plugins.vllm_fq_hf2. for megatron export -> modelopt.torch.export.plugins.vllm_fq_megatron## UsageRefer to[README.md](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/096ee13ea62bbb0ce0a4e4128c439651374d6235/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip)## Testing- Tested HF approach by exporting bf16 model using QAT script andrunning vllm server, verified amax values match- Tested MCore approach by quantizing and exporting bf16 model usingquantize.sh and export.sh script and running vllm server, verified amaxvalues match- Tested using unit tests in`tests/gpu/torch/export/test_vllm_fq_hf_export.py` and`tests/gpu/torch/export/test_vllm_fq_megatron_export.py`## Before your PR is "*Ready for review*"<!-- If you haven't finished some of the above items you can still open`Draft` PR. -->- **Make sure you read and follow [Contributorguidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**and your commits are signed.- **Is this change backward compatible?**: Yes- **Did you write any new necessary tests?**: NA- **Did you add or update any necessary documentation?**: Yes- **Did you update[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:NA## Additional Information<!-- E.g. related issue. --><!-- This is an auto-generated comment: release notes by coderabbit.ai-->## Summary by CodeRabbit## Release Notes* **New Features*** Added dedicated export functions for vLLM fakequant checkpoint format,supporting both HuggingFace and Megatron Core models.* **Refactor*** Simplified export API by removing conditional export flags forcleaner, more predictable behavior.* Reorganized export functionality into focused plugin modules forimproved maintainability.<sub>✏️ Tip: You can customize this high-level summary in your reviewsettings.</sub><!-- end of auto-generated comment: release notes by coderabbit.ai -->---------Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent0a4f0a8 commit263b2b7

File tree

10 files changed

+355
-310
lines changed

10 files changed

+355
-310
lines changed

‎examples/vllm_serve/README.md‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ lm_eval --model local-completions --tasks gsm8k --model_args model=<model_name>,
5757

5858
Overwrite the calibrated amax value with prepared values from either QAT/PTQ.
5959

60-
Step 1: export the model with bf16 weights and amax values.
60+
Step 1: export the model with bf16 weights and amax values. To export the model:
6161

62-
- For HF modelset`export_bf16_weights_amax` to export the model with function`modelopt.torch.export.unified_export_hf.export_hf_checkpoint`.
63-
- For MCore model use`export_bf16_weights_amax` to export the model with function`modelopt.torch.export.unified_export_megatron.export_mcore_gpt_to_hf`.
62+
- For HF modeluse`modelopt.torch.export.export_hf_vllm_fq_checkpoint` function.
63+
- For MCore model use`modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq` function.
6464

6565
Step 2: configure <quant_amax.pth> from exported model using AMAX_FILE_PATH environment variable in step 1. For example:
6666

‎modelopt/torch/export/__init__.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .model_configimport*
2020
from .model_config_exportimport*
2121
from .model_utilsimport*
22+
from .pluginsimport*
2223
from .transformer_engineimport*
2324
from .unified_export_hfimport*
2425
from .unified_export_megatronimport*

‎modelopt/torch/export/plugins/__init__.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@
2121
from .megatron_importerimport*
2222

2323
from .hf_spec_exportimport*
24+
from .vllm_fakequant_hfimport*
25+
26+
withimport_plugin("vllm_fakequant_megatron"):
27+
from .vllm_fakequant_megatronimport*

‎modelopt/torch/export/plugins/vllm_fakequant.py‎

Lines changed: 0 additions & 125 deletions
This file was deleted.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Export HuggingFace model to vLLM fakequant checkpoint."""
16+
17+
frompathlibimportPath
18+
19+
importtorch
20+
importtorch.nnasnn
21+
22+
frommodelopt.torch.export.layer_utilsimportis_quantlinear
23+
frommodelopt.torch.quantization.utilsimportget_quantizer_state_dict
24+
25+
__all__= ["export_hf_vllm_fq_checkpoint"]
26+
27+
28+
defexport_hf_vllm_fq_checkpoint(
29+
model:nn.Module,
30+
export_dir:Path|str,
31+
):
32+
"""Exports the torch model weights and amax values separately.
33+
34+
This function:
35+
1. Extracts amax values for calibration
36+
2. Deletes all quantizer parameters from state dict to store only weights in original dtype
37+
3. Saves the model weights
38+
39+
Args:
40+
model: The quantized model to export
41+
export_dir: Directory to save the amax values
42+
43+
"""
44+
export_dir=Path(export_dir)
45+
export_dir.mkdir(parents=True,exist_ok=True)
46+
47+
amax_dict= {
48+
name+"._amax":param["_amax"].detach().clone().cpu()
49+
forname,paraminget_quantizer_state_dict(model).items()
50+
if"_amax"inparam
51+
}
52+
53+
# remove quantizer from model
54+
for_,moduleinmodel.named_modules():
55+
ifis_quantlinear(module):
56+
forattrin ["weight_quantizer","input_quantizer","output_quantizer"]:
57+
ifhasattr(module,attr):
58+
delattr(module,attr)
59+
module.export()
60+
torch.save(amax_dict,f"{export_dir}/quant_amax.pth")
61+
# Save model
62+
model.save_pretrained(export_dir,state_dict=model.state_dict(),save_modelopt_state=False)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Export Megatron Core Model to HuggingFace vLLM fakequant checkpoint."""
16+
17+
importos
18+
importtempfile
19+
frompathlibimportPath
20+
21+
importtorch
22+
23+
frommodelopt.torch.export.model_configimportQUANTIZATION_NONE
24+
frommodelopt.torch.export.unified_export_megatronimportGPTModelExporter
25+
26+
__all__= ["export_mcore_gpt_to_hf_vllm_fq"]
27+
28+
29+
defgather_mcore_vllm_fq_quantized_state_dict(
30+
model,state_dict:dict[str,torch.Tensor],save_directory:str|os.PathLike
31+
):
32+
"""Gather all quantized state dict from all ranks and save them to a file.
33+
34+
Args:
35+
state_dict: The state dictionary of the module.
36+
save_directory: The directory to save the quantized state dict.
37+
38+
Returns:
39+
The state dictionary of the module without quantized state.
40+
"""
41+
amax_state_dict= {
42+
k:v.detach().clone().cpu()fork,vinstate_dict.items()ifk.endswith("_amax")
43+
}
44+
45+
# Gather all amax dicts to rank 0
46+
world_size=torch.distributed.get_world_size()
47+
rank=torch.distributed.get_rank()
48+
49+
ifrank==0:
50+
# Rank 0 will collect all amax values
51+
all_amax_dicts= [None]*world_size
52+
torch.distributed.gather_object(amax_state_dict,all_amax_dicts,dst=0)
53+
54+
# Merge all amax dicts into one
55+
merged_amax_dict= {}
56+
foramax_dictinall_amax_dicts:
57+
ifamax_dictisnotNone:
58+
merged_amax_dict.update(amax_dict)
59+
60+
print(f"Total amax entries from all ranks:{len(merged_amax_dict.keys())}")
61+
torch.save(merged_amax_dict,save_directory+"/quant_amax.pth")
62+
else:
63+
# Other ranks just send their amax values
64+
torch.distributed.gather_object(amax_state_dict,None,dst=0)
65+
66+
torch.distributed.barrier()
67+
68+
69+
classVllmFqGPTModelExporter(GPTModelExporter):
70+
"""VLLM fakequant GPTModel exporter."""
71+
72+
defsave_pretrained(
73+
self,
74+
save_directory:str|os.PathLike,
75+
pretrained_model_name_or_path:str|os.PathLike|None=None,
76+
):
77+
os.makedirs(save_directory,exist_ok=True)
78+
gather_mcore_vllm_fq_quantized_state_dict(self.model,self.state_dict,save_directory)
79+
assertnot (self.is_multimodalandpretrained_model_name_or_pathisnotNone), (
80+
"Exporting weights in bf16 and amax values is not supported for multimodal models "
81+
"when pretrained_model_name_or_path is not None"
82+
)
83+
assertnotself.export_extra_modules, (
84+
"Exporting extra modules is not supported for vLLM fakequant"
85+
)
86+
super().save_pretrained(save_directory,pretrained_model_name_or_path)
87+
88+
def_get_quantization_format(self,module:torch.nn.Module):
89+
returnQUANTIZATION_NONE
90+
91+
92+
defexport_mcore_gpt_to_hf_vllm_fq(
93+
model:torch.nn.Module,
94+
pretrained_model_name_or_path:str|os.PathLike|None=None,
95+
export_extra_modules:bool=False,
96+
dtype:torch.dtype=torch.bfloat16,
97+
export_dir:Path|str=tempfile.gettempdir(),
98+
moe_router_dtype:torch.dtype|None=None,
99+
):
100+
"""Export Megatron Core GPTModel to unified checkpoint and save to export_dir.
101+
102+
Args:
103+
model: The Megatron Core GPTModel instance.
104+
pretrained_model_name_or_path: Can be either: the *model id* of a
105+
pretrained model hosted inside a model repo on huggingface.co; or
106+
a *directory* containing model weights saved using
107+
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
108+
export_extra_modules: If True, export extra modules like medusa_heads or
109+
eagle_module. Otherwise, only export the base model.
110+
dtype: The weights data type to export the unquantized layers.
111+
export_dir: The target export path.
112+
"""
113+
exporter=VllmFqGPTModelExporter(
114+
model,
115+
pretrained_model_name_or_path,
116+
export_extra_modules=export_extra_modules,
117+
dtype=dtype,
118+
moe_router_dtype=moe_router_dtype,
119+
)
120+
exporter.save_pretrained(export_dir,pretrained_model_name_or_path)

‎modelopt/torch/export/unified_export_hf.py‎

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
)
6060
from .model_utilsimportget_language_model_from_vl,is_multimodal_model
6161
from .pluginsimportexport_spec_ckpt_config,export_spec_ckpt_state_dict,spec_opt_only
62-
from .plugins.vllm_fakequantimportexport_hf_vllm_fq_checkpoint
6362
from .quant_utilsimport (
6463
fuse_prequant_layernorm,
6564
fuse_prequant_to_linear,
@@ -559,7 +558,6 @@ def export_hf_checkpoint(
559558
dtype:torch.dtype|None=None,
560559
export_dir:Path|str=tempfile.gettempdir(),
561560
save_modelopt_state:bool=False,
562-
export_vllm_fq_weights_qstate:bool=False,
563561
):
564562
"""Exports the torch model to unified checkpoint and saves to export_dir.
565563
@@ -568,8 +566,6 @@ def export_hf_checkpoint(
568566
dtype: the weights data type to export the unquantized layers or the default model data type if None.
569567
export_dir: the target export path.
570568
save_modelopt_state: whether to save the modelopt state_dict.
571-
export_vllm_fq_weights_qstate: whether to export the weights and quantization state separately for vLLM
572-
fakequant serving.
573569
"""
574570
export_dir=Path(export_dir)
575571
export_dir.mkdir(parents=True,exist_ok=True)
@@ -583,11 +579,7 @@ def export_hf_checkpoint(
583579
return
584580

585581
try:
586-
ifexport_vllm_fq_weights_qstate:
587-
post_state_dict=export_hf_vllm_fq_checkpoint(model,export_dir)
588-
hf_quant_config=None
589-
else:
590-
post_state_dict,hf_quant_config=_export_hf_checkpoint(model,dtype)
582+
post_state_dict,hf_quant_config=_export_hf_checkpoint(model,dtype)
591583

592584
ifhf_quant_configisnotNone:
593585
# Save hf_quant_config.json for\ backward compatibility

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp