Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit597bd11

Browse files
ixlmaryufeiwu-nv
authored andcommitted
[None][fix] restore list[list[list[int]]] in add_token (NVIDIA#8502)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>Signed-off-by: yufeiwu-nv <230315618+yufeiwu-nv@users.noreply.github.com>
1 parent46fb72a commit597bd11

File tree

4 files changed

+54
-23
lines changed

4 files changed

+54
-23
lines changed

‎tensorrt_llm/_torch/pyexecutor/sampler.py‎

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,13 @@ def _group_requests_by_strategy_key(
290290
}
291291

292292

293-
defadd_token(request:LlmRequest,new_tokens:torch.Tensor,*,beam:int,step:int=0)->int:
293+
defadd_token(
294+
request:LlmRequest,new_tokens:list[list[list[int]]],*,beam:int,step:int=0
295+
)->int:
296+
# NB: Accessing nested lists faster than torch.Tensor or numpy.ndarray
294297
seq_slot=request.py_seq_slot
295298
assertseq_slotisnotNone
296-
new_token=cast(int,new_tokens[step][seq_slot][beam].item())
299+
new_token=new_tokens[step][seq_slot][beam]
297300
request.add_new_token(new_token,beam)
298301
returnnew_token
299302

@@ -700,7 +703,7 @@ def handle_logprobs(
700703
def_process_draft_tokens_greedy(
701704
self,
702705
request:LlmRequest,
703-
new_tokens:torch.Tensor,
706+
new_tokens:list[list[list[int]]],
704707
)->int:
705708
new_token=add_token(request,new_tokens,beam=self.BEAM)
706709
stop=self._handle_stop_criteria(request,new_token)
@@ -722,7 +725,8 @@ def _process_draft_tokens_greedy(
722725
def_process_draft_tokens_tree(
723726
self,
724727
request:LlmRequest,
725-
new_tokens:torch.Tensor,
728+
new_tokens_tensor:torch.Tensor,
729+
new_tokens_list:list[list[list[int]]],
726730
spec_tree_manager:SpecTreeManager,
727731
)->int:
728732
"""Tree verification for draft token tree based speculative decoding.
@@ -757,7 +761,7 @@ def _process_draft_tokens_tree(
757761
# TODO: For the last layer of the dynamic tree, we need to resampling all the draft tokens.
758762
cur_layer_num_nodes=sum(spec_tree_manager.get_top_k_list(cur_draft_layer_idx))
759763
foriinrange(cur_layer_num_nodes):
760-
new_token=add_token(request,new_tokens,beam=0,step=i)
764+
new_token=add_token(request,new_tokens_list,beam=0,step=i)
761765
return0
762766
else:
763767
# handle the target model request
@@ -767,7 +771,9 @@ def _process_draft_tokens_tree(
767771
eagle_paths=spec_tree_manager.get_eagle_paths(seq_slot)
768772

769773
all_draft_tokens=request.py_draft_tokens# [max_total_draft_tokens]
770-
all_target_tokens=new_tokens[:,seq_slot, :].squeeze(-1)# [max_total_draft_tokens]
774+
all_target_tokens=new_tokens_tensor[:,seq_slot, :].squeeze(
775+
-1
776+
)# [max_total_draft_tokens]
771777
assertall_target_tokens.shape[0]==spec_tree_manager.max_total_draft_tokens+1
772778

773779
longest_accepted_len=0
@@ -800,13 +806,15 @@ def _process_draft_tokens_tree(
800806
iflongest_accepted_len==0:
801807
# No draft tokens are accepted.
802808
# Take the top-1 token of the first layer as the next new token.
803-
new_token=add_token(request,new_tokens,beam=0,step=0)
809+
new_token=add_token(request,new_tokens_list,beam=0,step=0)
804810
return0
805811
else:
806812
# Take the longest accepted path as the next new token.
807813
num_accepted_draft_tokens=0
808814
foridxineagle_paths[longest_match_path_idx][:longest_accepted_len]:
809-
new_token=add_token(request,new_tokens,beam=0,step=cast(int,idx.item()))
815+
new_token=add_token(
816+
request,new_tokens_list,beam=0,step=cast(int,idx.item())
817+
)
810818
num_accepted_draft_tokens+=1
811819
ifself._handle_stop_criteria(request,new_token):
812820
break
@@ -876,8 +884,10 @@ def _tree_sampling_batch(
876884
def_process_draft_tokens_rejection_sampling(
877885
self,
878886
request:LlmRequest,
879-
new_tokens:torch.Tensor,
887+
new_tokens_list:list[list[list[int]]],
888+
new_tokens_tensor:torch.Tensor,
880889
)->int:
890+
assertrequest.py_draft_logitsisnotNone
881891
# FIXME: Passing a dummy vocab_size could result in unnecessary
882892
# filtering of vocab_size logits, out of vocab_size in
883893
# total. The 'sample' below should generally be avoided
@@ -893,7 +903,9 @@ def _process_draft_tokens_rejection_sampling(
893903
request.py_draft_logits,
894904
generator=generator,
895905
)
906+
assertdraft_probsisnotNone
896907
target_probs=request.py_target_probs
908+
asserttarget_probsisnotNone
897909
d2t=getattr(request,"d2t",None)
898910
ifd2tisnotNone:
899911
vocab_d=draft_probs.shape[-1]
@@ -927,26 +939,27 @@ def _process_draft_tokens_rejection_sampling(
927939
num_accepted=num_initially_accepted
928940
foriinrange(num_accepted):
929941
new_token=request.py_draft_tokens[i]
930-
new_tokens[i,request.seq_slot,self.BEAM]=new_token
942+
new_tokens_tensor[i,request.seq_slot,self.BEAM]=new_token
931943
request.add_new_token(new_token,self.BEAM)
932944
stop=self._handle_stop_criteria(request,new_token)
933945
ifstop:
934946
num_accepted=i+1
935947
returnnum_accepted
936948
ifsample_last:
937949
new_token=sample_rejected(draft_probs,target_probs,generator,num_accepted)
938-
new_tokens[num_accepted,request.seq_slot,self.BEAM]=new_token
950+
new_tokens_tensor[num_accepted,request.seq_slot,self.BEAM]=new_token
939951
request.add_new_token(new_token,self.BEAM)
940952
else:
941-
new_token=add_token(request,new_tokens,beam=self.BEAM,step=num_accepted)
953+
new_token=add_token(request,new_tokens_list,beam=self.BEAM,step=num_accepted)
942954
stop=self._handle_stop_criteria(request,new_token)
943955

944956
returnnum_accepted
945957

946958
defprocess_draft_tokens(
947959
self,
948960
request:LlmRequest,
949-
new_tokens:torch.Tensor,
961+
new_tokens_tensor:torch.Tensor,
962+
new_tokens_list:list[list[list[int]]],
950963
resource_manager:Optional[ResourceManager]=None,
951964
)->int:
952965
if (
@@ -957,14 +970,19 @@ def process_draft_tokens(
957970
ifspec_tree_managerisnotNone:
958971
num_accepted=self._process_draft_tokens_tree(
959972
request,
960-
new_tokens=new_tokens,
973+
new_tokens_tensor=new_tokens_tensor,
974+
new_tokens_list=new_tokens_list,
961975
spec_tree_manager=spec_tree_manager,
962976
)
963977
else:
964-
num_accepted=self._process_draft_tokens_greedy(request,new_tokens=new_tokens)
978+
num_accepted=self._process_draft_tokens_greedy(
979+
request,new_tokens=new_tokens_list
980+
)
965981
returnnum_accepted
966982
else:
967-
returnself._process_draft_tokens_rejection_sampling(request,new_tokens)
983+
returnself._process_draft_tokens_rejection_sampling(
984+
request,new_tokens_list=new_tokens_list,new_tokens_tensor=new_tokens_tensor
985+
)
968986

969987
@override
970988
defupdate_requests(
@@ -976,15 +994,17 @@ def update_requests(
976994
ifstate.sampler_event:
977995
state.sampler_event.synchronize()
978996

997+
assertstate.hostisnotNone
979998
new_tokens=state.host.new_tokens
999+
new_tokens_list=new_tokens.tolist()
9801000

9811001
forreqinstate.scheduled_requests.context_requests:
9821002
if (
9831003
req.state==LlmRequestState.GENERATION_COMPLETE
9841004
orreq.context_remaining_length!=0
9851005
):
9861006
continue
987-
new_token=add_token(req,new_tokens,beam=self.BEAM)
1007+
new_token=add_token(req,new_tokens_list,beam=self.BEAM)
9881008
self._handle_stop_criteria(req,new_token)
9891009
self.handle_logprobs(req,state,beam=self.BEAM,count=1)
9901010
req.py_decoding_iter+=1
@@ -993,7 +1013,12 @@ def update_requests(
9931013
ifreq.state==LlmRequestState.GENERATION_COMPLETE:
9941014
continue
9951015
processed=1
996-
num_accepted=self.process_draft_tokens(req,new_tokens,resource_manager)
1016+
num_accepted=self.process_draft_tokens(
1017+
req,
1018+
new_tokens_tensor=new_tokens,
1019+
new_tokens_list=new_tokens_list,
1020+
resource_manager=resource_manager,
1021+
)
9971022
ifget_draft_token_length(req)>0:
9981023
req.py_num_accepted_draft_tokens=num_accepted
9991024
req.py_rewind_len=req.py_draft_pages_allocated-num_accepted
@@ -1911,7 +1936,7 @@ def update_requests_multiple_beams_or_drafting(
19111936
state:SampleStateTRTLLM,
19121937
beam_width:int,
19131938
):
1914-
new_tokens_host=state.host.new_tokens
1939+
new_tokens_host=state.host.new_tokens.tolist()
19151940
finished_sum_host=state.host.finished_sum.tolist()
19161941
finish_reasons=state.host.finish_reasons.flatten().tolist()
19171942
sequence_lengths_host_data=state.host.sequence_lengths.flatten().tolist()

‎tensorrt_llm/_torch/speculative/mtp.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def update_requests(
256256
assertisinstance(state,SampleStateMTP)
257257

258258
state.sampler_event.synchronize()
259-
new_tokens=state.host.new_tokens
259+
new_tokens=state.host.new_tokens.tolist()
260260
new_tokens_lens_list=state.host.new_tokens_lens.tolist()
261261
next_draft_tokens_list=state.host.next_draft_tokens.tolist()
262262
beam_idx=self.BEAM

‎tests/unittest/_torch/speculative/test_draft_token_tree_verification.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ def run_test(eagle_model_dir, max_seq_len, beam_width, use_dynamic_tree,
4545
max_beam_width=beam_width,
4646
))
4747

48+
input_new_tokens_list=input_new_tokens.tolist()
4849
num_accepted_draft_tokens=torch_sampler._process_draft_tokens_tree(
4950
request=input_request,
50-
new_tokens=input_new_tokens,
51+
new_tokens_tensor=input_new_tokens,
52+
new_tokens_list=input_new_tokens_list,
5153
spec_tree_manager=spec_tree_manager)
5254

5355
print(f"num_accepted_draft_tokens:{num_accepted_draft_tokens}")

‎tests/unittest/_torch/speculative/test_torch_rejection_sampling.py‎

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
importunittest
2+
fromtypingimportcast
23

34
importnumpyasnp
45
importtorch
@@ -24,8 +25,11 @@ def test_get_rejected_indices():
2425
sampled_regular= []
2526
for_inrange(num_iter):
2627
draft_tokens= [
27-
torch.multinomial(draft_probs,num_samples=1,
28-
generator=generator).item()
28+
cast(
29+
int,
30+
torch.multinomial(draft_probs,
31+
num_samples=1,
32+
generator=generator).item())
2933
]
3034
rejected_indices=get_rejected_indices(draft_probs,target_probs,
3135
generator,draft_tokens)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp