Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit88a9c4f

Browse files
yibinl-nvidiatijyojwad
authored andcommitted
[cherry-pick] [TRTLLM-8031][feat] Add chunked return_generation_logits logic (#7831)
Signed-off-by: Yibin Li <yibinl@nvidia.com>
1 parent394c2ab commit88a9c4f

File tree

4 files changed

+451
-27
lines changed

4 files changed

+451
-27
lines changed

‎tensorrt_llm/_torch/pyexecutor/handle_logits.py‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
fromitertoolsimportchain
12
fromtypingimportList
23

34
importtorch
@@ -64,3 +65,8 @@ def __call__(
6465
logits_view=logits[logits_begin:logits_end].reshape(
6566
1,beam_width,-1)
6667
llm_req.py_result.append_generation_logits(logits_view)
68+
69+
# Finalize any remaining logits transfers for all requests in chunked mode
70+
forllm_reqinchain(context_requests,generation_requests):
71+
ifllm_req.py_use_chunked_generation_logitsandllm_req.py_return_generation_logits:
72+
llm_req.py_result.transfer_remaining_device_logits()

‎tensorrt_llm/_torch/pyexecutor/llm_request.py‎

Lines changed: 121 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,35 @@
4040

4141
classLogitsStorage:
4242

43-
def__init__(self,
44-
seq_length:int,
45-
use_device_memory=True,
46-
should_exclude_last=False):
43+
def__init__(
44+
self,
45+
seq_length:int,
46+
use_device_memory=True,
47+
should_exclude_last=False,
48+
use_chunked_generation_logits=False,
49+
chunk_size=8
50+
):# logic adpted from HandleGenerationLogits.cpp to use chunked transfer
4751
ifshould_exclude_last:
4852
# Exclude last logits is used when overlap scheduler is used, that generates one extra token,
4953
# so we should make sure there's memory for that extra +1.
5054
seq_length+=1
5155
self.seq_length=seq_length
5256
self.use_device_memory=use_device_memory
5357
self._should_exclude_last=should_exclude_last
58+
self.use_chunked_generation_logits=use_chunked_generation_logits
59+
self.chunk_size=chunk_size
5460
self._logits_indices= []
5561

5662
# Lazily initialized by _init() upon first append()
5763
self._storage:torch.Tensor|None=None
5864
self.beam_width=-1
5965
self.vocab_size=-1
6066

67+
# Chunked mode: device-side fragments
68+
ifuse_chunked_generation_logits:
69+
self._device_fragments:List[torch.Tensor]= []
70+
self._current_position=0
71+
6172
def_init(self,logits:torch.Tensor):
6273
_,self.beam_width,self.vocab_size=logits.shape
6374

@@ -75,30 +86,51 @@ def _init(self, logits: torch.Tensor):
7586
pin_memory=True,
7687
requires_grad=False)
7788

89+
def_init_chunked_storage(self,logits:torch.Tensor):
90+
# with chunked mode, we only use cpu memory
91+
_,self.beam_width,self.vocab_size=logits.shape
92+
93+
self._storage=torch.empty(
94+
(self.seq_length,self.beam_width,self.vocab_size),
95+
dtype=logits.dtype,
96+
device='cpu',
97+
pin_memory=True,
98+
requires_grad=False)
99+
78100
defappend(self,logits:torch.Tensor):
79101
iflogits.ndim==2:
80102
logits=logits.unsqueeze(1)
81103
assertlogits.ndim==3,f"Bad logits shape, expect [num_tokens, beam_width, vocab_size], got{logits.shape}"
82104

83-
ifself.beam_width==-1:
84-
self._init(logits)
105+
ifself.use_chunked_generation_logits:
106+
ifself.beam_width==-1:
107+
self._init_chunked_storage(logits)
108+
self._add_fragment(logits)
109+
else:
110+
ifself.beam_width==-1:
111+
self._init(logits)
85112

86-
assertlogits.size(1)==self.beam_width,"Beam width mismatch"
113+
assertlogits.size(1)==self.beam_width,"Beam width mismatch"
87114

88-
position=0ifnotself._logits_indiceselseself._logits_indices[-1][1]
89-
new_position=logits.size(0)+position
90-
ifnew_position>self.seq_length:
91-
raiseValueError(
92-
f"LogitsStorage overflow. This storage can only hold{self.seq_length} logits "
93-
f"({position} already filled) but trying to append{logits.size(0)} more logits"
94-
)
115+
position=0ifnotself._logits_indiceselseself._logits_indices[
116+
-1][1]
117+
new_position=logits.size(0)+position
118+
ifnew_position>self.seq_length:
119+
raiseValueError(
120+
f"LogitsStorage overflow. This storage can only hold{self.seq_length} logits "
121+
f"({position} already filled) but trying to append{logits.size(0)} more logits"
122+
)
95123

96-
self._storage[position:new_position].copy_(logits,non_blocking=True)
97-
self._logits_indices.append((position,new_position))
124+
self._storage[position:new_position].copy_(logits,
125+
non_blocking=True)
126+
self._logits_indices.append((position,new_position))
98127

99128
defget(self,all_logits:bool)->torch.Tensor|None:
100129
"""Returns the used logits storage if there are any, otherwise, returns None.
101130
When all_logits is True then all set logits are returned, otherwise, only the last logits are returned."""
131+
ifself._storageisNone:
132+
returnNone
133+
102134
try:
103135
last=-2ifself._should_exclude_lastelse-1
104136
start=0ifall_logitselseself._logits_indices[last][0]
@@ -107,6 +139,41 @@ def get(self, all_logits: bool) -> torch.Tensor | None:
107139
exceptIndexError:
108140
returnNone
109141

142+
def_add_fragment(self,logits:torch.Tensor):
143+
"""Add a logits fragment to device storage"""
144+
self._device_fragments.append(logits.clone())
145+
146+
# Streaming mode: transfer immediately after each fragment (self.chunk_size=1).
147+
# Non-streaming mode: batch transfer every chunk_size steps.
148+
iflen(self._device_fragments)==self.chunk_size:
149+
self._transfer_chunk_to_host()
150+
151+
def_transfer_chunk_to_host(self):
152+
"""Transfer accumulated fragments to host"""
153+
ifnotself._device_fragments:
154+
return
155+
156+
# Allocate host storage if needed
157+
assertself._storageisnotNone,"Storage should be initialized"
158+
159+
# Merge fragments on device first
160+
merged_logits=torch.cat(self._device_fragments,dim=0)
161+
162+
# Copy to host (device-to-host transfer)
163+
end_pos=self._current_position+len(self._device_fragments)
164+
self._storage[self._current_position:end_pos].copy_(merged_logits,
165+
non_blocking=True)
166+
167+
# Update position and clear fragments
168+
self._logits_indices.append((self._current_position,end_pos))
169+
self._current_position=end_pos
170+
self._device_fragments.clear()
171+
172+
deffinalize_chunked_transfer(self):
173+
"""Force transfer of any remaining fragments to host (for chunked mode)"""
174+
ifself.use_chunked_generation_logitsandself._device_fragments:
175+
self._transfer_chunk_to_host()
176+
110177
defset_exclude_last(self,should_exclude_last:bool)->None:
111178
self._should_exclude_last=should_exclude_last
112179

@@ -164,13 +231,25 @@ def __init__(self,
164231
return_log_probs:bool=False,
165232
return_context_logits:bool=False,
166233
return_generation_logits:bool=False,
167-
exclude_last_generation_logits:bool=False):
234+
exclude_last_generation_logits:bool=False,
235+
use_chunked_generation_logits:bool=True,
236+
chunk_size:int=8):
237+
ifstreaminganduse_chunked_generation_logits:
238+
assertchunk_size==1,"chunk_size must be 1 in streaming mode"
168239
self._streaming=streaming
240+
self._chunk_size=chunk_size
241+
242+
# Note that in C++ implemnetation both context logits and generation logits are stored on host memory.
243+
# Here we only use host memory for generation logits if in chunked model.
169244
self._context_logits=LogitsStorage(
170-
prompt_len,use_device_memory)ifreturn_context_logitselseNone
245+
prompt_len,use_device_memory,use_chunked_generation_logits=False
246+
)ifreturn_context_logitselseNone
171247
self._generation_logits=LogitsStorage(
172-
max_new_tokens,use_device_memory,exclude_last_generation_logits
173-
)ifreturn_generation_logitselseNone
248+
max_new_tokens,
249+
use_device_memory,
250+
exclude_last_generation_logits,
251+
use_chunked_generation_logits=use_chunked_generation_logits,
252+
chunk_size=self._chunk_size)ifreturn_generation_logitselseNone
174253
self._log_probs=LogProbStorage()ifreturn_log_probselseNone
175254

176255
defappend_context_logits(self,context_logits:torch.Tensor):
@@ -187,6 +266,11 @@ def append_log_probs(self,
187266
ifself._log_probs:
188267
self._log_probs.append(log_probs,cum_log_probs)
189268

269+
deftransfer_remaining_device_logits(self):
270+
"""Finalize any remaining generation logits transfers (for chunked mode)"""
271+
ifself._generation_logits:
272+
self._generation_logits.finalize_chunked_transfer()
273+
190274
defset_log_probs(self,log_probs:list[TokenLogprobs],
191275
cum_log_probs:list[float]):
192276
"""
@@ -292,6 +376,8 @@ def __init__(
292376
llm_request:Optional[
293377
tensorrt_llm.bindings.internal.batch_manager.LlmRequest]=None,
294378
is_draft:bool=False,
379+
use_chunked_generation_logits:bool=True,
380+
logits_chunk_size:int=8,
295381
**kwargs):
296382

297383
self.py_logits_post_processors=kwargs.pop("py_logits_post_processors",
@@ -339,15 +425,25 @@ def __init__(
339425
self.py_is_draft=is_draft
340426
self.py_seq_slot=None
341427

428+
# Chunked logits parameters
429+
self.py_use_chunked_generation_logits=use_chunked_generation_logits
430+
self.py_logits_chunk_size=logits_chunk_sizeifnotself.streamingelse1
431+
342432
# TODO: remove this when use DynamicDecodeOp in pytorch flow.
343433
# currently, keep py_stop_words_list as python list, rather than tensor.
344434
self.py_stop_words_list=stop_words_list
345435

346-
self.py_result=PyResult(self.py_prompt_len,self.py_max_new_tokens,
347-
return_logits_device_memory,self.streaming,
348-
return_log_probs,return_context_logits,
349-
return_generation_logits,
350-
exclude_last_generation_logits)
436+
self.py_result=PyResult(
437+
self.py_prompt_len,
438+
self.py_max_new_tokens,
439+
return_logits_device_memory,
440+
self.streaming,
441+
return_log_probs,
442+
return_context_logits,
443+
return_generation_logits,
444+
exclude_last_generation_logits,
445+
use_chunked_generation_logits=self.py_use_chunked_generation_logits,
446+
chunk_size=self.py_logits_chunk_size)
351447
self.child_requests= []
352448

353449
self._py_embedding_bias_1d=None

‎tests/unittest/_torch/modeling/test_modeling_nemotron_h.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@
1111

1212
defget_logprobs(token_ids:torch.Tensor,logits:torch.Tensor)->torch.Tensor:
1313
raw_probs=torch.softmax(logits,dim=-1)
14-
index=token_ids.unsqueeze(1).cuda()
14+
index=token_ids.unsqueeze(1)
15+
assertindex.device==raw_probs.device,f"index and raw_probs should be on the same device, but got index location:{index.device}, raw_probs location:{raw_probs.device}"
1516
token_probs=torch.gather(raw_probs,dim=1,index=index).squeeze(-1)
1617
returntorch.log(token_probs)
1718

1819

1920
defextract_prefill_logprobs(result:RequestOutput)->torch.Tensor:
2021
token_ids=torch.tensor(result.prompt_token_ids[1:])
2122
logits=result.context_logits[:-1, :]
22-
returnget_logprobs(token_ids,logits)
23+
returnget_logprobs(token_ids.cuda(),logits)
2324

2425

2526
defextract_decode_logprobs(result:RequestOutput,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp