Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit2f8dc6f

Browse files
authored
[None][feat] Return topk logprobs in torch backend (#7756)
Signed-off-by: Dong Cao <docao@nvidia.com>
1 parent6256376 commit2f8dc6f

File tree

6 files changed

+68
-68
lines changed

6 files changed

+68
-68
lines changed

‎tensorrt_llm/_torch/pyexecutor/llm_request.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def __init__(
311311
is_draft:bool=False,
312312
seq_slot:Optional[int]=None,
313313
target_seq_slot:Optional[int]=None,
314+
num_logprobs:int=0,
314315
**kwargs):
315316

316317
self.py_logits_post_processors=kwargs.pop("py_logits_post_processors",
@@ -354,6 +355,7 @@ def __init__(
354355
self.py_lora_task_layer_module_configs:list[
355356
tensorrt_llm.bindings.internal.runtime.
356357
TaskLayerModuleConfig]|None=None
358+
self.py_num_logprobs=num_logprobs
357359

358360
self.py_return_log_probs=return_log_probs
359361
self.py_return_context_logits=return_context_logits
@@ -562,6 +564,8 @@ def executor_request_to_llm_request(
562564
mrope_position_deltas=mrope_position_deltas,
563565
lookahead_config=None,
564566
return_log_probs=executor_request.output_config.return_log_probs,
567+
num_logprobs=executor_request.py_num_logprobsifhasattr(
568+
executor_request,"py_num_logprobs")else0,
565569
return_context_logits=executor_request.output_config.
566570
return_context_logits,
567571
return_perf_metrics=executor_request.output_config.return_perf_metrics,

‎tensorrt_llm/_torch/pyexecutor/sampler.py‎

Lines changed: 36 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
fromtypingimportAny,List,Literal,Optional,cast
1010

1111
importtorch
12+
importtorch.nn.functionalasF
1213

1314
fromtensorrt_llm._torch.pyexecutor.make_decoding_batch_input_outputimport \
1415
MakeDecodingBatchInputOutput
@@ -852,15 +853,19 @@ def _handle_stop_criteria(self, request: LlmRequest,
852853

853854
defhandle_logprobs(self,request:LlmRequest,state:SampleState,*,
854855
beam:int,count:int):
855-
current_slice=slice(0,count),request.py_seq_slot,beam
856856
ifrequest.py_return_log_probs:
857-
assertstate.host.log_probsisnotNone
858-
log_probs=state.host.log_probs[request.py_seq_slot][beam][:count]
859-
current_tokens=state.host.new_tokens[current_slice]
857+
topk_log_probs_vals=request.py_topk_logprobs_vals[:count]
858+
topk_log_probs_indices=request.py_topk_logprobs_indices[:count]
860859

861860
token_log_probs= [{
862-
int(token):Logprob(logprob=logprob,rank=1)
863-
}fortoken,logprobinzip(current_tokens,log_probs.tolist())]
861+
int(token):
862+
Logprob(logprob=logprob,rank=rank+1)
863+
forrank, (token,logprob)inenumerate(
864+
zip(topk_token,topk_logprob.tolist()))
865+
}
866+
fortopk_token,topk_logprobinzip(
867+
topk_log_probs_indices,topk_log_probs_vals)]
868+
864869
assertbeam==0,"The following call relies on beam_width to be 1 - hence the list with a single element"
865870
request.py_result.append_log_probs([token_log_probs])
866871

@@ -970,13 +975,8 @@ def log_probs_host(
970975
self,
971976
scheduled_requests:ScheduledRequests)->Optional[torch.Tensor]:
972977
"""Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103"""
973-
ifany(req.py_return_log_probs
974-
forreqinscheduled_requests.all_requests()):
975-
returntorch.empty(
976-
(self.max_num_sequences,self.MAX_BEAM_WIDTH,self.max_tokens),
977-
device="cpu",
978-
pin_memory=True)
979-
returnNone
978+
returnany(req.py_return_log_probs
979+
forreqinscheduled_requests.all_requests())
980980

981981
@override
982982
@torch.inference_mode()
@@ -1001,8 +1001,7 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
10011001
sampler_event.record()
10021002
returnSampleState(scheduled_requests=scheduled_requests,
10031003
device=SampleStateTensors(new_tokens=new_tokens),
1004-
host=SampleStateTensors(new_tokens=new_tokens_host,
1005-
log_probs=log_probs_host),
1004+
host=SampleStateTensors(new_tokens=new_tokens_host),
10061005
sampler_event=sampler_event)
10071006

10081007
@staticmethod
@@ -1111,12 +1110,24 @@ def _sample_batched_by_strategy(
11111110
model_outputs:dict[str,torch.Tensor],
11121111
*,
11131112
cuda_device:torch.device,
1114-
log_probs_host:torch.Tensor|None=None,
1113+
log_probs_host:bool=False,
11151114
req_num_steps:torch.Tensor,
11161115
req_offsets:torch.Tensor,
11171116
steps_dim_size:int,
11181117
token_dtype:torch.dtype,
11191118
)->_BatchedSamplingResult:
1119+
iflog_probs_host:
1120+
assertlogits_cuda.dim()==2,"logits should be 2D"
1121+
logprobs=F.log_softmax(logits_cuda.to("cuda",
1122+
dtype=torch.float32),
1123+
dim=-1)
1124+
topk_vals,topk_indices=torch.topk(logprobs,
1125+
k=max(req.py_num_logprobs
1126+
forreqinrequests),
1127+
dim=-1)
1128+
topk_vals=topk_vals.to(device="cpu",non_blocking=True)
1129+
topk_indices=topk_indices.to(device="cpu",non_blocking=True)
1130+
11201131
requests_by_strategy=_group_requests_by_sampling_strategy(
11211132
requests,pin_memory=True)
11221133
generator_cuda=self.get_generator(cuda_device)
@@ -1160,12 +1171,18 @@ def _sample_batched_by_strategy(
11601171
# softmax_grp_indices: Indices of 'speculation_group_indices' entries requesting probs
11611172
# speculation_softmax_indices: Indices of 'softmax_grp_indices' entries corresponding
11621173
# to requests with draft logits.
1163-
iflog_probs_hostisnotNone:
1174+
iflog_probs_host:
11641175
softmax_req_indices=group_req_indices
11651176
softmax_grp_indices=torch.arange(len(group_req_indices),
11661177
dtype=torch.int32)
11671178
speculation_softmax_indices=torch.tensor(
11681179
speculation_group_indices,dtype=torch.int32)
1180+
forreq_idingroup_req_indices:
1181+
req=requests[req_id]
1182+
req.py_topk_logprobs_vals=topk_vals[
1183+
logits_cuda_indexer[req_id], :req.py_num_logprobs]
1184+
req.py_topk_logprobs_indices=topk_indices[
1185+
logits_cuda_indexer[req_id], :req.py_num_logprobs]
11691186
else:
11701187
speculation_group_indices_tensor=torch.tensor(
11711188
speculation_group_indices,dtype=torch.int32)
@@ -1257,7 +1274,7 @@ def _unbatch_sampling_results(
12571274
new_tokens_cuda:torch.Tensor,
12581275
req_num_steps:torch.Tensor,
12591276
seq_slots:torch.Tensor,
1260-
log_probs_host:torch.Tensor|None=None,
1277+
log_probs_host:bool=False,
12611278
)->torch.Tensor:
12621279
beam=self.BEAM
12631280
assertbeam==0,"beam_width != 1 not supported"
@@ -1274,17 +1291,6 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
12741291
# Assert destination tensor dimensions are canonically ordered ("row"-major); this
12751292
# matters for element ordering in the .view(...).scatter_(...) calls below.
12761293
assert_dims_canonically_ordered(new_tokens_cuda)
1277-
assertlog_probs_hostisNoneor_dims_canonically_ordered(
1278-
log_probs_host)
1279-
1280-
# new_tokens_cuda indexed by
1281-
# slice(0, steps), slot, beam
1282-
# log_probs_host indexed by
1283-
# slot, beam, slice(0, steps)
1284-
# batch_... tensors indexed by slice(batch_req_index, batch_req_index + steps)
1285-
#
1286-
iflog_probs_hostisnotNone:
1287-
assertnew_tokens_cuda.size(0)==log_probs_host.size(-2)
12881294

12891295
# Construct index mapping from slice indices of computed tensors
12901296
# (packed request_idx and step dimensions) to linearized indices
@@ -1306,39 +1312,7 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
13061312
0,batch_dest_indices_1d_cuda,
13071313
batch_next_tokens_cuda_int)
13081314
new_tokens_host=new_tokens_cuda.to("cpu",non_blocking=True)
1309-
# NB: In order to avoid a scatter_ on the host and the necessary D2H copy + synchronization,
1310-
# the 'step' and 'seq_slot' dimensions are unpacked on GPU and later asynchronously
1311-
# copied into the destination buffer. Note that this overwrites all 'step' and token slots for the
1312-
# requests in 'requests' (passed to _process_requests). In fact, the current implementation
1313-
# even overwrites the destination tensors completely (including slices corresponding to request
1314-
# slots not present in 'requests', cf. 'FIXME' below).
1315-
iflog_probs_hostisnotNone:
1316-
# FIXME: If log_probs_host were indexed by request indices, rather than request slots, this
1317-
# tensor could be packed densely along the request axis.
1318-
log_probs_cuda=torch.empty_like(
1319-
log_probs_host,device=batch_dest_indices_1d_cuda.device)
1320-
# FIXME: Needs a separate indexer because tensor layout differs from new_tokens_cuda
1321-
batch_dest_probs_cuda_indexer=_UnpackedStepIndexer(
1322-
seq_slots=seq_slots[batch_req_indices],
1323-
num_steps=req_num_steps[batch_req_indices],
1324-
steps_dim_size=new_tokens_cuda.size(0),
1325-
slots_dim_size=new_tokens_cuda.size(1),
1326-
dim_order=_UnpackedStepIndexer.DimOrder.SLOT_MAJOR,
1327-
index_dtype=torch.int64,# enforced by Tensor.scatter_
1328-
)
1329-
batch_dest_probs_indices_cuda=batch_dest_probs_cuda_indexer[:].to(
1330-
batch_softmax_cuda.device,non_blocking=True)
1331-
# NB: torch.arange is needed to enable "advanced indexing",
1332-
# cf. https://numpy.org/devdocs/user/basics.indexing.html#integer-array-indexing
1333-
batch_token_probs=batch_softmax_cuda[
1334-
torch.arange(batch_softmax_cuda.size(0),
1335-
device=batch_softmax_cuda.device,
1336-
dtype=torch.int32),batch_next_tokens_cuda_int]
1337-
log_probs_cuda[:,beam,
1338-
...].view(-1,*log_probs_cuda.shape[3:]).scatter_(
1339-
0,batch_dest_probs_indices_cuda,
1340-
torch.log(batch_token_probs))
1341-
log_probs_host.copy_(log_probs_cuda,non_blocking=True)
1315+
13421316
# For requests with LlmRequest.py_draft_logits, return py_target_probs
13431317
forrequest,batch_softmax_index_cudainpy_draft_logits_indices:
13441318
request.py_target_probs=batch_softmax_cuda[
@@ -1481,7 +1455,6 @@ def _process_requests(
14811455

14821456
logits_cuda=self._apply_min_length_penalty(logits_cuda,requests,
14831457
req_num_steps_list)
1484-
14851458
# Perform sampling in batches
14861459
batched_sampling_result=self._sample_batched_by_strategy(
14871460
logits_cuda,

‎tensorrt_llm/executor/base_worker.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ def _deduce_max_tokens(request: GenerationRequest,
480480
context_phase_params=context_phase_params,
481481
type=request_type,
482482
cache_salt_id=request.cache_salt_id)
483+
executor_request.py_num_logprobs=request.sampling_params.logprobs
483484
executor_request.py_lora_path=py_lora_path
484485

485486
ifself._is_pytorch_backendandrequest.multimodal_paramsisnotNone:

‎tensorrt_llm/llmapi/llm.py‎

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -598,10 +598,6 @@ def _check_arguments(self, prompt_len: int, query_len: int,
598598
is_gen_only:bool)->None:
599599

600600
ifself.args.backendin ["pytorch","_autodeploy"]:
601-
ifsampling_params.logprobsandsampling_params.logprobs>1:
602-
raiseValueError(
603-
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."
604-
)
605601
# Check prompt length and query length against max_num_tokens to filter illegal requests.
606602
# Skip check for gen-only requests
607603
ifself.args.backend=="pytorch"andnotself.args.enable_chunked_prefillandnotis_gen_only:

‎tensorrt_llm/scaffolding/worker.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def convert_task_params(self, task: GenerationTask):
180180
temperature=task.temperature,
181181
top_p=task.top_p,
182182
top_k=task.top_k,
183-
return_context_logits=task.return_context_logits)
183+
return_context_logits=task.return_context_logits,
184+
logprobs=task.num_logprobs)
184185
returnsampling_params
185186

186187
asyncdefgeneration_handler(self,task:GenerationTask)->TaskStatus:

‎tests/unittest/llmapi/test_llm_pytorch.py‎

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,31 @@ def test_llm_reward_model():
175175
assertnotoutputs[0].outputs[0].text
176176

177177

178+
deftest_llm_topk_logprobs():
179+
topk_logprobs=3
180+
max_tokens=10
181+
llm=LLM(model=llama_model_path,kv_cache_config=global_kvcache_config)
182+
sampling_params=SamplingParams(max_tokens=max_tokens,
183+
logprobs=topk_logprobs)
184+
outputs=llm.generate(prompts,sampling_params)
185+
logprobs=outputs[0].outputs[0].logprobs
186+
187+
assertlen(logprobs)==max_tokens
188+
forstep_logprobsinlogprobs:
189+
assertlen(step_logprobs)==topk_logprobs
190+
191+
logprob_items= [(logprob_obj.logprob,logprob_obj.rank)
192+
forlogprob_objinstep_logprobs.values()]
193+
sorted_by_rank=sorted(logprob_items,key=lambdax:x[1])
194+
195+
foriinrange(len(sorted_by_rank)-1):
196+
current_logprob,current_rank=sorted_by_rank[i]
197+
next_logprob,next_rank=sorted_by_rank[i+1]
198+
assertcurrent_logprob>=next_logprob
199+
assertcurrent_rank==i+1
200+
assertnext_rank==current_rank+1
201+
202+
178203
deftest_llm_perf_metrics():
179204
llm=LLM(model=llama_model_path,kv_cache_config=global_kvcache_config)
180205
sampling_params=SamplingParams(max_tokens=10,return_perf_metrics=True)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp