Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite20d218

Browse files
authored
[OMNIML-2857] [Experimental] Support the DeepSeek V3.2 model (#435)
## What does this PR do?**Type of change:** ? New model support**Overview:** ?## UsagePlease see examples/deepseek/README.md<!-- This is an auto-generated comment: release notes by coderabbit.ai-->## Summary by CodeRabbit* New Features* Support for DeepSeek V3.2 quantization and automatic detection ofavailable DeepSeek versions.* Triton-backed weight dequantization utility and MoE-aware calibrationmode to improve calibration fidelity.* Documentation* DeepSeek examples README expanded with setup, conversion, calibration,and FP8→FP4 quantization workflows for R1, V3, and V3.2.* Bug Fixes* More robust, failure-tolerant copying of auxiliary files/assets duringquantization.* Chores * Updated changelog and lint/ignore rules for example artifacts.<!-- end of auto-generated comment: release notes by coderabbit.ai -->---------Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com>
1 parent9cd0824 commite20d218

File tree

8 files changed

+203
-23
lines changed

8 files changed

+203
-23
lines changed

‎CHANGELOG.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Model Optimizer Changelog (Linux)
2727
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md<https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
2828
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
2929

30+
3031
0.39 (2025-11-11)
3132
^^^^^^^^^^^^^^^^^
3233

‎examples/deepseek/.gitignore‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
DeepSeek-V3/
2+
DeepSeek-V3.2-Exp/

‎examples/deepseek/README.md‎

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,69 @@
1-
#Quantize DeepseekR1 to FP4
1+
#Quantize Deepseekmodels to FP4
22

3-
This example will demonstrate the steps to quantize DeepSeekR1 model to FP4 and export a unified checkpoint that can be deployed with TRT-LLM.
3+
This example will demonstrate the steps to quantize DeepSeekmodels to FP4 and export a unified checkpoint that can be deployed with TRT-LLM.
44

55
##Setup
66

77
Due to the model size, currently it requires 8xH200 or 16xH100 to quantize the FP8 model, we will use 8xH200 as example.
88

9-
###Convert the HF checkpoint for deepseek FP8 inference
9+
##Convert the HF checkpoint for deepseek FP8 inference
1010

1111
```bash
1212
# set up variables to run the example
1313
export HF_FP8_CKPT={path_to_downloaded_hf_checkpoint}
1414
export DS_CKPT={path_to_save_converted_checkpoint}
1515
export FP4_QUANT_PATH={path_to_save_quantization_results}
1616
export HF_FP4_PATH={path_to_save_the_final_FP4_checkpoint}
17+
```
18+
19+
###DeepSeek V3 R1 V3.1
1720

18-
# download the FP8 checkpoint from Hugginface
21+
```bash
22+
# download the FP8 checkpoint from Hugginface. This is an example of DeepSeek-R1
1923
huggingface-cli download deepseek-ai/DeepSeek-R1 --local-dir$HF_FP8_CKPT
2024

2125
# clone DeepSeek-V3 (base model of R1) Github repository for FP8 inference,
2226
git clone https://github.com/deepseek-ai/DeepSeek-V3.git&&cd DeepSeek-V3&& git checkout 1398800
27+
```
28+
29+
###[Experimental] DeepSeek V3.2
2330

31+
```bash
32+
# download the FP8 checkpoint from Hugginface.
33+
huggingface-cli download deepseek-ai/DeepSeek-V3.2-Exp --local-dir$HF_FP8_CKPT
34+
35+
# clone DeepSeek-V3.2 Github repository for FP8 inference,
36+
git clone https://github.com/deepseek-ai/DeepSeek-V3.2-Exp.git&&cd DeepSeek-V3.2-Exp&& git checkout 3b99a53
37+
38+
# Install requirements
39+
pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git
40+
pip install -r inference/requirements.txt
41+
```
42+
43+
###Convert the Checkpoint
44+
45+
```bash
2446
# convert the HF checkpoint to a specific format for Deepseek
2547
python inference/convert.py --hf-ckpt-path$HF_FP8_CKPT --save-path$DS_CKPT --n-experts 256 --model-parallel 8
2648
```
2749

28-
###Post-training quantization
50+
##Post-training quantization
51+
52+
###Run the calibration scripts
2953

30-
####Run the calibration scripts
54+
DeepSeek V3, R1, V3.1
3155

3256
```bash
3357
torchrun --nproc-per-node 8 --master_port=12346 ptq.py --model_path$DS_CKPT --config DeepSeek-V3/inference/configs/config_671B.json --quant_cfg NVFP4_DEFAULT_CFG --output_path$FP4_QUANT_PATH
3458
```
3559

36-
####Quantize the FP8 hf checkpoint to FP4
60+
DeepSeek V3.2
61+
62+
```bash
63+
torchrun --nproc-per-node 8 --master_port=12346 ptq.py --model_path$DS_CKPT --config DeepSeek-V3.2-Exp/inference/config_671B_v3.2.json --quant_cfg NVFP4_DEFAULT_CFG --output_path$FP4_QUANT_PATH
64+
```
65+
66+
###Quantize the FP8 hf checkpoint to FP4
3767

3868
We provide a one-step-script which will:
3969

‎examples/deepseek/ds_kernel.py‎

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# MIT License
17+
18+
# Copyright (c) 2023 DeepSeek
19+
20+
# Permission is hereby granted, free of charge, to any person obtaining a copy
21+
# of this software and associated documentation files (the "Software"), to deal
22+
# in the Software without restriction, including without limitation the rights
23+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
24+
# copies of the Software, and to permit persons to whom the Software is
25+
# furnished to do so, subject to the following conditions:
26+
27+
# The above copyright notice and this permission notice shall be included in all
28+
# copies or substantial portions of the Software.
29+
30+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
31+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
32+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
33+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
34+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
35+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
36+
# SOFTWARE.
37+
38+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
39+
# SPDX-License-Identifier: Apache-2.0
40+
#
41+
# Licensed under the Apache License, Version 2.0 (the "License");
42+
# you may not use this file except in compliance with the License.
43+
# You may obtain a copy of the License at
44+
#
45+
# http://www.apache.org/licenses/LICENSE-2.0
46+
#
47+
# Unless required by applicable law or agreed to in writing, software
48+
# distributed under the License is distributed on an "AS IS" BASIS,
49+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50+
# See the License for the specific language governing permissions and
51+
# limitations under the License.
52+
53+
importtorch
54+
importtriton
55+
importtriton.languageastl
56+
57+
"""Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py"""
58+
59+
60+
@triton.jit
61+
defweight_dequant_kernel(x_ptr,s_ptr,y_ptr,M,N,BLOCK_SIZE:tl.constexpr):
62+
"""
63+
Dequantizes weights using the provided scaling factors and stores the result.
64+
65+
Args:
66+
x_ptr (tl.pointer): Pointer to the quantized weights.
67+
s_ptr (tl.pointer): Pointer to the scaling factors.
68+
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
69+
M (int): Number of rows in the weight matrix.
70+
N (int): Number of columns in the weight matrix.
71+
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
72+
73+
Returns:
74+
None
75+
"""
76+
pid_m=tl.program_id(axis=0)
77+
pid_n=tl.program_id(axis=1)
78+
n=tl.cdiv(N,BLOCK_SIZE)
79+
offs_m=pid_m*BLOCK_SIZE+tl.arange(0,BLOCK_SIZE)
80+
offs_n=pid_n*BLOCK_SIZE+tl.arange(0,BLOCK_SIZE)
81+
offs=offs_m[:,None]*N+offs_n[None, :]
82+
mask= (offs_m[:,None]<M)& (offs_n[None, :]<N)
83+
x=tl.load(x_ptr+offs,mask=mask).to(tl.float32)
84+
s=tl.load(s_ptr+pid_m*n+pid_n)
85+
y=x*s
86+
tl.store(y_ptr+offs,y,mask=mask)
87+
88+
89+
defweight_dequant(x:torch.Tensor,s:torch.Tensor,block_size:int=128)->torch.Tensor:
90+
"""
91+
Dequantizes the given weight tensor using the provided scale tensor.
92+
93+
Args:
94+
x (torch.Tensor): The quantized weight tensor of shape (M, N).
95+
s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size).
96+
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
97+
98+
Returns:
99+
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
100+
101+
Raises:
102+
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
103+
"""
104+
assertx.is_contiguous()ands.is_contiguous(),"Input tensors must be contiguous"
105+
assertx.dim()==2ands.dim()==2,"Input tensors must have 2 dimensions"
106+
M,N=x.size()
107+
y=torch.empty_like(x,dtype=torch.get_default_dtype())
108+
grid=lambdameta: (triton.cdiv(M,meta["BLOCK_SIZE"]),triton.cdiv(N,meta["BLOCK_SIZE"]))
109+
weight_dequant_kernel[grid](x,s,y,M,N,BLOCK_SIZE=block_size)
110+
returny

‎examples/deepseek/ptq.py‎

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,21 @@
6464
frommodelopt.torch.utils.dataset_utilsimportget_dataset_dataloader
6565
frommodelopt.torch.utils.distributedimportParallelState
6666

67-
sys.path.append(str(Path(__file__).resolve().parent/"DeepSeek-V3/inference"))
68-
importmodelasdeekseep_model
69-
fromkernelimportact_quant,fp8_gemm,weight_dequant
67+
DS_V3_PATH=Path(__file__).resolve().parent/"DeepSeek-V3/inference"
68+
DS_V3_2_PATH=Path(__file__).resolve().parent/"DeepSeek-V3.2-Exp/inference"
69+
70+
ifDS_V3_2_PATH.exists():
71+
sys.path.append(str(DS_V3_2_PATH))
72+
elifDS_V3_PATH.exists():
73+
sys.path.append(str(DS_V3_PATH))
74+
else:
75+
raiseValueError(
76+
f"DeepSeek-V3 or DeepSeek-V3.2-Exp not found in{Path(__file__).resolve().parent}"
77+
)
78+
79+
importmodelasdeekseep_model# noqa: E402
80+
fromds_kernelimportweight_dequant# noqa: E402
81+
fromkernelimportact_quant,fp8_gemm# noqa: E402
7082

7183

7284
defmonkey_patch_deepseek_model():
@@ -186,6 +198,26 @@ def _setup(self):
186198
self.kv_bmm_quantizer=TensorQuantizer()
187199
self.pe_bmm_quantizer=TensorQuantizer()
188200

201+
classCalibMoe(deekseep_model.MoE):
202+
def__init__(self,*args,**kwargs):
203+
super().__init__(*args,**kwargs)
204+
self._setup()
205+
206+
def_setup(self):
207+
self._original_topk=self.gate.topk
208+
self._original_topk_groups=self.gate.topk_groups
209+
210+
defforward(self,x:torch.Tensor)->torch.Tensor:
211+
# Forward all tokens to all experts for calibration
212+
self.gate.topk=self.n_routed_experts
213+
self.gate.topk_groups=self.gate.n_groups
214+
super().forward(x)
215+
# Restore the original topk and topk_groups
216+
self.gate.topk=self._original_topk
217+
self.gate.topk_groups=self._original_topk_groups
218+
219+
returnsuper().forward(x)
220+
189221
mtq.register(
190222
original_cls=deekseep_model.RowParallelLinear,
191223
quantized_cls=QuantRowParallelLinear,
@@ -196,6 +228,7 @@ def _setup(self):
196228
)
197229
mtq.register(original_cls=deekseep_model.Linear,quantized_cls=QuantLinear)
198230
mtq.register(original_cls=deekseep_model.MLA,quantized_cls=QuantMLA)
231+
mtq.register(original_cls=deekseep_model.MoE,quantized_cls=CalibMoe)
199232

200233

201234
defload_deepseek_model(model_config:str,model_path:str,batch_size:int):
@@ -243,10 +276,10 @@ def ptq(
243276
## create dataset
244277
device=next(model.parameters()).device
245278
calib_dataset=get_dataset_dataloader(
246-
dataset_name="cnn_dailymail",
279+
dataset_name=["cnn_dailymail","nemotron-post-training-dataset-v2"],
247280
tokenizer=tokenizer,
248281
batch_size=batch_size,
249-
num_samples=calib_size,
282+
num_samples=[calib_size,calib_size],
250283
device=device,
251284
)
252285

@@ -307,6 +340,13 @@ def state_dict_filter(state_dict):
307340
os.path.join(output_path,f"amax_dict_rank{rank}-mp{world_size}.pt"),
308341
)
309342

343+
# if rank == 0:
344+
# with open("expert_activation_counts.txt", "w") as f:
345+
# for name, module in model.named_modules():
346+
# if isinstance(module, deekseep_model.MoE):
347+
# counts = module.activated_expert_counts()
348+
# f.writelines(f"{name}: {count}\n" for count in counts)
349+
310350
quant_config=get_quant_config(model.named_modules())
311351

312352
ifenable_fp8_kvcache:

‎examples/deepseek/quantize_fp8_to_nvfp4.sh‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ fi
7878

7979
# Copy miscellaneous files to the quantized checkpoint
8080
mkdir -p$FP4_PATH
81-
cp$FP8_HF_PATH/*.json$FP8_HF_PATH/*.py$FP4_PATH/
81+
cp$FP8_HF_PATH/*.json$FP4_PATH/
82+
cp$FP8_HF_PATH/*.py$FP4_PATH/||true
83+
cp -r$FP8_HF_PATH/assets$FP4_PATH/||true
8284

8385
# Run the quantization command
8486
echo"Running quantization..."

‎examples/deepseek/quantize_to_nvfp4.py‎

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,15 @@
4141
importglob
4242
importjson
4343
importos
44-
importsys
45-
frompathlibimportPath
4644
fromtypingimportAny
4745

4846
importtorch
47+
fromds_kernelimportweight_dequant
4948
fromsafetensors.torchimportload_file,save_file
5049
fromtqdmimporttqdm
5150

5251
frommodelopt.torch.quantization.qtensorimportNVFP4QTensor
5352

54-
sys.path.append(str(Path(__file__).resolve().parent/"DeepSeek-V3/inference"))
55-
fromkernelimportweight_dequant
56-
5753

5854
def_remap_key(key_dict:dict[str,Any]):
5955
# renaming the module to match HF modeling
@@ -155,7 +151,7 @@ def convert_fp8_ckpt_to_nvfp4(
155151
per_layer_quant_config,
156152
):
157153
defamax_to_nvfp4_scaling_factor_2(amax):
158-
returnamax.float()/6.0/448.0
154+
returnamax.float()/(6.0*448.0)
159155

160156
defamax_to_fp8_scaling_factor(amax):
161157
returnamax.float()/448.0

‎pyproject.toml‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ extend-ignore = [
6262
"__init__.py" = ["F401","F403"]
6363
"examples/*" = ["D"]
6464
"tests/*" = ["B017","D","E402","PT012"]
65-
"*/_[a-zA-Z]*" = ["D"]# Private packages (_abc/*.py) or modules (_xyz.py)
66-
"*.ipynb" = ["D","E501"]# Ignore missing docstrings or line length for Jupyter notebooks
67-
"modelopt/torch/quantization/triton/*" = ["N803","N806","E731"]# triton style
68-
65+
"*/_[a-zA-Z]*" = ["D"]# Private packages (_abc/*.py) or modules (_xyz.py)
66+
"*.ipynb" = ["D","E501"]# Ignore missing docstrings or line length for Jupyter notebooks
67+
"modelopt/torch/quantization/triton/*" = ["N803","N806","E731"]# triton style
68+
"examples/deepseek/ds_kernel.py" = ["N803","N806","E731"]# triton style
6969

7070
[tool.ruff.lint.pycodestyle]
7171
max-line-length =120# Line length limit for comments and docstrings

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp