Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit4c98e8b

Browse files
committed
feat: batched sampling by strategy (supersedes enable_mixed_sampler)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
1 parent870cfcf commit4c98e8b

File tree

18 files changed

+1005
-296
lines changed

18 files changed

+1005
-296
lines changed

‎setup.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,4 +260,4 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str],
260260
install_requires=required_deps,
261261
dependency_links=
262262
extra_URLs,# Warning: Dependency links support has been dropped by pip 19.0
263-
python_requires=">=3.7, <4")
263+
python_requires=">=3.10, <4")

‎tensorrt_llm/_torch/auto_deploy/llm_args.py‎

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
105105
description="Disable the overlap scheduler in trtllm runtime",
106106
)
107107

108-
enable_mixed_sampler:bool=Field(
109-
default=False,
110-
description="If true, will iterate over sampling_params of each request and use the corresponding "
111-
"sampling strategy, e.g. top-k, top-p, etc.",
112-
)
113-
114108
world_size:int=Field(
115109
default=1,
116110
ge=0,

‎tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py‎

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -337,16 +337,11 @@ def create_autodeploy_executor(ad_config: LlmArgs):
337337
scheduler=SimpleScheduler(capacitor_scheduler,mb_scheduler)
338338

339339
# search sampler with speculative decoding
340-
# TODO (lucaslie, fridah-nv): some models require enable_mixed_sampler=True to have good outputs, see
341-
# https://github.com/NVIDIA/TensorRT-LLM/issues/5254
342-
# We should expose mixed_sample to our build_and_run_ad script so we can configure this
343-
# correctly for models as needed.
344340
sampler_args=TorchSampler.Args(
345341
max_seq_len=ad_config.max_seq_len,
346342
max_draft_len=max_draft_len,
347343
max_num_sequences=max_num_sequences,
348344
max_beam_width=ad_config.max_beam_width,
349-
enable_mixed_sampler=ad_config.enable_mixed_sampler,
350345
)
351346
sampler=TorchSampler(sampler_args)
352347

‎tensorrt_llm/_torch/modules/rms_norm.py‎

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
# limitations under the License.
1515

1616
importenum
17-
fromtypingimportOptional,Tuple,Union
17+
fromtypesimportEllipsisType# https://stackoverflow.com/a/66636313
18+
fromtypingimportOptional,Tuple,TypeAlias,Union,cast
1819

1920
importtorch
2021
fromtorchimportnn
@@ -24,6 +25,9 @@
2425

2526
classRMSNorm(nn.Module):
2627

28+
_ARGUMENT_NOT_SPECIFIED_SENTINEL= ...
29+
_ArgumentNotSpecifiedSentinelType:TypeAlias=EllipsisType
30+
2731
def__init__(
2832
self,
2933
*,
@@ -48,12 +52,19 @@ def __init__(
4852
defforward(
4953
self,
5054
hidden_states:torch.Tensor,
51-
residual:Optional[torch.Tensor]= ...,
52-
)->Union[torch.Tensor,Tuple[torch.Tensor,torch.Tensor]]:
55+
residual:Union[
56+
Optional[torch.Tensor],
57+
_ArgumentNotSpecifiedSentinelType]=_ARGUMENT_NOT_SPECIFIED_SENTINEL,
58+
)->Union[torch.Tensor,Tuple[torch.Tensor,Optional[torch.Tensor]]]:
59+
return_residual=True
60+
ifresidualisself._ARGUMENT_NOT_SPECIFIED_SENTINEL:
61+
return_residual=False
62+
residual=None
63+
5364
ifIS_FLASHINFER_AVAILABLE:
5465
from ..custom_opsimport (flashinfer_fused_add_rmsnorm,
5566
flashinfer_rmsnorm)
56-
ifisinstance(residual,torch.Tensor):
67+
ifresidualisnotNone:
5768
flashinfer_fused_add_rmsnorm(hidden_states,residual,
5869
self.weight,self.variance_epsilon)
5970
else:
@@ -62,7 +73,7 @@ def forward(
6273
else:
6374
input_dtype=hidden_states.dtype
6475
hidden_states=hidden_states.to(torch.float32)
65-
ifisinstance(residual,torch.Tensor):
76+
ifresidualisnotNone:
6677
hidden_states=hidden_states+residual.to(torch.float32)
6778
residual=hidden_states.to(input_dtype)
6879

@@ -71,20 +82,22 @@ def forward(
7182
self.variance_epsilon)
7283
hidden_states=self.weight*hidden_states.to(input_dtype)
7384

74-
ifresidualis ...:
75-
returnhidden_states
85+
ifreturn_residual:
86+
returnhidden_states,cast(Optional[torch.Tensor],residual)
7687
else:
77-
returnhidden_states,residual
88+
returnhidden_states
7889

7990
defskip_forward(
8091
self,
8192
hidden_states:torch.Tensor,
82-
residual:Optional[torch.Tensor]= ...,
83-
)->Union[torch.Tensor,Tuple[torch.Tensor,torch.Tensor]]:
84-
ifresidualis ...:
93+
residual:Union[
94+
Optional[torch.Tensor],
95+
_ArgumentNotSpecifiedSentinelType]=_ARGUMENT_NOT_SPECIFIED_SENTINEL,
96+
)->Union[torch.Tensor,Tuple[torch.Tensor,Optional[torch.Tensor]]]:
97+
ifresidualisself._ARGUMENT_NOT_SPECIFIED_SENTINEL:
8598
returnhidden_states
8699
else:
87-
returnhidden_states,residual
100+
returnhidden_states,cast(Optional[torch.Tensor],residual)
88101

89102

90103
classGroupRMSNormKernelSelection(enum.Enum):

‎tensorrt_llm/_torch/pyexecutor/_util.py‎

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def create_py_executor_instance(
686686

687687

688688
defcreate_torch_sampler_args(mapping:Mapping,*,max_seq_len:int,
689-
enable_mixed_sampler:bool,max_batch_size:int,
689+
max_batch_size:int,
690690
speculative_config:SpeculativeConfig,
691691
max_beam_width:int):
692692
max_num_sequences=max_batch_size*mapping.pp_size
@@ -697,7 +697,6 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
697697
max_draft_len=max_draft_len,
698698
max_num_sequences=max_num_sequences,
699699
max_beam_width=max_beam_width,
700-
enable_mixed_sampler=enable_mixed_sampler,
701700
)
702701

703702

@@ -711,7 +710,6 @@ def instantiate_sampler(engine: PyTorchModelEngine,
711710
sampler_args=create_torch_sampler_args(
712711
mapping,
713712
max_seq_len=engine.max_seq_len,
714-
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler,
715713
max_batch_size=max_batch_size,
716714
speculative_config=speculative_config,
717715
max_beam_width=max_beam_width)

‎tensorrt_llm/_torch/pyexecutor/config.py‎

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,6 @@ class PyTorchConfig:
5656

5757
moe_disable_finalize_fusion:bool=False
5858

59-
enable_mixed_sampler:bool=False
60-
"""
61-
If true, will iterate over sampling_params of each request and use the
62-
corresponding sampling strategy, e.g. top-k, top-p, etc.
63-
"""
6459
sampler_type:SamplerType=SamplerType.auto
6560
"""
6661
The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto.

‎tensorrt_llm/_torch/pyexecutor/llm_request.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def __init__(
365365
exclude_last_generation_logits)
366366
self.child_requests= []
367367

368-
self._py_embedding_bias_1d=None
368+
self._py_embedding_bias_1d:Optional[torch.Tensor]=None
369369
ifhasattr(self,'embedding_bias')andself.embedding_biasisnotNone:
370370
# Pre-squeeze to 1D if needed (remove batch dimension)
371371
ifself.embedding_bias.dim()>1:

‎tensorrt_llm/_torch/pyexecutor/py_executor_creator.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def create_py_executor(
350350
if_get_allow_chain_drafter():
351351
use_chain_drafter= (
352352
guided_decoding_configisNone
353-
andnotpytorch_backend_config.enable_mixed_sampler
353+
anddraft_spec_config._allow_greedy_draft_tokens
354354
andpytorch_backend_config.attn_backend=="TRTLLM")
355355
else:
356356
use_chain_drafter=False

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp