Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit6abded4

Browse files
authored
[5455919] Fix Q/DQ/Cast placement in 'FP32 required' custom ops (#554)
## What does this PR do?**Type of change:** Bug fix**Overview:** Fix incorrect quantization of custom ops when some inputtensors are required to be in INT8 and some in FP32.| Before fix | After fix ||----------------|-------------|| <img width="841" height="623" alt="snap_custom_op_quant_incorrect"src="https://github.com/user-attachments/assets/88e4d460-fbae-4bcb-86c8-139d23ce04c8"/> | <img width="786" height="286" alt="snap_custom_op_quant_correct"src="https://github.com/user-attachments/assets/475079c2-a565-4f0d-b167-6d801ab83dfc"/> |## Usage```python$ python -m modelopt.onnx.quantization --onnx_path=$MODEL_PATH.onnx \ --trt_plugins $PLUGIN_PATH.so \ --trt_plugins_precision $CUSTOM_OP_NAME:$PRECISION```## Testing### 1. BEVFormer model- Follow step 1 in[README](https://github.com/NVIDIA/DL4AGX/tree/master/AV-Solutions/bevformer-int8-eq#1-export-model-to-onnx-and-compile-plugins).- In the quantization step, do:```sh$ python -m modelopt.onnx.quantization --onnx_path=/mnt/models/bevformer_tiny_epoch_24_cp2_op13.onnx \ --trt_plugins=$PLUGIN_PATH \ --trt_plugins_precision MultiScaleDeformableAttnTRT:[int8,int32,fp32,int8,int8]:[int8] \ --high_precision_dtype fp16```> See table in "Overview" for expected graph structure.### 2. 5455919 modelValidated model in bug 5455919.## Before your PR is "*Ready for review*"<!-- If you haven't finished some of the above items you can still open`Draft` PR. -->- **Make sure you read and follow [Contributorguidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**and your commits are signed.- **Is this change backward compatible?**: Yes- **Did you write any new necessary tests?**: No- **Did you add or update any necessary documentation?**: No- **Did you update[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:Yes## Additional Information-/pull/363: Feature expansion.-/pull/524: The graph cleanup isactually needed after Q/DQ trimming around custom ops. Moved the cleanuplines to inside that function.---------Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com>
1 parente20d218 commit6abded4

File tree

7 files changed

+77
-41
lines changed

7 files changed

+77
-41
lines changed

‎CHANGELOG.rst‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ Model Optimizer Changelog (Linux)
2020
**Bug Fixes**
2121

2222
- Fix a bug in FastNAS pruning (computer vision models) where the model parameters were sorted twice messing up the ordering.
23+
- Fix Q/DQ/Cast node placements in 'FP32 required' tensors in custom ops in the ONNX quantization workflow.
2324

2425
**New Features**
2526

2627
- Add MoE (e.g. Qwen3-30B-A3B, gpt-oss-20b) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``).
2728
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md<https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
2829
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
30+
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
2931

3032

3133
0.39 (2025-11-11)

‎modelopt/onnx/autocast/convert.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def convert_to_f16(
194194
low_precision_type:str="fp16",
195195
keep_io_types:bool=True,
196196
op_block_list:list[str]= [],
197+
tensor_block_dict:dict[str,dict[str,list[int]]]= {},
197198
trt_plugins:list[str]|None= [],
198199
)->onnx.ModelProto:
199200
"""Convert model to mixed precision, using PrecisionConverter.
@@ -204,8 +205,8 @@ def convert_to_f16(
204205
model: ONNX model to convert.
205206
low_precision_type: Target precision to reduce to ('fp16' or 'bf16').
206207
keep_io_types: Whether to preserve input/output types.
207-
disable_shape_infer: Whether to disable shape inference.
208208
op_block_list: List of operation types that should remain in FP32.
209+
tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
209210
trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library).
210211
"""
211212
assertlow_precision_typein ["fp16","bf16"],"low_precision_type must be either fp16 or bf16"
@@ -235,6 +236,7 @@ def convert_to_f16(
235236
keep_io_types=keep_io_types,
236237
low_precision_type=low_precision_type,
237238
custom_ops=sanitizer.custom_ops,
239+
tensor_block_dict=tensor_block_dict,
238240
)
239241
high_precision_nodes= [node.namefornodeinmodel.graph.nodeifnode.op_typeinop_block_list]
240242
low_precision_nodes= [

‎modelopt/onnx/autocast/precisionconverter.py‎

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
min_opset:int=13,
100100
max_ir_version:int|None=None,
101101
trt_plugins:list[str]|None= [],
102+
tensor_block_dict:dict[str,dict[str,list[int]]]= {},
102103
)->None:
103104
"""Initialize PrecisionConverter.
104105
@@ -112,6 +113,10 @@ def __init__(
112113
init_conversion_max_bytes: Maximum size in bytes for initializer conversion. Larger initializers will be
113114
cast at runtime.
114115
custom_ops: List of custom ops.
116+
min_opset: Minimum opset for conversion.
117+
max_ir_version: Max IR version for conversion.
118+
trt_plugins: List of custom TensorRT plugin library paths in .so format (compiled shared library).
119+
tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
115120
"""
116121
self.model=deepcopy(model)
117122
self.value_info_map=value_info_map
@@ -148,6 +153,9 @@ def __init__(
148153
)
149154
)
150155

156+
# Custom mapping of op types to indices of inputs that should not be converted to low precision
157+
self.skip_inputs_map=self._create_skip_inputs_mapping(tensor_block_dict)
158+
151159
defconvert(
152160
self,
153161
high_precision_nodes:list[str],
@@ -211,7 +219,8 @@ def convert(
211219
# For the low precision nodes that take a FP32 input, we don't exclude it from
212220
# casting up so that the input can be converted to FP32 as expected.
213221
exclude_consumers=list(
214-
set(low_precision_nodes)- {fp32_input_to_low_precision_node[tensor_name].name}
222+
set(low_precision_nodes)
223+
- {n.nameforninfp32_input_to_low_precision_node[tensor_name]}
215224
)
216225
self._add_cast(
217226
tensor_name,
@@ -467,12 +476,14 @@ def _filter_unsupported_op_types(
467476
returnhigh_precision_nodes,low_precision_nodes
468477

469478
def_get_tensors_to_cast(
470-
self,low_precision_nodes:list[str]
471-
)->tuple[list[str],list[str],dict[str,onnx.NodeProto]]:
479+
self,
480+
low_precision_nodes:list[str],
481+
high_precision_tensors:dict[str,dict[str,list[int]]]= {},
482+
)->tuple[list[str],list[str],dict[str,list[onnx.NodeProto]]]:
472483
cast_to_fp16= []# Tensors to cast down to FP16
473484
cast_to_fp32= []# Tensors to cast up to FP32
474485
# Keep track of the low precision nodes that take a FP32 input.
475-
fp32_input_to_low_precision_node={}
486+
fp32_input_to_low_precision_node=defaultdict(list)
476487

477488
# Get tensors for FP16 nodes
478489
fornodeinself.model.graph.node:
@@ -481,7 +492,7 @@ def _get_tensors_to_cast(
481492
forinputinnode.input:
482493
ifself._should_skip_low_precision_input_conversion(node,input):
483494
cast_to_fp32.append(input)
484-
fp32_input_to_low_precision_node[input]=node
495+
fp32_input_to_low_precision_node[input].append(node)
485496
else:
486497
cast_to_fp16.append(input)
487498

@@ -536,7 +547,7 @@ def _convert_initializers(
536547
low_precision_nodes: List of node names that should use low precision initializers.
537548
high_precision_nodes: List of node names that should use high precision initializers.
538549
"""
539-
# 1. Compute a mapping frominitiailizers to high precision nodes & low precision nodes that use them.
550+
# 1. Compute a mapping frominitializers to high precision nodes & low precision nodes that use them.
540551
low_precision_nodes_set:set[str]=set(low_precision_nodes)
541552
high_precision_nodes_set:set[str]=set(high_precision_nodes)
542553
initializer_to_nodes:dict[str,InitializerConsumerTracker]=defaultdict(
@@ -888,7 +899,7 @@ def _add_cast(
888899
)
889900

890901
iftensor_to_consumersisNone:
891-
utils.get_consumer_nodes(self.model,tensor_name)
902+
consumer_nodes=utils.get_consumer_nodes(self.model,tensor_name)
892903
else:
893904
consumer_nodes=tensor_to_consumers.get(tensor_name, [])
894905
consumer_nodes= [nforninconsumer_nodesifn.namenotinexclude_consumers]
@@ -1272,13 +1283,9 @@ def _sanitize_model(self):
12721283
graph_sanitizer.sanitize()
12731284
self.model=graph_sanitizer.model
12741285

1275-
def_should_skip_low_precision_input_conversion(
1276-
self,node:onnx.NodeProto,input_name:str
1277-
)->bool:
1278-
"""Check if the input should be skipped for low precision conversion.
1279-
1280-
This is used for nodes that have inputs that MUST remain in FP32.
1281-
"""
1286+
def_create_skip_inputs_mapping(self,tensor_block_dict:dict[str,dict[str,list[int]]]= {}):
1287+
"""Create mapping of op types to indices of inputs that should not be converted to low precision."""
1288+
skip_inputs_map= {}
12821289
matchself.low_precision_type.str_short:
12831290
case"fp16":
12841291
skip_inputs_map=SKIP_LOW_PRECISION_MAPPING_FP16
@@ -1287,12 +1294,27 @@ def _should_skip_low_precision_input_conversion(
12871294
case _:
12881295
raiseValueError(f"Unsupported low precision type:{self.low_precision_type}")
12891296

1290-
ifnode.op_typeinskip_inputs_map:
1297+
# Update mapping with user-defined information
1298+
forop,tensor_mapintensor_block_dict.items():
1299+
high_precision_tensor=tensor_map.get("inp", [])
1300+
ifhigh_precision_tensor:
1301+
skip_inputs_map.update({op:set(high_precision_tensor)})
1302+
1303+
returnskip_inputs_map
1304+
1305+
def_should_skip_low_precision_input_conversion(
1306+
self,node:onnx.NodeProto,input_name:str
1307+
)->bool:
1308+
"""Check if the input should be skipped for low precision conversion.
1309+
1310+
This is used for nodes that have inputs that MUST remain in FP32.
1311+
"""
1312+
ifnode.op_typeinself.skip_inputs_map:
12911313
# Figure out the index of the input in the node input
12921314
inputs_lst=list(node.input)
12931315
ifinput_namenotininputs_lst:
12941316
raiseValueError(f"Input{input_name} not found in node{node.name}.")
12951317
input_index=inputs_lst.index(input_name)
12961318
# Check if we should skip this input for low precision conversion
1297-
returninput_indexinskip_inputs_map[node.op_type]
1319+
returninput_indexinself.skip_inputs_map[node.op_type]
12981320
returnFalse

‎modelopt/onnx/quantization/fp8.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def quantize(
169169
op_types_to_quantize:list[str]|None=None,
170170
op_types_to_exclude:list[str]|None=None,
171171
op_types_to_exclude_fp16:list[str]|None=None,
172+
custom_ops_to_cast_fp32:dict|None=None,
172173
nodes_to_quantize:list[str]|None=None,
173174
nodes_to_exclude:list[str]|None=None,
174175
use_external_data_format:bool=False,
@@ -324,6 +325,7 @@ def quantize(
324325
onnx_model,
325326
keep_io_types=notdirect_io_types,
326327
op_block_list=op_types_to_exclude_fp16or [],
328+
tensor_block_dict=custom_ops_to_cast_fp32or {},
327329
low_precision_type=high_precision_dtype,
328330
trt_plugins=trt_extra_plugin_lib_paths,
329331
)

‎modelopt/onnx/quantization/int8.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def quantize(
120120
op_types_to_quantize:list[str]|None=None,
121121
op_types_to_exclude:list[str]|None=None,
122122
op_types_to_exclude_fp16:list[str]|None=None,
123+
custom_ops_to_cast_fp32:dict|None=None,
123124
nodes_to_quantize:list[str]|None=None,
124125
nodes_to_exclude:list[str]|None=None,
125126
use_external_data_format:bool=False,
@@ -285,6 +286,7 @@ def quantize(
285286
onnx_model,
286287
keep_io_types=notdirect_io_types,
287288
op_block_list=op_types_to_exclude_fp16or [],
289+
tensor_block_dict=custom_ops_to_cast_fp32or {},
288290
low_precision_type=high_precision_dtype,
289291
trt_plugins=trt_extra_plugin_lib_paths,
290292
)

‎modelopt/onnx/quantization/qdq_utils.py‎

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -872,22 +872,32 @@ def remove_input_dq_and_output_q(
872872
)
873873

874874
# Only remove DQs from the inputs of custom ops
875-
ifconsumers[0].op_typenotinquantizable_custom_ops:
875+
has_cast=consumers[0].op_type=="Cast"
876+
consumers_2=tensor_consumers[consumers[0].output[0]]ifhas_castelseconsumers
877+
ifconsumers_2[0].op_typenotinquantizable_custom_ops:
876878
continue
877879

878-
# Rewire graph to connect Q with the node after DQ (skip DQ)
879-
forconsumerinconsumers:
880-
forcons_idx,cons_inpinenumerate(consumer.input):
881-
ifcons_inp==node.output[0]:
882-
# If the input tensor is meant to be quantized, delete DQ. Otherwise, delete both Q/DQ.
883-
ifcons_idxinquantizable_custom_ops[consumer.op_type]["inp"]:
884-
consumer.input[cons_idx]=q_node.output[0]
885-
else:
886-
q_node_prev=tensor_producers.get(q_node.input[0],None)
887-
consumer.input[cons_idx]= (
888-
q_node_prev.output[0]ifq_node_prevelseq_node.input[0]
889-
)
890-
break
880+
ifhas_cast:
881+
# Assume that this input tensor is not meant to be quantized as there's a Cast node between DQ
882+
# and the custom op. Keep the Cast node and delete both Q/DQ nodes.
883+
q_node_prev=tensor_producers.get(q_node.input[0],None)
884+
consumers[0].input[0]= (
885+
q_node_prev.output[0]ifq_node_prevelseq_node.input[0]
886+
)
887+
else:
888+
# Rewire graph to connect Q with the node after DQ (skip DQ)
889+
forconsumerinconsumers:
890+
forcons_idx,cons_inpinenumerate(consumer.input):
891+
ifcons_inp==node.output[0]:
892+
# If the input tensor is meant to be quantized, delete DQ. Otherwise, delete both Q/DQ.
893+
ifcons_idxinquantizable_custom_ops[consumer.op_type]["inp"]:
894+
consumer.input[cons_idx]=q_node.output[0]
895+
else:
896+
q_node_prev=tensor_producers.get(q_node.input[0],None)
897+
consumer.input[cons_idx]= (
898+
q_node_prev.output[0]ifq_node_prevelseq_node.input[0]
899+
)
900+
break
891901

892902
# Track DequantizeLinear node indices for cleanup
893903
dq_indices.append(node_idx)
@@ -944,6 +954,11 @@ def remove_input_dq_and_output_q(
944954
f"{len(dq_indices)} DQ node{''iflen(dq_indices)==1else's'}"
945955
)
946956

957+
# Cleanup graph to remove any dangling Q/DQ nodes
958+
graph=gs.import_onnx(onnx_model)
959+
graph.cleanup()
960+
onnx_model=gs.export_onnx(graph)
961+
947962
# TODO: remove manual ir_version change once ORT supports ir_version 11
948963
onnx_model.ir_version=10
949964

‎modelopt/onnx/quantization/quantize.py‎

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -430,16 +430,6 @@ def quantize(
430430
)
431431
trt_plugins=update_trt_ep_support(calibration_eps,has_dds_op,has_custom_op,trt_plugins)# type: ignore[arg-type]
432432

433-
# Update list with op types to exclude from FP16/BF16 conversion
434-
op_types_to_exclude_fp16=list(
435-
dict.fromkeys((op_types_to_exclude_fp16or [])+list(custom_ops_to_cast_fp32.keys()))
436-
)
437-
ifhigh_precision_dtype=="fp32"andop_types_to_exclude_fp16:
438-
logger.warning(
439-
"Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. "
440-
"Since the model won't be converted to a lower precision, this flag is void."
441-
)
442-
443433
# Use random scales if calibration data is not supplied
444434
ifcalibration_dataisNone:
445435
calibration_data_reader=RandomDataProvider(onnx_path,calibration_shapes)
@@ -485,6 +475,7 @@ def quantize(
485475
op_types_to_quantize=op_types_to_quantize,
486476
op_types_to_exclude=op_types_to_exclude,
487477
op_types_to_exclude_fp16=op_types_to_exclude_fp16,
478+
custom_ops_to_cast_fp32=custom_ops_to_cast_fp32,
488479
nodes_to_quantize=nodes_to_quantize,
489480
nodes_to_exclude=nodes_to_exclude,
490481
use_external_data_format=use_external_data_format,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp