Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit0a4f0a8

Browse files
authored
[OMNIML-2244] Implement the ONNX quantization exporter for INT4 (#575)
## What does this PR do?**Type of change:** New Feature**Overview:** - Created an abstract parent class for ONNXQuantExporter- Created child classes for individual precisions- Implemented the INT4QuantExporter- Removed quantize_weights_to_int4- Added a method to quantize weights of the ONNX model to low precision## Testing```python torch_quant_to_onnx.py --quantize_mode=int4_awq \--onnx_save_path=<onnx_path> \```## Before your PR is "*Ready for review*"<!-- If you haven't finished some of the above items you can still open`Draft` PR. -->- **Make sure you read and follow [Contributorguidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**and your commits are signed.- **Is this change backward compatible?**: Yes- **Did you write any new necessary tests?**: No- **Did you add or update any necessary documentation?**: Yes- **Did you update[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:No <!--- Only for new features, API changes, critical bug fixes or bwbreaking changes. -->---------Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent261858c commit0a4f0a8

File tree

10 files changed

+595
-182
lines changed

10 files changed

+595
-182
lines changed

‎modelopt/onnx/export/__init__.py‎

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""ONNX export utilities."""
17+
18+
__all__= [
19+
"FP8QuantExporter",
20+
"INT4QuantExporter",
21+
"INT8QuantExporter",
22+
"MXFP8QuantExporter",
23+
"NVFP4QuantExporter",
24+
"ONNXQuantExporter",
25+
]
26+
27+
from .base_exporterimportONNXQuantExporter
28+
from .fp8_exporterimportFP8QuantExporter
29+
from .int4_exporterimportINT4QuantExporter
30+
from .int8_exporterimportINT8QuantExporter
31+
from .mxfp8_exporterimportMXFP8QuantExporter
32+
from .nvfp4_exporterimportNVFP4QuantExporter
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Base class for ONNX quantizer exporters."""
17+
18+
fromabcimportABC,abstractmethod
19+
20+
importonnx
21+
22+
23+
classONNXQuantExporter(ABC):
24+
"""Base class for ONNX quantizer exporters."""
25+
26+
@classmethod
27+
defprocess_model(cls,onnx_model:onnx.ModelProto)->onnx.ModelProto:
28+
"""Processes the ONNX model."""
29+
onnx_model=cls.pre_process(onnx_model)
30+
onnx_model=cls.compute_scales(onnx_model)
31+
onnx_model=cls.compress_weights(onnx_model)
32+
onnx_model=cls.post_process(onnx_model)
33+
returnonnx_model
34+
35+
@staticmethod
36+
@abstractmethod
37+
defpre_process(onnx_model:onnx.ModelProto)->onnx.ModelProto:
38+
"""Pre-processes the ONNX model. Converts all DQ -> * -> op patterns to DQ -> op."""
39+
40+
@staticmethod
41+
@abstractmethod
42+
defcompute_scales(onnx_model:onnx.ModelProto)->onnx.ModelProto:
43+
"""Computes the scales for the weights in the ONNX model."""
44+
45+
@staticmethod
46+
@abstractmethod
47+
defcompress_weights(onnx_model:onnx.ModelProto)->onnx.ModelProto:
48+
"""Compresses the weights in the ONNX model."""
49+
50+
@staticmethod
51+
@abstractmethod
52+
defpost_process(onnx_model:onnx.ModelProto)->onnx.ModelProto:
53+
"""Post-processes the ONNX model."""
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""FP8 quantization exporter."""
17+
18+
importonnx
19+
20+
from .base_exporterimportONNXQuantExporter
21+
22+
23+
# TODO: Implement the FP8QuantExporter
24+
classFP8QuantExporter(ONNXQuantExporter):
25+
"""Exporter for FP8 quantization."""
26+
27+
@staticmethod
28+
defpre_process(onnx_model:onnx.ModelProto)->onnx.ModelProto:
29+
"""Pre-processes the ONNX model for FP8 quantization."""
30+
31+
@staticmethod
32+
defcompute_scales(onnx_model:onnx.ModelProto)->onnx.ModelProto:
33+
"""Computes the scales for the weights in the ONNX model for FP8 quantization."""
34+
35+
@staticmethod
36+
defcompress_weights(onnx_model:onnx.ModelProto)->onnx.ModelProto:
37+
"""Compresses the weights in the ONNX model for FP8 quantization."""
38+
39+
@staticmethod
40+
defpost_process(onnx_model:onnx.ModelProto)->onnx.ModelProto:
41+
"""Post-processes the ONNX model for FP8 quantization."""

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp