Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7a36ccc

Browse files
Added support to export for BF16 weight and amax for vLLM fakequant QAT (#579)
## What does this PR do?**Type of change:** New Feature**Overview:** Support for vLLM fakequantize QAT/QAD checkpoint evaluation. This MRadds function to export checkpoint as BF16 weights and amax using`export_hf_checkpoint` for HF and `export_mcore_gpt_to_hf` for MCoreusing `export_bf16_weights_amax` option. The exported weights and amaxcan be used with[vllm_serve_fakequant.py](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/096ee13ea62bbb0ce0a4e4128c439651374d6235/examples/vllm_serve/vllm_serve_fakequant.py)script to run saved checkpoint.## UsageRefer to[README.md](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/096ee13ea62bbb0ce0a4e4128c439651374d6235/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip)## Testing- Tested HF approach by exporting bf16 model using QAT script andrunning vllm server, verified amax values match- Tested MCore approach by quantizing and exporting bf16 model usingquantize.sh and export.sh script and running vllm server, verified amaxvalues match## Before your PR is "*Ready for review*"<!-- If you haven't finished some of the above items you can still open`Draft` PR. -->- **Make sure you read and follow [Contributorguidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**and your commits are signed.- **Is this change backward compatible?**: Yes- **Did you write any new necessary tests?**: No- **Did you add or update any necessary documentation?**: Yes- **Did you update[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:Yes## Additional InformationMCore export script doesn't have the option to export enable currently---------Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent5842d73 commit7a36ccc

File tree

9 files changed

+547
-240
lines changed

9 files changed

+547
-240
lines changed

‎CHANGELOG.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Model Optimizer Changelog (Linux)
1717
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
1818
- Add support for PyTorch Geometric quantization.
1919
- Add per tensor and per channel MSE calibrator support.
20+
- Added support for PTQ/QAT checkpoint export and loading for running fakequant evaluation in vLLM. See ``examples/vllm_serve/README.md#load-qatptq-model-and-serve-in-vllm-wip`` for more details.
2021

2122
**Documentation**
2223

‎examples/vllm_serve/README.md‎

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,19 @@ lm_eval --model local-completions --tasks gsm8k --model_args model=<model_name>,
5555

5656
##Load QAT/PTQ model and serve in vLLM (WIP)
5757

58-
Overwrite the calibrated amax value with prepared values from eitherPTQ/QAT. This is only tested for Llama3.1
58+
Overwrite the calibrated amax value with prepared values from either QAT/PTQ.
5959

60-
Step 1: convert amax to merged amax, using llama3.1 as an example:
60+
Step 1: export the model with bf16 weights and amax values.
61+
62+
- For HF model set`export_bf16_weights_amax` to export the model with function`modelopt.torch.export.unified_export_hf.export_hf_checkpoint`.
63+
- For MCore model use`export_bf16_weights_amax` to export the model with function`modelopt.torch.export.unified_export_megatron.export_mcore_gpt_to_hf`.
64+
65+
Step 2: configure <quant_amax.pth> from exported model using AMAX_FILE_PATH environment variable in step 1. For example:
6166

6267
```bash
63-
pythonconvert_amax_hf2vllm.py-i<amax.pth> -o<vllm_amax.pth>
68+
AMAX_FILE_PATH=<vllm_amax.pth> QUANT_CFG=<quant_config>pythonvllm_serve_fakequant.py<model_path> -tp 8 --host 0.0.0.0 --port 8000
6469
```
6570

66-
Step 2: add`<vllm_amax.pth>` to`quant_config` in`vllm_serve_fakequant.py`
67-
6871
##Important Notes
6972

7073
**Amax Synchronization across Tensor Parallel (TP):**
@@ -85,3 +88,5 @@ torch.distributed.barrier()
8588
##Known Problems
8689

8790
1. AWQ is not yet supported in vLLM.
91+
2. PTQ/QAT checkpoint doesn't work with KV Cache quantization enabled.
92+
3. Mixed precision checkpoint doesn't work currently.

‎examples/vllm_serve/convert_amax_hf2vllm.py‎

Lines changed: 0 additions & 213 deletions
This file was deleted.

‎examples/vllm_serve/fakequant_worker.py‎

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
importdataclasses
1717
importos
18+
importre
1819
importwarnings
20+
fromcollectionsimportdefaultdict
1921
fromcontextlibimportcontextmanager
2022
fromtypingimportAny
2123

@@ -30,6 +32,99 @@
3032
frommodelopt.torch.utils.dataset_utilsimportget_dataset_dataloader
3133

3234

35+
defconvert_amax_hf2vllm(
36+
hf_state_dict:dict[str,torch.Tensor],fuse_experts:bool=False
37+
)->dict[str,torch.Tensor]:
38+
"""
39+
Convert amax values from HuggingFace format to vLLM format.
40+
41+
This function merges:
42+
- q_proj, k_proj, v_proj amax values into qkv_proj (taking max)
43+
- gate_proj, up_proj amax values into gate_up_proj (taking max)
44+
45+
Args:
46+
hf_state_dict: HuggingFace state dict containing amax values
47+
48+
Returns:
49+
vLLM format state dict with merged amax values
50+
"""
51+
vllm_state_dict= {}
52+
53+
# Group keys by their base pattern (without the specific projection name)
54+
merge_groups=defaultdict(list)
55+
56+
forkey,valueinhf_state_dict.items():
57+
if"_amax"notinkey:
58+
# Copy non-amax keys as-is
59+
vllm_state_dict[key]=value
60+
continue
61+
62+
# Check if this is a q/k/v projection that needs merging
63+
qkv_match=re.search(r"(.*\.)([qkv])_proj(\..+_amax)$",key)
64+
ifqkv_match:
65+
base_pattern=qkv_match.group(1)+"qkv_proj"+qkv_match.group(3)
66+
merge_groups[base_pattern].append((key,value))
67+
continue
68+
69+
# Check if this is an expert gate/up projection
70+
# Pattern: model.layers.0.mlp.experts.*.gate_proj.input_quantizer._amax and
71+
# model.layers.0.mlp.experts.*.up_proj.input_quantizer._amax
72+
# Maps to: model.layers.0.mlp.experts.w13_input_quantizer._amax
73+
expert_gate_up_match= (
74+
"mixer"notinkey
75+
andfuse_experts
76+
andre.search(r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer\._amax)$",key)
77+
)
78+
ifexpert_gate_up_match:
79+
base_pattern=expert_gate_up_match.group(1)+".w13_"+expert_gate_up_match.group(3)
80+
merge_groups[base_pattern].append((key,value))
81+
continue
82+
83+
# Check if this is a non-expert gate/up projection that needs merging
84+
gate_up_match= (
85+
"mixer"notinkey
86+
and"experts"notinkey
87+
andre.search(r"(.*\.)(gate|up)_proj(\..+_amax)$",key)
88+
)
89+
ifgate_up_match:
90+
base_pattern=gate_up_match.group(1)+"gate_up_proj"+gate_up_match.group(3)
91+
merge_groups[base_pattern].append((key,value))
92+
continue
93+
94+
# Check if this is an expert down_proj
95+
# Pattern: model.layers.0.mlp.experts.*.down_proj.input_quantizer._amax
96+
# Maps to: model.layers.0.mlp.experts.w2_input_quantizer._amax
97+
expert_down_match= (
98+
"mixer"notinkey
99+
andfuse_experts
100+
andre.search(r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer\._amax)$",key)
101+
)
102+
ifexpert_down_match:
103+
base_pattern=expert_down_match.group(1)+".w2_"+expert_down_match.group(2)
104+
merge_groups[base_pattern].append((key,value))
105+
continue
106+
107+
# Copy other amax keys as-is (like o_proj, down_proj)
108+
vllm_state_dict[key]=value
109+
110+
# Merge grouped amax values by taking the maximum
111+
formerged_key,key_value_pairsinmerge_groups.items():
112+
iflen(key_value_pairs)>1:
113+
# Take the maximum across all values for this merged key
114+
values= [valuefor_,valueinkey_value_pairs]
115+
merged_value=torch.stack(values).max(dim=0)[0]
116+
vllm_state_dict[merged_key]=merged_value
117+
print(f"Merged{len(key_value_pairs)} keys into{merged_key}")
118+
fororig_key,_inkey_value_pairs:
119+
print(f" -{orig_key}")
120+
else:
121+
# Single key, just rename it
122+
_,value=key_value_pairs[0]
123+
vllm_state_dict[merged_key]=value
124+
125+
returnvllm_state_dict
126+
127+
33128
@contextmanager
34129
defdisable_compilation(model):
35130
do_not_compile=True
@@ -154,8 +249,17 @@ def calibrate_loop(model: Any = None) -> None:
154249
ifamax_file_path:
155250
print(f"Loading amax values from{amax_file_path}")
156251
saved_amax_dict=torch.load(amax_file_path)
157-
current_state_dict=model.state_dict()
252+
# convert amax keys to vLLM format
253+
ifhasattr(self.model_runner.model,"hf_to_vllm_mapper"):
254+
saved_amax_dict=self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict)
255+
saved_amax_dict= {
256+
key.replace("quantizer_amax","quantizer._amax"):value
257+
forkey,valueinsaved_amax_dict.items()
258+
ifkey.endswith("quantizer_amax")
259+
}
260+
saved_amax_dict=convert_amax_hf2vllm(saved_amax_dict,fuse_experts=True)
158261

262+
current_state_dict=model.state_dict()
159263
# Count amax keys in checkpoint and model
160264
checkpoint_amax_keys= [keyforkeyinsaved_amax_dictifkey.endswith("_amax")]
161265
model_amax_keys= [keyforkeyincurrent_state_dictifkey.endswith("_amax")]

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp