Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit93f5bbf

Browse files
authored
[OMNIML-3015]Add per tensor/per channel MSE calibrator (#540)
## What does this PR do?**Type of change:** ? <!-- Use one of the following: Bug fix, newfeature, new example, new tests, documentation. -->new feature**Overview:** ?Add per tensor/per channel MSE calibrator.## UsageCan be enabled with "algorithm" field in quantization configs.```"algorithm": {"method": "mse", "num_steps": 20, "stop_multiplier": 8.0},```## Testing<!-- Mention how have you tested your change if applicable. -->Unit test for the MseCalibrator,E2E test with NVFP4 and INT8,**results: **start_multiplier=0.25stop_multiplier=4.0num_steps=20**Qwen3-8B MMLU:****BF16 baseline: 72.94**| Calib Algo | NVFP4 | FP8 | INT8 || ------ | ------ | ------ | ------ || MSE | 70.88 | 72.65 | 55.46 || MAX | 70.83 | 72.7 | 24.52 |## Before your PR is "*Ready for review*"<!-- If you haven't finished some of the above items you can still open`Draft` PR. -->- **Make sure you read and follow [Contributorguidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**and your commits are signed.- **Is this change backward compatible?**: Yes- **Did you write any new necessary tests?**: Yes- **Did you add or update any necessary documentation?**: Yes/No- **Did you update[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:Yes/No <!--- Only for new features, API changes, critical bug fixes orbw breaking changes. -->## Additional InformationTODO: for the follow up PR:- [ ] TP sync for HF models- [ ] Calculate weight quantizer only once<!-- This is an auto-generated comment: release notes by coderabbit.ai-->## Summary by CodeRabbit## Release Notes* **New Features*** Added MSE-based quantization calibration supporting per-tensor andper-channel optimization with configurable parameters (step count,multiplier ranges).* **Tests** * Added comprehensive test coverage for MSE calibration functionality.* **Documentation** * Updated changelog to reflect MSE calibrator support.<sub>✏️ Tip: You can customize this high-level summary in your reviewsettings.</sub><!-- end of auto-generated comment: release notes by coderabbit.ai -->---------Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
1 parent01e24fd commit93f5bbf

File tree

10 files changed

+950
-25
lines changed

10 files changed

+950
-25
lines changed

‎CHANGELOG.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Model Optimizer Changelog (Linux)
1616
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
1717
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
1818
- Add support for PyTorch Geometric quantization.
19+
- Add per tensor and per channel MSE calibrator support.
1920

2021
**Documentation**
2122

‎modelopt/torch/quantization/calib/__init__.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
from .calibratorimport*
2424
from .histogramimport*
2525
from .maximport*
26+
from .mseimport*
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Calibrator that returns the MSE amax of all collected tensors."""
17+
18+
fromcollections.abcimportCallable
19+
20+
importtorch
21+
importtorch.nn.functionalasF
22+
23+
from ..importutilsasquant_utils
24+
from .calibratorimport_Calibrator
25+
26+
__all__= ["MseCalibrator"]
27+
28+
29+
classMseCalibrator(_Calibrator):
30+
"""Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x."""
31+
32+
def__init__(
33+
self,
34+
amax:torch.Tensor,
35+
axis:int|tuple|list|None=None,
36+
num_steps:int=10,
37+
start_multiplier:float=0.25,
38+
stop_multiplier:float=4.0,
39+
quant_func:Callable[[torch.Tensor,torch.Tensor],torch.Tensor]|None=None,
40+
error_func:Callable[[torch.Tensor,torch.Tensor],torch.Tensor]|None=None,
41+
):
42+
"""Initialize MSE calibrator.
43+
44+
Args:
45+
amax: Initial amax value (required).
46+
axis: Quantization axis. None means per-tensor quantization.
47+
num_steps: Number of amax candidates to try.
48+
start_multiplier: Starting multiplier for amax search.
49+
stop_multiplier: Ending multiplier for amax search.
50+
quant_func: Function that quantizes input tensor given an amax value.
51+
Should have signature: quant_func(x, amax) -> quantized_x.
52+
error_func: Function to compute error between x and xq.
53+
Default is F.mse_loss(x, xq, reduction='none').
54+
"""
55+
super().__init__(num_bits=None,axis=axis,unsigned=None)
56+
self._initial_amax=amax
57+
self._num_steps=num_steps
58+
self._start_multiplier=start_multiplier
59+
self._stop_multiplier=stop_multiplier
60+
self._quant_func=quant_func
61+
self._error_func=error_func
62+
self._losses_sum= [None]*num_steps
63+
self._candidate_amaxs= [None]*num_steps
64+
65+
self._amax=None
66+
67+
@torch.no_grad()
68+
defcollect(self,x:torch.Tensor):
69+
"""Collect input tensor statistics and compute losses for MSE calibration.
70+
71+
Args:
72+
x: Input tensor.
73+
"""
74+
ifself._quant_funcisNone:
75+
raiseRuntimeError(
76+
"Quantization function not set. Msecalibrator requires a quant_func to be provided."
77+
)
78+
79+
x=x.detach().to(dtype=torch.float32)
80+
81+
device=x.device
82+
multipliers=torch.linspace(
83+
self._start_multiplier,self._stop_multiplier,steps=self._num_steps,device=device
84+
)
85+
86+
# Get reduce axis for per-channel quantization
87+
reduce_axis=quant_utils.convert_quantization_axis_to_reduce_axis(x,self._axis)
88+
89+
forstep,multiplierinenumerate(multipliers):
90+
candidate_amax=self._initial_amax*multiplier
91+
xq=self._quant_func(x,candidate_amax)
92+
93+
ifself._error_funcisnotNone:
94+
error=self._error_func(x,xq)
95+
else:
96+
error=F.mse_loss(x,xq,reduction="none")
97+
98+
loss=quant_utils.reduce_sum(error,axis=reduce_axis,keepdims=False)
99+
100+
ifself._candidate_amaxs[step]isNone:
101+
self._candidate_amaxs[step]=candidate_amax
102+
103+
ifself._losses_sum[step]isNone:
104+
self._losses_sum[step]=loss.clone()
105+
else:
106+
self._losses_sum[step]+=loss
107+
108+
defreset(self):
109+
"""Reset the stored losses and amax value."""
110+
self._losses_sum= [None]*self._num_steps
111+
self._candidate_amaxs= [None]*self._num_steps
112+
self._amax=None
113+
114+
@torch.no_grad()
115+
defcompute_amax(self,verbose:bool=False):
116+
"""Return the amax value that minimizes quantization error.
117+
118+
Args:
119+
verbose: If True, print the ratio of best_amax to initial_amax.
120+
"""
121+
ifnotany(loss_sumisnotNoneforloss_suminself._losses_sum):
122+
returnNone
123+
124+
# Check if this is per-tensor or per-channel based on the first loss
125+
first_loss_sum=None
126+
forloss_suminself._losses_sum:
127+
ifloss_sumisnotNone:
128+
first_loss_sum=loss_sum
129+
break
130+
131+
iffirst_loss_sumisNone:
132+
returnNone
133+
134+
# Collect losses for all steps
135+
losses_per_step= []
136+
forstepinrange(self._num_steps):
137+
ifself._losses_sum[step]isnotNone:
138+
losses_per_step.append(self._losses_sum[step])
139+
# No data for this step, use inf
140+
eliffirst_loss_sum.ndim==0:
141+
losses_per_step.append(torch.tensor(float("inf"),device=first_loss_sum.device))
142+
else:
143+
losses_per_step.append(torch.full_like(first_loss_sum,float("inf")))
144+
145+
# Stack to get [num_steps] for per-tensor or [num_steps, num_channels] for per-channel
146+
losses_per_step=torch.stack(losses_per_step)
147+
148+
# Find best step(s): scalar for per-tensor, [num_channels] for per-channel
149+
best_steps=torch.argmin(losses_per_step,dim=0)
150+
151+
# Stack candidate amaxs and select based on best_steps
152+
candidate_amaxs=torch.stack(self._candidate_amaxs)
153+
154+
iffirst_loss_sum.ndim==0:
155+
# Per-tensor case: best_steps is a scalar
156+
self._amax=self._candidate_amaxs[best_steps.item()]
157+
else:
158+
# Per-channel case: best_steps is a tensor
159+
num_channels=best_steps.shape[0]
160+
self._amax=candidate_amaxs[
161+
best_steps,torch.arange(num_channels,device=best_steps.device)
162+
]
163+
self._amax=self._amax.reshape(self._initial_amax.shape)
164+
165+
ifverbose:
166+
ratio=self._amax/self._initial_amax
167+
ifratio.ndim==0:
168+
print(f"MSE Calibrator: best_amax/initial_amax ratio ={ratio.item():.4f}")
169+
else:
170+
print(
171+
f"MSE Calibrator: best_amax/initial_amax ratio - "
172+
f"mean:{ratio.mean().item():.4f}, "
173+
f"min:{ratio.min().item():.4f}, "
174+
f"max:{ratio.max().item():.4f}"
175+
)
176+
177+
returnself._amax

‎modelopt/torch/quantization/config.py‎

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,45 @@ class MaxCalibConfig(QuantizeAlgorithmConfig):
981981
)
982982

983983

984+
classMseCalibConfig(QuantizeAlgorithmConfig):
985+
"""Configuration for per-tensor MSE calibration.
986+
987+
Finds a scale s (via amax a, with s = a / q_max) that minimizes the
988+
reconstruction error of a tensor after uniform Q→DQ:
989+
990+
s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations}
991+
"""
992+
993+
method:Literal["mse"]=ModeloptField("mse")
994+
995+
num_steps:int|None=ModeloptField(
996+
default=10,
997+
ge=1,
998+
title="Number of amax candidates to try.",
999+
description="Number of amax candidates to search over for MSE minimization.",
1000+
)
1001+
1002+
start_multiplier:float|None=ModeloptField(
1003+
default=0.25,
1004+
gt=0.0,
1005+
title="Starting multiplier for amax search.",
1006+
description="Starting multiplier for amax search range (multiplies initial amax).",
1007+
)
1008+
1009+
stop_multiplier:float|None=ModeloptField(
1010+
default=4.0,
1011+
gt=0.0,
1012+
title="Ending multiplier for amax search.",
1013+
description="Ending multiplier for amax search range (multiplies initial amax).",
1014+
)
1015+
1016+
distributed_sync:bool|None=ModeloptField(
1017+
default=True,
1018+
title="Whether to sync the amax across the distributed processes.",
1019+
description="If True, the amax will be synced across the distributed processes.",
1020+
)
1021+
1022+
9841023
classSmoothQuantCalibConfig(QuantizeAlgorithmConfig):
9851024
"""The config for ``smoothquant`` algorithm (SmoothQuant).
9861025

‎modelopt/torch/quantization/mode.py‎

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
AWQLiteCalibConfig,
3939
CompressConfig,
4040
MaxCalibConfig,
41+
MseCalibConfig,
4142
QuantizeAlgoCfgType,
4243
QuantizeAlgorithmConfig,
4344
QuantizeConfig,
@@ -54,7 +55,7 @@
5455
restore_svdquant_model,
5556
update_quantize_metadata,
5657
)
57-
from .model_calibimportawq,max_calibrate,smoothquant,svdquant
58+
from .model_calibimportawq,max_calibrate,mse_calibrate,smoothquant,svdquant
5859

5960
__all__= ["BaseCalibrateModeDescriptor"]
6061

@@ -363,6 +364,18 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
363364
_calib_func=max_calibrate
364365

365366

367+
@CalibrateModeRegistry.register_mode
368+
classMseCalibrateModeDescriptor(BaseCalibrateModeDescriptor):
369+
"""Mode for mse calibration algorithm."""
370+
371+
@property
372+
defconfig_class(self)->type[QuantizeAlgorithmConfig]:
373+
"""Specifies the config class for the mode."""
374+
returnMseCalibConfig
375+
376+
_calib_func=mse_calibrate
377+
378+
366379
@CalibrateModeRegistry.register_mode
367380
classSmoothQuantModeDescriptor(BaseCalibrateModeDescriptor):
368381
"""Mode for smoothquant calibration algorithm."""

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp