Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit9b54b3b

Browse files
authored
[None][chore] AutoDeploy: replace HF's deprecated keyword torch_dtype --> dtype (#8510)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent8dc4aac commit9b54b3b

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

‎tensorrt_llm/_torch/auto_deploy/models/hf.py‎

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,6 @@ def __init__(self, *args, **kwargs):
107107
self.model_kwargs,
108108
)
109109

110-
# special handling for torch_dtype in model_kwargs since HF does not correctly update
111-
# torch_dtype string to an actual torch.dtype object (only with default)
112-
if"torch_dtype"inself.model_kwargs:
113-
dtype=self.model_kwargs["torch_dtype"]
114-
ifisinstance(dtype,str):
115-
dtype=getattr(torch,self.model_kwargs["torch_dtype"])
116-
assertisinstance(dtype,torch.dtype),f"Invalid dtype:{dtype}"
117-
self.model_kwargs["torch_dtype"]=dtype
118-
119110
# set sharding config source to huggingface
120111
self._sharding_config["source"]=ShardingConfigSource.HUGGINGFACE
121112

@@ -159,6 +150,16 @@ def _recursive_update_config(
159150
setattr(config,key,updated_value)
160151
ifchild_unused:
161152
nested_unused_kwargs[key]=child_unused
153+
elif (
154+
keyin ["torch_dtype","dtype"]
155+
andisinstance(value_new,str)
156+
andvalue_new!="auto"
157+
):
158+
# check special handling of torch_dtype (DEPRECATED!) and dtype key to ensure we
159+
# use the correct torch.dtype object instead of a string.
160+
dtype=getattr(torch,value_new)
161+
assertisinstance(dtype,torch.dtype),f"Invalid{dtype=}"
162+
setattr(config,key,dtype)
162163
else:
163164
# Direct update for simple values
164165
setattr(config,key,value_new)
@@ -278,7 +279,7 @@ def build_and_load_model(self, device: DeviceLikeType) -> nn.Module:
278279
"trust_remote_code":True,
279280
"tp_plan":"auto",
280281
**unused_kwargs,
281-
"torch_dtype":"auto",# takes precedence over unused_kwargs!
282+
"dtype":"auto",# takes precedence over unused_kwargs!
282283
},
283284
)
284285
model.eval()

‎tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
465465
"ibm-ai-platform/Bamba-9B-v2": {
466466
"llm_models_subdir":"Bamba-9B-v2",
467467
"model_kwargs": {
468-
"torch_dtype":"bfloat16",
468+
"dtype":"bfloat16",
469469
"hidden_size":64,
470470
"intermediate_size":128,
471471
"mamba_chunk_size":64,
@@ -484,7 +484,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
484484
"nvidia/NVIDIA-Nemotron-Nano-12B-v2": {
485485
"llm_models_subdir":"NVIDIA-Nemotron-Nano-12B-v2",
486486
"model_kwargs": {
487-
"torch_dtype":"bfloat16",
487+
"dtype":"bfloat16",
488488
"hidden_size":32,
489489
"intermediate_size":64,
490490
"mamba_head_dim":40,

‎tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_hybrid_patches.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_bamba_patches(model_dir: str, run_verify_generation: bool):
4848
**common_kwargs,
4949
"model_kwargs": {
5050
"use_cache":use_cache,
51-
"torch_dtype":"bfloat16",
51+
"dtype":"bfloat16",
5252
},
5353
}
5454
llm_args=AutoDeployConfig(**llm_args)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp