@@ -290,10 +290,13 @@ def _group_requests_by_strategy_key(
290290 }
291291
292292
293- def add_token (request :LlmRequest ,new_tokens :torch .Tensor ,* ,beam :int ,step :int = 0 )-> int :
293+ def add_token (
294+ request :LlmRequest ,new_tokens :list [list [list [int ]]],* ,beam :int ,step :int = 0
295+ )-> int :
296+ # NB: Accessing nested lists faster than torch.Tensor or numpy.ndarray
294297seq_slot = request .py_seq_slot
295298assert seq_slot is not None
296- new_token = cast ( int , new_tokens [step ][seq_slot ][beam ]. item ())
299+ new_token = new_tokens [step ][seq_slot ][beam ]
297300request .add_new_token (new_token ,beam )
298301return new_token
299302
@@ -700,7 +703,7 @@ def handle_logprobs(
700703def _process_draft_tokens_greedy (
701704self ,
702705request :LlmRequest ,
703- new_tokens :torch . Tensor ,
706+ new_tokens :list [ list [ list [ int ]]] ,
704707 )-> int :
705708new_token = add_token (request ,new_tokens ,beam = self .BEAM )
706709stop = self ._handle_stop_criteria (request ,new_token )
@@ -722,7 +725,8 @@ def _process_draft_tokens_greedy(
722725def _process_draft_tokens_tree (
723726self ,
724727request :LlmRequest ,
725- new_tokens :torch .Tensor ,
728+ new_tokens_tensor :torch .Tensor ,
729+ new_tokens_list :list [list [list [int ]]],
726730spec_tree_manager :SpecTreeManager ,
727731 )-> int :
728732"""Tree verification for draft token tree based speculative decoding.
@@ -757,7 +761,7 @@ def _process_draft_tokens_tree(
757761# TODO: For the last layer of the dynamic tree, we need to resampling all the draft tokens.
758762cur_layer_num_nodes = sum (spec_tree_manager .get_top_k_list (cur_draft_layer_idx ))
759763for i in range (cur_layer_num_nodes ):
760- new_token = add_token (request ,new_tokens ,beam = 0 ,step = i )
764+ new_token = add_token (request ,new_tokens_list ,beam = 0 ,step = i )
761765return 0
762766else :
763767# handle the target model request
@@ -767,7 +771,9 @@ def _process_draft_tokens_tree(
767771eagle_paths = spec_tree_manager .get_eagle_paths (seq_slot )
768772
769773all_draft_tokens = request .py_draft_tokens # [max_total_draft_tokens]
770- all_target_tokens = new_tokens [:,seq_slot , :].squeeze (- 1 )# [max_total_draft_tokens]
774+ all_target_tokens = new_tokens_tensor [:,seq_slot , :].squeeze (
775+ - 1
776+ )# [max_total_draft_tokens]
771777assert all_target_tokens .shape [0 ]== spec_tree_manager .max_total_draft_tokens + 1
772778
773779longest_accepted_len = 0
@@ -800,13 +806,15 @@ def _process_draft_tokens_tree(
800806if longest_accepted_len == 0 :
801807# No draft tokens are accepted.
802808# Take the top-1 token of the first layer as the next new token.
803- new_token = add_token (request ,new_tokens ,beam = 0 ,step = 0 )
809+ new_token = add_token (request ,new_tokens_list ,beam = 0 ,step = 0 )
804810return 0
805811else :
806812# Take the longest accepted path as the next new token.
807813num_accepted_draft_tokens = 0
808814for idx in eagle_paths [longest_match_path_idx ][:longest_accepted_len ]:
809- new_token = add_token (request ,new_tokens ,beam = 0 ,step = cast (int ,idx .item ()))
815+ new_token = add_token (
816+ request ,new_tokens_list ,beam = 0 ,step = cast (int ,idx .item ())
817+ )
810818num_accepted_draft_tokens += 1
811819if self ._handle_stop_criteria (request ,new_token ):
812820break
@@ -876,8 +884,10 @@ def _tree_sampling_batch(
876884def _process_draft_tokens_rejection_sampling (
877885self ,
878886request :LlmRequest ,
879- new_tokens :torch .Tensor ,
887+ new_tokens_list :list [list [list [int ]]],
888+ new_tokens_tensor :torch .Tensor ,
880889 )-> int :
890+ assert request .py_draft_logits is not None
881891# FIXME: Passing a dummy vocab_size could result in unnecessary
882892# filtering of vocab_size logits, out of vocab_size in
883893# total. The 'sample' below should generally be avoided
@@ -893,7 +903,9 @@ def _process_draft_tokens_rejection_sampling(
893903request .py_draft_logits ,
894904generator = generator ,
895905 )
906+ assert draft_probs is not None
896907target_probs = request .py_target_probs
908+ assert target_probs is not None
897909d2t = getattr (request ,"d2t" ,None )
898910if d2t is not None :
899911vocab_d = draft_probs .shape [- 1 ]
@@ -927,26 +939,27 @@ def _process_draft_tokens_rejection_sampling(
927939num_accepted = num_initially_accepted
928940for i in range (num_accepted ):
929941new_token = request .py_draft_tokens [i ]
930- new_tokens [i ,request .seq_slot ,self .BEAM ]= new_token
942+ new_tokens_tensor [i ,request .seq_slot ,self .BEAM ]= new_token
931943request .add_new_token (new_token ,self .BEAM )
932944stop = self ._handle_stop_criteria (request ,new_token )
933945if stop :
934946num_accepted = i + 1
935947return num_accepted
936948if sample_last :
937949new_token = sample_rejected (draft_probs ,target_probs ,generator ,num_accepted )
938- new_tokens [num_accepted ,request .seq_slot ,self .BEAM ]= new_token
950+ new_tokens_tensor [num_accepted ,request .seq_slot ,self .BEAM ]= new_token
939951request .add_new_token (new_token ,self .BEAM )
940952else :
941- new_token = add_token (request ,new_tokens ,beam = self .BEAM ,step = num_accepted )
953+ new_token = add_token (request ,new_tokens_list ,beam = self .BEAM ,step = num_accepted )
942954stop = self ._handle_stop_criteria (request ,new_token )
943955
944956return num_accepted
945957
946958def process_draft_tokens (
947959self ,
948960request :LlmRequest ,
949- new_tokens :torch .Tensor ,
961+ new_tokens_tensor :torch .Tensor ,
962+ new_tokens_list :list [list [list [int ]]],
950963resource_manager :Optional [ResourceManager ]= None ,
951964 )-> int :
952965if (
@@ -957,14 +970,19 @@ def process_draft_tokens(
957970if spec_tree_manager is not None :
958971num_accepted = self ._process_draft_tokens_tree (
959972request ,
960- new_tokens = new_tokens ,
973+ new_tokens_tensor = new_tokens_tensor ,
974+ new_tokens_list = new_tokens_list ,
961975spec_tree_manager = spec_tree_manager ,
962976 )
963977else :
964- num_accepted = self ._process_draft_tokens_greedy (request ,new_tokens = new_tokens )
978+ num_accepted = self ._process_draft_tokens_greedy (
979+ request ,new_tokens = new_tokens_list
980+ )
965981return num_accepted
966982else :
967- return self ._process_draft_tokens_rejection_sampling (request ,new_tokens )
983+ return self ._process_draft_tokens_rejection_sampling (
984+ request ,new_tokens_list = new_tokens_list ,new_tokens_tensor = new_tokens_tensor
985+ )
968986
969987@override
970988def update_requests (
@@ -976,15 +994,17 @@ def update_requests(
976994if state .sampler_event :
977995state .sampler_event .synchronize ()
978996
997+ assert state .host is not None
979998new_tokens = state .host .new_tokens
999+ new_tokens_list = new_tokens .tolist ()
9801000
9811001for req in state .scheduled_requests .context_requests :
9821002if (
9831003req .state == LlmRequestState .GENERATION_COMPLETE
9841004or req .context_remaining_length != 0
9851005 ):
9861006continue
987- new_token = add_token (req ,new_tokens ,beam = self .BEAM )
1007+ new_token = add_token (req ,new_tokens_list ,beam = self .BEAM )
9881008self ._handle_stop_criteria (req ,new_token )
9891009self .handle_logprobs (req ,state ,beam = self .BEAM ,count = 1 )
9901010req .py_decoding_iter += 1
@@ -993,7 +1013,12 @@ def update_requests(
9931013if req .state == LlmRequestState .GENERATION_COMPLETE :
9941014continue
9951015processed = 1
996- num_accepted = self .process_draft_tokens (req ,new_tokens ,resource_manager )
1016+ num_accepted = self .process_draft_tokens (
1017+ req ,
1018+ new_tokens_tensor = new_tokens ,
1019+ new_tokens_list = new_tokens_list ,
1020+ resource_manager = resource_manager ,
1021+ )
9971022if get_draft_token_length (req )> 0 :
9981023req .py_num_accepted_draft_tokens = num_accepted
9991024req .py_rewind_len = req .py_draft_pages_allocated - num_accepted
@@ -1911,7 +1936,7 @@ def update_requests_multiple_beams_or_drafting(
19111936state :SampleStateTRTLLM ,
19121937beam_width :int ,
19131938 ):
1914- new_tokens_host = state .host .new_tokens
1939+ new_tokens_host = state .host .new_tokens . tolist ()
19151940finished_sum_host = state .host .finished_sum .tolist ()
19161941finish_reasons = state .host .finish_reasons .flatten ().tolist ()
19171942sequence_lengths_host_data = state .host .sequence_lengths .flatten ().tolist ()