Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite35db17

Browse files
authored
Support Wan2.2 t2v diffusers quantization (#556)
## What does this PR do?**Type of change:**new feature**Overview:**Support Wan2.2 t2v diffusers quantization1. fix torch2.9 support2. add Wan2.2 t2v diffusers pipeline quantization Main difference of the Wan2.2 pipeline comparing to exisiting pipelines is that there are 2 backbone models for denoising. For the quantization therefore we need to quantize both of them. However, it turns out our base library does not well support quantization of multiple models in the same time. Therefore, the change here just stick to quantize a single model each time, and then run the quantization multiple times. So, we need to allow users to pick which backbone to quantize, therefore adding a new argment for it3. add a workaround for the exporting ONNX issue when we upgrade diffusers to >= 0.35.0. The issue lies is the exporting of the torch.nn.RMSNorm. Some pipelines in the diffusers > 0.35.0 use the torch version RMSNorm while before that they use the diffusers' own version of RMSNorm. It turns out they are directly replacable so the workaround is to simply replace the torch RMSNorm usages with diffusers RMSNorm. But we need to fix it properly soon by porting our ONNX export to be based on torch dynamo instead of torchscript. Issuereported from external user:#2624. allow use of a prompts file, which is simply a text file with a list of prompts, one prompt each line5. allow each component of a pipeline to have different dtype accuracy. added a new list stype command line arg --component-dtype for this. example: --component-dtype vae:Float6. print the summary of the quantized model so users can capture issuesfrom log## Usagepython quantize.py \ --model wan2.2-t2v-14b \ --format fp8 \ --batch-size 4 \ --calib-size 64 \ --n-steps 20 \ --backbone transformer \ --model-dtype BFloat16 \ --component-dtype vae:Float \ --trt-high-precision-dtype BFloat16 \ --quantized-torch-ckpt-save-path ./wan_transformer.pt \ --onnx-dir wan-transformer-onnx \ --prompts-file wan-prompts.txt## TestingTested SDXL_BASE, LTX_VIDEO_DEV, WAN22_T2V## Before your PR is "*Ready for review*"<!-- If you haven't finished some of the above items you can still open`Draft` PR. -->- **Make sure you read and follow [Contributorguidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**and your commits are signed.- **Is this change backward compatible?**: Yes/No <!--- If No, explainwhy. -->- **Did you write any new necessary tests?**: Yes/No- **Did you add or update any necessary documentation?**: Yes/No- **Did you update[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:Yes/No <!--- Only for new features, API changes, critical bug fixes orbw breaking changes. -->## Additional Information#262---------Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parentbc52b6c commite35db17

File tree

4 files changed

+349
-113
lines changed

4 files changed

+349
-113
lines changed

‎examples/diffusers/quantization/onnx_utils/export.py‎

Lines changed: 95 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@
3838
importonnx
3939
importonnx_graphsurgeonasgs
4040
importtorch
41-
fromdiffusers.models.transformersimportFluxTransformer2DModel,SD3Transformer2DModel
41+
fromdiffusers.models.transformersimport (
42+
FluxTransformer2DModel,
43+
SD3Transformer2DModel,
44+
WanTransformer3DModel,
45+
)
4246
fromdiffusers.models.transformers.transformer_ltximportLTXVideoTransformer3DModel
4347
fromdiffusers.models.unetsimportUNet2DConditionModel
4448
fromtorch.onnximportexportasonnx_export
@@ -104,6 +108,11 @@
104108
"encoder_attention_mask": {0:"batch_size"},
105109
"video_coords": {0:"batch_size",2:"latent_dim"},
106110
},
111+
"wan2.2-t2v-14b": {
112+
"hidden_states": {0:"batch_size",2:"frame_num",3:"height",4:"width"},
113+
"encoder_hidden_states": {0:"batch_size"},
114+
"timestep": {0:"batch_size"},
115+
},
107116
}
108117

109118

@@ -159,7 +168,7 @@ def _gen_dummy_inp_and_dyn_shapes_sdxl(backbone, min_bs=1, opt_bs=1):
159168
"added_cond_kwargs.time_ids": {"min": [min_bs,6],"opt": [opt_bs,6]},
160169
}
161170

162-
dummy_input= {
171+
dummy_kwargs= {
163172
"sample":torch.randn(*dynamic_shapes["sample"]["min"]),
164173
"timestep":torch.ones(1),
165174
"encoder_hidden_states":torch.randn(*dynamic_shapes["encoder_hidden_states"]["min"]),
@@ -169,9 +178,9 @@ def _gen_dummy_inp_and_dyn_shapes_sdxl(backbone, min_bs=1, opt_bs=1):
169178
},
170179
"return_dict":False,
171180
}
172-
dummy_input=torch_to(dummy_input,dtype=backbone.dtype)
181+
dummy_kwargs=torch_to(dummy_kwargs,dtype=backbone.dtype)
173182

174-
returndummy_input,dynamic_shapes
183+
returndummy_kwargs,dynamic_shapes
175184

176185

177186
def_gen_dummy_inp_and_dyn_shapes_sd3(backbone,min_bs=1,opt_bs=1):
@@ -196,16 +205,16 @@ def _gen_dummy_inp_and_dyn_shapes_sd3(backbone, min_bs=1, opt_bs=1):
196205
},
197206
}
198207

199-
dummy_input= {
208+
dummy_kwargs= {
200209
"hidden_states":torch.randn(*dynamic_shapes["hidden_states"]["min"]),
201210
"timestep":torch.ones(1),
202211
"encoder_hidden_states":torch.randn(*dynamic_shapes["encoder_hidden_states"]["min"]),
203212
"pooled_projections":torch.randn(*dynamic_shapes["pooled_projections"]["min"]),
204213
"return_dict":False,
205214
}
206-
dummy_input=torch_to(dummy_input,dtype=backbone.dtype)
215+
dummy_kwargs=torch_to(dummy_kwargs,dtype=backbone.dtype)
207216

208-
returndummy_input,dynamic_shapes
217+
returndummy_kwargs,dynamic_shapes
209218

210219

211220
def_gen_dummy_inp_and_dyn_shapes_flux(backbone,min_bs=1,opt_bs=1):
@@ -237,7 +246,7 @@ def _gen_dummy_inp_and_dyn_shapes_flux(backbone, min_bs=1, opt_bs=1):
237246
dynamic_shapes["guidance"]= {"min": [1],"opt": [1]}
238247

239248
dtype=backbone.dtype
240-
dummy_input= {
249+
dummy_kwargs= {
241250
"hidden_states":torch.randn(*dynamic_shapes["hidden_states"]["min"],dtype=dtype),
242251
"encoder_hidden_states":torch.randn(
243252
*dynamic_shapes["encoder_hidden_states"]["min"],dtype=dtype
@@ -251,9 +260,9 @@ def _gen_dummy_inp_and_dyn_shapes_flux(backbone, min_bs=1, opt_bs=1):
251260
"return_dict":False,
252261
}
253262
ifcfg.guidance_embeds:# flux-dev
254-
dummy_input["guidance"]=torch.full((1,),3.5,dtype=torch.float32)
263+
dummy_kwargs["guidance"]=torch.full((1,),3.5,dtype=torch.float32)
255264

256-
returndummy_input,dynamic_shapes
265+
returndummy_kwargs,dynamic_shapes
257266

258267

259268
def_gen_dummy_inp_and_dyn_shapes_ltx(backbone,min_bs=2,opt_bs=2):
@@ -282,7 +291,7 @@ def _gen_dummy_inp_and_dyn_shapes_ltx(backbone, min_bs=2, opt_bs=2):
282291
"opt": [opt_bs,3,video_dim],
283292
},
284293
}
285-
dummy_input= {
294+
dummy_kwargs= {
286295
"hidden_states":torch.randn(*dynamic_shapes["hidden_states"]["min"],dtype=dtype),
287296
"encoder_hidden_states":torch.randn(
288297
*dynamic_shapes["encoder_hidden_states"]["min"],dtype=dtype
@@ -293,7 +302,57 @@ def _gen_dummy_inp_and_dyn_shapes_ltx(backbone, min_bs=2, opt_bs=2):
293302
),
294303
"video_coords":torch.randn(*dynamic_shapes["video_coords"]["min"],dtype=dtype),
295304
}
296-
returndummy_input,dynamic_shapes
305+
306+
returndummy_kwargs,dynamic_shapes
307+
308+
309+
def_gen_dummy_inp_and_dyn_shapes_wan(backbone,min_bs=1,opt_bs=2):
310+
assertisinstance(backbone,WanTransformer3DModel)
311+
dtype=backbone.dtype
312+
313+
channels=16# latent channels from VAE
314+
hidden_size=4096# text encoder hidden size (UMT5-XXL)
315+
316+
# num of frames for wan is 4*n+1, as from the official codebase:
317+
# https://github.com/Wan-Video/Wan2.2/blob/e9783574ef77be11fcab9aa5607905402538c08d/generate.py#L126
318+
# picking n == 1 as min, n = 20 as opt as 81 is the default num of frames in their code base
319+
min_num_frames=4*1+1
320+
opt_num_frames=4*20+1
321+
322+
# height and width configs are from their codebase:
323+
# https://github.com/Wan-Video/Wan2.2/blob/e9783574ef77be11fcab9aa5607905402538c08d/wan/configs/__init__.py#L21
324+
min_height=480
325+
min_width=480
326+
327+
# height max can be 1280, but opt setting is 1280x720, so use 720 here
328+
opt_height=720
329+
opt_width=1280
330+
331+
min_latent_height=min_height//8
332+
min_latent_width=min_width//8
333+
opt_latent_height=opt_height//8
334+
opt_latent_width=opt_width//8
335+
336+
dynamic_shapes= {
337+
"hidden_states": {
338+
"min": [min_bs,channels,min_num_frames,min_latent_height,min_latent_width],
339+
"opt": [opt_bs,channels,opt_num_frames,opt_latent_height,opt_latent_width],
340+
},
341+
"encoder_hidden_states": {
342+
"min": [min_bs,512,hidden_size],
343+
"opt": [opt_bs,512,hidden_size],
344+
},
345+
"timestep": {"min": [min_bs],"opt": [opt_bs]},
346+
}
347+
348+
dummy_kwargs= {
349+
"hidden_states":torch.randn(*dynamic_shapes["hidden_states"]["min"],dtype=dtype),
350+
"encoder_hidden_states":torch.randn(
351+
*dynamic_shapes["encoder_hidden_states"]["min"],dtype=dtype
352+
),
353+
"timestep":torch.ones(*dynamic_shapes["timestep"]["min"],dtype=dtype),
354+
}
355+
returndummy_kwargs,dynamic_shapes
297356

298357

299358
defupdate_dynamic_axes(model_id,dynamic_axes):
@@ -327,30 +386,32 @@ def _create_dynamic_shapes(dynamic_shapes):
327386
defgenerate_dummy_inputs_and_dynamic_axes_and_shapes(model_id,backbone):
328387
"""Generate dummy inputs, dynamic axes, and dynamic shapes for the given model."""
329388
ifmodel_idin ["sdxl-1.0","sdxl-turbo"]:
330-
dummy_input,dynamic_shapes=_gen_dummy_inp_and_dyn_shapes_sdxl(
331-
backbone,min_bs=2,opt_bs=16
389+
dummy_kwargs,dynamic_shapes=_gen_dummy_inp_and_dyn_shapes_sdxl(
390+
backbone,min_bs=1,opt_bs=16
332391
)
333392
elifmodel_idin ["sd3-medium","sd3.5-medium"]:
334-
dummy_input,dynamic_shapes=_gen_dummy_inp_and_dyn_shapes_sd3(
335-
backbone,min_bs=2,opt_bs=16
393+
dummy_kwargs,dynamic_shapes=_gen_dummy_inp_and_dyn_shapes_sd3(
394+
backbone,min_bs=1,opt_bs=16
336395
)
337396
elifmodel_idin ["flux-dev","flux-schnell"]:
338-
dummy_input,dynamic_shapes=_gen_dummy_inp_and_dyn_shapes_flux(
339-
backbone,min_bs=1,opt_bs=1
397+
dummy_kwargs,dynamic_shapes=_gen_dummy_inp_and_dyn_shapes_flux(
398+
backbone,min_bs=1,opt_bs=2
340399
)
341400
elifmodel_id=="ltx-video-dev":
342-
dummy_input,dynamic_shapes=_gen_dummy_inp_and_dyn_shapes_ltx(
343-
backbone,min_bs=2,opt_bs=2
401+
dummy_kwargs,dynamic_shapes=_gen_dummy_inp_and_dyn_shapes_ltx(
402+
backbone,min_bs=1,opt_bs=2
403+
)
404+
elifmodel_id=="wan2.2-t2v-14b":
405+
dummy_kwargs,dynamic_shapes=_gen_dummy_inp_and_dyn_shapes_wan(
406+
backbone,min_bs=1,opt_bs=2
344407
)
345408
else:
346409
raiseNotImplementedError(f"Unsupported model_id:{model_id}")
347410

348-
dummy_input=torch_to(dummy_input,device=backbone.device)
349-
dummy_inputs= (dummy_input,)
411+
dummy_kwargs=torch_to(dummy_kwargs,device=backbone.device)
350412
dynamic_axes=MODEL_ID_TO_DYNAMIC_AXES[model_id]
351-
dynamic_shapes=_create_dynamic_shapes(dynamic_shapes)
352413

353-
returndummy_inputs,dynamic_axes,dynamic_shapes
414+
returndummy_kwargs,dynamic_axes,dynamic_shapes
354415

355416

356417
defget_io_shapes(model_id,onnx_load_path,dynamic_shapes):
@@ -415,7 +476,7 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision):
415476
configure_linear_module_onnx_quantizers(backbone)ifprecision=="fp4"elsenullcontext()
416477
)
417478

418-
dummy_inputs,dynamic_axes,_=generate_dummy_inputs_and_dynamic_axes_and_shapes(
479+
dummy_kwargs,dynamic_axes,_=generate_dummy_inputs_and_dynamic_axes_and_shapes(
419480
model_name,backbone
420481
)
421482

@@ -449,6 +510,13 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision):
449510
"video_coords",
450511
]
451512
output_names= ["latent"]
513+
elifmodel_namein ["wan2.2-t2v-14b"]:
514+
input_names= [
515+
"hidden_states",
516+
"timestep",
517+
"encoder_hidden_states",
518+
]
519+
output_names= ["latent"]
452520
else:
453521
raiseNotImplementedError(f"Unsupported model_id:{model_name}")
454522

@@ -458,8 +526,9 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision):
458526
withquantizer_context,torch.inference_mode():
459527
onnx_export(
460528
backbone,
461-
dummy_inputs,
529+
(),
462530
f=tmp_output.as_posix(),
531+
kwargs=dummy_kwargs,
463532
input_names=input_names,
464533
output_names=output_names,
465534
dynamic_axes=dynamic_axes,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp