Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit1d0ee04

Browse files
authored
[OMNIML-2932] Fusing pre_quant_scale for NVFP4 AWQ (#421)
## What does this PR do?**Type of change:** ? <!-- Use one of the following: Bug fix, newfeature, new example, new tests, documentation. -->**Overview:** This PR andNVIDIA/TensorRT-LLM#8698 enableNVFP4 AWQ deployment for TRT-LLM. Specifically, this PR fusespre_quant_scale in following two cases:* For MLP, pre_quant_scale of gate_proj layer is fused into up_proj'sweight, so we don't need an extra handle in downstream fused moekernels.* For attention, we will try to fuse the pre_quant_scale of o_proj tov_proj if their dimensions match, which means we will skip fusion forMQA/GQA models.## Usage<!-- You can potentially add a usage example below. -->```python# Add a code snippet demonstrating how to use this```## Testing<!-- Mention how have you tested your change if applicable. -->unit test, e2e test for Qwen3 dense and moe models.## Before your PR is "*Ready for review*"<!-- If you haven't finished some of the above items you can still open`Draft` PR. -->- **Make sure you read and follow [Contributorguidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**and your commits are signed.- **Is this change backward compatible?**: Yes/No <!--- If No, explainwhy. -->- **Did you write any new necessary tests?**: Yes/No- **Did you add or update any necessary documentation?**: Yes/No- **Did you update[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:Yes/No <!--- Only for new features, API changes, critical bug fixes orbw breaking changes. -->## Additional Information<!-- E.g. related issue. -->---------Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parenta8641bc commit1d0ee04

File tree

4 files changed

+330
-19
lines changed

4 files changed

+330
-19
lines changed

‎CHANGELOG.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Model Optimizer Changelog (Linux)
4747
- Enabled native Modelopt quantization support for FP8 and NVFP4 formats in SGLang. See `SGLang quantization documentation<https://github.com/sgl-project/sglang/blob/main/docs/advanced_features/quantization.md#using-nvidia-modelopt>`_ for more details.
4848
- Added modelopt quantized checkpoints in vLLM/SGLang CI/CD pipelines (PRs are under review).
4949
- Add support for exporting QLoRA checkpoint fintuned using ModelOpt.
50+
- Update NVFP4 AWQ checkpoint export. It now fuses scaling factors of o_proj and down_proj layers into the model when possible to facilitate deployment.
5051

5152
**Documentation**
5253

‎modelopt/torch/export/quant_utils.py‎

Lines changed: 131 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
489489

490490
ifinput_quantizerisnotNoneandhasattr(input_quantizer,"_pre_quant_scale"):
491491
returnQUANTIZATION_NVFP4_AWQ
492-
ifgetattr(layer,"fused_with_layernorm",False):
492+
ifgetattr(layer,"fused_with_prequant",False):
493493
returnQUANTIZATION_NVFP4_AWQ
494494
assertinput_quantizerisnotNone, (
495495
f"input_quantizer is None for{quantizer_attr_names}"
@@ -959,18 +959,145 @@ def all_items_same(item_list):
959959
returnall(x==item_list[0]forxinitem_list)
960960

961961

962+
def_update_pre_quant_scale(module,new_pre_quant_scale):
963+
old_pre_quant_scale=module.input_quantizer._pre_quant_scale
964+
# do the processing in fp32 for numerical stability
965+
dtype=module.weight.dtype
966+
module.weight=nn.Parameter(
967+
(
968+
module.weight.to(torch.float32)
969+
*old_pre_quant_scale.to(dtype=torch.float32,device=module.weight.device)
970+
/new_pre_quant_scale.to(dtype=torch.float32,device=module.weight.device)
971+
).to(dtype)
972+
)
973+
module.input_quantizer.pre_quant_scale=new_pre_quant_scale
974+
975+
# Redo weights collection
976+
module.weight_quantizer.reset_amax()
977+
enable_stats_collection(module.weight_quantizer)
978+
module.weight_quantizer(module.weight)
979+
finish_stats_collection(module.weight_quantizer)
980+
981+
982+
# Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale))
983+
PQS_FUSE_MODULE_MAPPING= [
984+
# Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension
985+
# Mathematical equivalence:
986+
# Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T
987+
# After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T
988+
(["LlamaAttention","Qwen3Attention","Qwen3MoeAttention"], ("v_proj","o_proj")),
989+
# MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension
990+
# Mathematical equivalence:
991+
# Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
992+
# After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
993+
(["LlamaMLP","Qwen3MLP","Qwen3MoeMLP"], ("up_proj","down_proj")),
994+
]
995+
996+
997+
deffuse_prequant_to_linear(model:torch.nn.Module,fuse_grouped_heads=False):
998+
"""Fuse pre_quant_scale to the linear weights if possible.
999+
1000+
Args:
1001+
model: The model to fuse pre_quant_scale to.
1002+
fuse_grouped_heads: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale
1003+
and linear weights is not the same.
1004+
1005+
Returns:
1006+
fused_modules: A list of modules of which pre_quant_scale is fused to the previous linear layer.
1007+
"""
1008+
# Fuse pre_quant_scale to the linear weights
1009+
for_,moduleinmodel.named_modules():
1010+
formodule_mapinPQS_FUSE_MODULE_MAPPING:
1011+
target_module_list=module_map[0]
1012+
linear_pair=module_map[1]
1013+
ifany(module_nameintype(module).__name__formodule_nameintarget_module_list):
1014+
linear_fuse_into=module.get_submodule(linear_pair[0])
1015+
linear_pqs_from=module.get_submodule(linear_pair[1])
1016+
ifhasattr(linear_pqs_from,"input_quantizer")andhasattr(
1017+
linear_pqs_from.input_quantizer,"_pre_quant_scale"
1018+
):
1019+
pre_quant_scale=linear_pqs_from.input_quantizer._pre_quant_scale
1020+
1021+
# for GQA/MQA models, we can apply averaging to the pre_quant_scale for shared head groups
1022+
ifpre_quant_scale.numel()!=linear_fuse_into.weight.shape[-2]:
1023+
if (
1024+
notfuse_grouped_heads
1025+
or"attention"notintype(module).__name__.lower()
1026+
):
1027+
warn(
1028+
f"Skipping pattern fuse prequant for{type(module).__name__}"
1029+
f"pre_quant_scale dim{pre_quant_scale.numel()} != "
1030+
f"out_channel dim{linear_fuse_into.weight.shape[-2]}"
1031+
)
1032+
continue
1033+
config=module.config
1034+
num_kv_heads=config.num_key_value_heads
1035+
kv_head_dim=linear_fuse_into.weight.shape[0]//num_kv_heads
1036+
n_rep=pre_quant_scale.numel()//num_kv_heads//kv_head_dim
1037+
1038+
# Reshape:(num_kv_heads, n_rep, kv_head_dim)
1039+
# n_rep is the number of query group
1040+
averaged_scale=pre_quant_scale.view(
1041+
num_kv_heads,n_rep,kv_head_dim
1042+
).mean(dim=1)
1043+
1044+
# To update o_proj, we need to repeat back to original shape
1045+
repeated_scale= (
1046+
averaged_scale.unsqueeze(1)
1047+
.expand(num_kv_heads,n_rep,kv_head_dim)
1048+
.reshape(-1)
1049+
)
1050+
# Update o_proj's pre_quant_scale
1051+
_update_pre_quant_scale(linear_pqs_from,repeated_scale)
1052+
1053+
# Use averaged scale (flattened) for v_proj fusion
1054+
pre_quant_scale=averaged_scale.reshape(-1)
1055+
1056+
# Fuse the pre_quant_scale to weight
1057+
linear_fuse_into.weight=torch.nn.Parameter(
1058+
linear_fuse_into.weight*pre_quant_scale.view(-1,1)
1059+
)
1060+
ifhasattr(linear_fuse_into,"bias")andlinear_fuse_into.biasisnotNone:
1061+
linear_fuse_into.bias=torch.nn.Parameter(
1062+
linear_fuse_into.bias*pre_quant_scale
1063+
)
1064+
1065+
# Recalibrate the weight quantizer for linear_fuse_into
1066+
linear_fuse_into.weight_quantizer.reset_amax()
1067+
enable_stats_collection(linear_fuse_into.weight_quantizer)
1068+
linear_fuse_into.weight_quantizer(linear_fuse_into.weight)
1069+
finish_stats_collection(linear_fuse_into.weight_quantizer)
1070+
1071+
delattr(linear_pqs_from.input_quantizer,"_pre_quant_scale")
1072+
setattr(linear_pqs_from,"fused_with_prequant",True)
1073+
1074+
9621075
deffuse_prequant_layernorm(
9631076
layernorm_module:torch.nn.Module,
9641077
modules:list[torch.Tensor],
9651078
):
966-
"""Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted."""
1079+
"""Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.
1080+
1081+
original:
1082+
layernorm_output = (normalization(input) * weight) + bias
1083+
layernorm_output_scaled = layernorm_output * pre_quant_scale
1084+
1085+
fused:
1086+
fused_weight = weight * avg_pre_quant_scale
1087+
fused_bias = bias * avg_pre_quant_scale
1088+
layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias
1089+
"""
9671090
layernorm_module.weight=torch.nn.Parameter(
9681091
layernorm_module.weight*getattr(modules[0].input_quantizer,"_pre_quant_scale")
9691092
)
1093+
ifhasattr(layernorm_module,"bias")andlayernorm_module.biasisnotNone:
1094+
layernorm_module.bias=torch.nn.Parameter(
1095+
layernorm_module.bias*getattr(modules[0].input_quantizer,"_pre_quant_scale")
1096+
)
9701097
# Pre_quant_scales of modules must not be exported, since they have been fused with layernorm
9711098
formoduleinmodules:
9721099
delattr(module.input_quantizer,"_pre_quant_scale")
973-
setattr(module,"fused_with_layernorm",True)
1100+
setattr(module,"fused_with_prequant",True)
9741101

9751102

9761103
defpreprocess_linear_fusion(modules:list[torch.nn.Module],resmooth_only=False):
@@ -992,22 +1119,7 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False
9921119

9931120
formoduleinmodules:
9941121
ifnottorch.equal(module.input_quantizer.pre_quant_scale,avg_prequant_scale):
995-
module.weight=nn.Parameter(
996-
module.weight
997-
*module.input_quantizer.pre_quant_scale.to(
998-
dtype=module.weight.dtype,device=module.weight.device
999-
)
1000-
/avg_prequant_scale.to(
1001-
dtype=module.weight.dtype,device=module.weight.device
1002-
)
1003-
)
1004-
module.input_quantizer.pre_quant_scale=avg_prequant_scale
1005-
1006-
# Redo weights collection
1007-
module.weight_quantizer.reset_amax()
1008-
enable_stats_collection(module.weight_quantizer)
1009-
module.weight_quantizer(module.weight)
1010-
finish_stats_collection(module.weight_quantizer)
1122+
_update_pre_quant_scale(module,avg_prequant_scale)
10111123

10121124
ifresmooth_only:
10131125
return

‎modelopt/torch/export/unified_export_hf.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from .pluginsimportexport_spec_ckpt_config,export_spec_ckpt_state_dict,spec_opt_only
6161
from .quant_utilsimport (
6262
fuse_prequant_layernorm,
63+
fuse_prequant_to_linear,
6364
get_activation_scaling_factor,
6465
get_quant_config,
6566
get_quantization_format,
@@ -107,6 +108,10 @@ def _output_hook(module, input, output):
107108
fused_linears= {}
108109
module_names=set()
109110

111+
# Fuse pre_quant_scale to the linear weights if possible
112+
ifquantization_formatisnotNoneand"nvfp4_awq"inquantization_format.lower():
113+
fuse_prequant_to_linear(model)
114+
110115
forname,moduleinmodel.named_modules():
111116
module_names.add(name)
112117

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp