Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit81aeb50

Browse files
committed
[None][chore] AutoDeploy: replace HF's deprecated keyword torch_dtype --> dtype
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent5ff4f88 commit81aeb50

File tree

4 files changed

+9
-14
lines changed

4 files changed

+9
-14
lines changed

‎3rdparty/cutlass‎

Submodulecutlass updated96 files

‎tensorrt_llm/_torch/auto_deploy/models/hf.py‎

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,6 @@ def __init__(self, *args, **kwargs):
107107
self.model_kwargs,
108108
)
109109

110-
# special handling for torch_dtype in model_kwargs since HF does not correctly update
111-
# torch_dtype string to an actual torch.dtype object (only with default)
112-
if"torch_dtype"inself.model_kwargs:
113-
dtype=self.model_kwargs["torch_dtype"]
114-
ifisinstance(dtype,str):
115-
dtype=getattr(torch,self.model_kwargs["torch_dtype"])
116-
assertisinstance(dtype,torch.dtype),f"Invalid dtype:{dtype}"
117-
self.model_kwargs["torch_dtype"]=dtype
118-
119110
# set sharding config source to huggingface
120111
self._sharding_config["source"]=ShardingConfigSource.HUGGINGFACE
121112

@@ -159,6 +150,10 @@ def _recursive_update_config(
159150
setattr(config,key,updated_value)
160151
ifchild_unused:
161152
nested_unused_kwargs[key]=child_unused
153+
elifkeyin ["torch_dtype","dtype"]andisinstance(value_new,str):
154+
# check special handling of torch_dtype (DEPRECATED!) and dtype key to ensure we
155+
# use the correct torch.dtype object instead of a string.
156+
setattr(config,key,getattr(torch,value_new))
162157
else:
163158
# Direct update for simple values
164159
setattr(config,key,value_new)
@@ -278,7 +273,7 @@ def build_and_load_model(self, device: DeviceLikeType) -> nn.Module:
278273
"trust_remote_code":True,
279274
"tp_plan":"auto",
280275
**unused_kwargs,
281-
"torch_dtype":"auto",# takes precedence over unused_kwargs!
276+
"dtype":"auto",# takes precedence over unused_kwargs!
282277
},
283278
)
284279
model.eval()

‎tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
465465
"ibm-ai-platform/Bamba-9B-v2": {
466466
"llm_models_subdir":"Bamba-9B-v2",
467467
"model_kwargs": {
468-
"torch_dtype":"bfloat16",
468+
"dtype":"bfloat16",
469469
"hidden_size":64,
470470
"intermediate_size":128,
471471
"mamba_chunk_size":64,
@@ -484,7 +484,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
484484
"nvidia/NVIDIA-Nemotron-Nano-12B-v2": {
485485
"llm_models_subdir":"NVIDIA-Nemotron-Nano-12B-v2",
486486
"model_kwargs": {
487-
"torch_dtype":"bfloat16",
487+
"dtype":"bfloat16",
488488
"hidden_size":32,
489489
"intermediate_size":64,
490490
"mamba_head_dim":40,

‎tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_hybrid_patches.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_bamba_patches(model_dir: str, run_verify_generation: bool):
4848
**common_kwargs,
4949
"model_kwargs": {
5050
"use_cache":use_cache,
51-
"torch_dtype":"bfloat16",
51+
"dtype":"bfloat16",
5252
},
5353
}
5454
llm_args=AutoDeployConfig(**llm_args)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp