99from typing import Any ,List ,Literal ,Optional ,cast
1010
1111import torch
12+ import torch .nn .functional as F
1213
1314from tensorrt_llm ._torch .pyexecutor .make_decoding_batch_input_output import \
1415MakeDecodingBatchInputOutput
@@ -852,15 +853,19 @@ def _handle_stop_criteria(self, request: LlmRequest,
852853
853854def handle_logprobs (self ,request :LlmRequest ,state :SampleState ,* ,
854855beam :int ,count :int ):
855- current_slice = slice (0 ,count ),request .py_seq_slot ,beam
856856if request .py_return_log_probs :
857- assert state .host .log_probs is not None
858- log_probs = state .host .log_probs [request .py_seq_slot ][beam ][:count ]
859- current_tokens = state .host .new_tokens [current_slice ]
857+ topk_log_probs_vals = request .py_topk_logprobs_vals [:count ]
858+ topk_log_probs_indices = request .py_topk_logprobs_indices [:count ]
860859
861860token_log_probs = [{
862- int (token ):Logprob (logprob = logprob ,rank = 1 )
863- }for token ,logprob in zip (current_tokens ,log_probs .tolist ())]
861+ int (token ):
862+ Logprob (logprob = logprob ,rank = rank + 1 )
863+ for rank , (token ,logprob )in enumerate (
864+ zip (topk_token ,topk_logprob .tolist ()))
865+ }
866+ for topk_token ,topk_logprob in zip (
867+ topk_log_probs_indices ,topk_log_probs_vals )]
868+
864869assert beam == 0 ,"The following call relies on beam_width to be 1 - hence the list with a single element"
865870request .py_result .append_log_probs ([token_log_probs ])
866871
@@ -970,13 +975,8 @@ def log_probs_host(
970975self ,
971976scheduled_requests :ScheduledRequests )-> Optional [torch .Tensor ]:
972977"""Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103"""
973- if any (req .py_return_log_probs
974- for req in scheduled_requests .all_requests ()):
975- return torch .empty (
976- (self .max_num_sequences ,self .MAX_BEAM_WIDTH ,self .max_tokens ),
977- device = "cpu" ,
978- pin_memory = True )
979- return None
978+ return any (req .py_return_log_probs
979+ for req in scheduled_requests .all_requests ())
980980
981981@override
982982@torch .inference_mode ()
@@ -1001,8 +1001,7 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
10011001sampler_event .record ()
10021002return SampleState (scheduled_requests = scheduled_requests ,
10031003device = SampleStateTensors (new_tokens = new_tokens ),
1004- host = SampleStateTensors (new_tokens = new_tokens_host ,
1005- log_probs = log_probs_host ),
1004+ host = SampleStateTensors (new_tokens = new_tokens_host ),
10061005sampler_event = sampler_event )
10071006
10081007@staticmethod
@@ -1111,12 +1110,24 @@ def _sample_batched_by_strategy(
11111110model_outputs :dict [str ,torch .Tensor ],
11121111* ,
11131112cuda_device :torch .device ,
1114- log_probs_host :torch . Tensor | None = None ,
1113+ log_probs_host :bool = False ,
11151114req_num_steps :torch .Tensor ,
11161115req_offsets :torch .Tensor ,
11171116steps_dim_size :int ,
11181117token_dtype :torch .dtype ,
11191118 )-> _BatchedSamplingResult :
1119+ if log_probs_host :
1120+ assert logits_cuda .dim ()== 2 ,"logits should be 2D"
1121+ logprobs = F .log_softmax (logits_cuda .to ("cuda" ,
1122+ dtype = torch .float32 ),
1123+ dim = - 1 )
1124+ topk_vals ,topk_indices = torch .topk (logprobs ,
1125+ k = max (req .py_num_logprobs
1126+ for req in requests ),
1127+ dim = - 1 )
1128+ topk_vals = topk_vals .to (device = "cpu" ,non_blocking = True )
1129+ topk_indices = topk_indices .to (device = "cpu" ,non_blocking = True )
1130+
11201131requests_by_strategy = _group_requests_by_sampling_strategy (
11211132requests ,pin_memory = True )
11221133generator_cuda = self .get_generator (cuda_device )
@@ -1160,12 +1171,18 @@ def _sample_batched_by_strategy(
11601171# softmax_grp_indices: Indices of 'speculation_group_indices' entries requesting probs
11611172# speculation_softmax_indices: Indices of 'softmax_grp_indices' entries corresponding
11621173# to requests with draft logits.
1163- if log_probs_host is not None :
1174+ if log_probs_host :
11641175softmax_req_indices = group_req_indices
11651176softmax_grp_indices = torch .arange (len (group_req_indices ),
11661177dtype = torch .int32 )
11671178speculation_softmax_indices = torch .tensor (
11681179speculation_group_indices ,dtype = torch .int32 )
1180+ for req_id in group_req_indices :
1181+ req = requests [req_id ]
1182+ req .py_topk_logprobs_vals = topk_vals [
1183+ logits_cuda_indexer [req_id ], :req .py_num_logprobs ]
1184+ req .py_topk_logprobs_indices = topk_indices [
1185+ logits_cuda_indexer [req_id ], :req .py_num_logprobs ]
11691186else :
11701187speculation_group_indices_tensor = torch .tensor (
11711188speculation_group_indices ,dtype = torch .int32 )
@@ -1257,7 +1274,7 @@ def _unbatch_sampling_results(
12571274new_tokens_cuda :torch .Tensor ,
12581275req_num_steps :torch .Tensor ,
12591276seq_slots :torch .Tensor ,
1260- log_probs_host :torch . Tensor | None = None ,
1277+ log_probs_host :bool = False ,
12611278 )-> torch .Tensor :
12621279beam = self .BEAM
12631280assert beam == 0 ,"beam_width != 1 not supported"
@@ -1274,17 +1291,6 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
12741291# Assert destination tensor dimensions are canonically ordered ("row"-major); this
12751292# matters for element ordering in the .view(...).scatter_(...) calls below.
12761293assert _dims_canonically_ordered (new_tokens_cuda )
1277- assert log_probs_host is None or _dims_canonically_ordered (
1278- log_probs_host )
1279-
1280- # new_tokens_cuda indexed by
1281- # slice(0, steps), slot, beam
1282- # log_probs_host indexed by
1283- # slot, beam, slice(0, steps)
1284- # batch_... tensors indexed by slice(batch_req_index, batch_req_index + steps)
1285- #
1286- if log_probs_host is not None :
1287- assert new_tokens_cuda .size (0 )== log_probs_host .size (- 2 )
12881294
12891295# Construct index mapping from slice indices of computed tensors
12901296# (packed request_idx and step dimensions) to linearized indices
@@ -1306,39 +1312,7 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
130613120 ,batch_dest_indices_1d_cuda ,
13071313batch_next_tokens_cuda_int )
13081314new_tokens_host = new_tokens_cuda .to ("cpu" ,non_blocking = True )
1309- # NB: In order to avoid a scatter_ on the host and the necessary D2H copy + synchronization,
1310- # the 'step' and 'seq_slot' dimensions are unpacked on GPU and later asynchronously
1311- # copied into the destination buffer. Note that this overwrites all 'step' and token slots for the
1312- # requests in 'requests' (passed to _process_requests). In fact, the current implementation
1313- # even overwrites the destination tensors completely (including slices corresponding to request
1314- # slots not present in 'requests', cf. 'FIXME' below).
1315- if log_probs_host is not None :
1316- # FIXME: If log_probs_host were indexed by request indices, rather than request slots, this
1317- # tensor could be packed densely along the request axis.
1318- log_probs_cuda = torch .empty_like (
1319- log_probs_host ,device = batch_dest_indices_1d_cuda .device )
1320- # FIXME: Needs a separate indexer because tensor layout differs from new_tokens_cuda
1321- batch_dest_probs_cuda_indexer = _UnpackedStepIndexer (
1322- seq_slots = seq_slots [batch_req_indices ],
1323- num_steps = req_num_steps [batch_req_indices ],
1324- steps_dim_size = new_tokens_cuda .size (0 ),
1325- slots_dim_size = new_tokens_cuda .size (1 ),
1326- dim_order = _UnpackedStepIndexer .DimOrder .SLOT_MAJOR ,
1327- index_dtype = torch .int64 ,# enforced by Tensor.scatter_
1328- )
1329- batch_dest_probs_indices_cuda = batch_dest_probs_cuda_indexer [:].to (
1330- batch_softmax_cuda .device ,non_blocking = True )
1331- # NB: torch.arange is needed to enable "advanced indexing",
1332- # cf. https://numpy.org/devdocs/user/basics.indexing.html#integer-array-indexing
1333- batch_token_probs = batch_softmax_cuda [
1334- torch .arange (batch_softmax_cuda .size (0 ),
1335- device = batch_softmax_cuda .device ,
1336- dtype = torch .int32 ),batch_next_tokens_cuda_int ]
1337- log_probs_cuda [:,beam ,
1338- ...].view (- 1 ,* log_probs_cuda .shape [3 :]).scatter_ (
1339- 0 ,batch_dest_probs_indices_cuda ,
1340- torch .log (batch_token_probs ))
1341- log_probs_host .copy_ (log_probs_cuda ,non_blocking = True )
1315+
13421316# For requests with LlmRequest.py_draft_logits, return py_target_probs
13431317for request ,batch_softmax_index_cuda in py_draft_logits_indices :
13441318request .py_target_probs = batch_softmax_cuda [
@@ -1481,7 +1455,6 @@ def _process_requests(
14811455
14821456logits_cuda = self ._apply_min_length_penalty (logits_cuda ,requests ,
14831457req_num_steps_list )
1484-
14851458# Perform sampling in batches
14861459batched_sampling_result = self ._sample_batched_by_strategy (
14871460logits_cuda ,