4040
4141class LogitsStorage :
4242
43- def __init__ (self ,
44- seq_length :int ,
45- use_device_memory = True ,
46- should_exclude_last = False ):
43+ def __init__ (
44+ self ,
45+ seq_length :int ,
46+ use_device_memory = True ,
47+ should_exclude_last = False ,
48+ use_chunked_generation_logits = False ,
49+ chunk_size = 8
50+ ):# logic adpted from HandleGenerationLogits.cpp to use chunked transfer
4751if should_exclude_last :
4852# Exclude last logits is used when overlap scheduler is used, that generates one extra token,
4953# so we should make sure there's memory for that extra +1.
5054seq_length += 1
5155self .seq_length = seq_length
5256self .use_device_memory = use_device_memory
5357self ._should_exclude_last = should_exclude_last
58+ self .use_chunked_generation_logits = use_chunked_generation_logits
59+ self .chunk_size = chunk_size
5460self ._logits_indices = []
5561
5662# Lazily initialized by _init() upon first append()
5763self ._storage :torch .Tensor | None = None
5864self .beam_width = - 1
5965self .vocab_size = - 1
6066
67+ # Chunked mode: device-side fragments
68+ if use_chunked_generation_logits :
69+ self ._device_fragments :List [torch .Tensor ]= []
70+ self ._current_position = 0
71+
6172def _init (self ,logits :torch .Tensor ):
6273_ ,self .beam_width ,self .vocab_size = logits .shape
6374
@@ -75,30 +86,51 @@ def _init(self, logits: torch.Tensor):
7586pin_memory = True ,
7687requires_grad = False )
7788
89+ def _init_chunked_storage (self ,logits :torch .Tensor ):
90+ # with chunked mode, we only use cpu memory
91+ _ ,self .beam_width ,self .vocab_size = logits .shape
92+
93+ self ._storage = torch .empty (
94+ (self .seq_length ,self .beam_width ,self .vocab_size ),
95+ dtype = logits .dtype ,
96+ device = 'cpu' ,
97+ pin_memory = True ,
98+ requires_grad = False )
99+
78100def append (self ,logits :torch .Tensor ):
79101if logits .ndim == 2 :
80102logits = logits .unsqueeze (1 )
81103assert logits .ndim == 3 ,f"Bad logits shape, expect [num_tokens, beam_width, vocab_size], got{ logits .shape } "
82104
83- if self .beam_width == - 1 :
84- self ._init (logits )
105+ if self .use_chunked_generation_logits :
106+ if self .beam_width == - 1 :
107+ self ._init_chunked_storage (logits )
108+ self ._add_fragment (logits )
109+ else :
110+ if self .beam_width == - 1 :
111+ self ._init (logits )
85112
86- assert logits .size (1 )== self .beam_width ,"Beam width mismatch"
113+ assert logits .size (1 )== self .beam_width ,"Beam width mismatch"
87114
88- position = 0 if not self ._logits_indices else self ._logits_indices [- 1 ][1 ]
89- new_position = logits .size (0 )+ position
90- if new_position > self .seq_length :
91- raise ValueError (
92- f"LogitsStorage overflow. This storage can only hold{ self .seq_length } logits "
93- f"({ position } already filled) but trying to append{ logits .size (0 )} more logits"
94- )
115+ position = 0 if not self ._logits_indices else self ._logits_indices [
116+ - 1 ][1 ]
117+ new_position = logits .size (0 )+ position
118+ if new_position > self .seq_length :
119+ raise ValueError (
120+ f"LogitsStorage overflow. This storage can only hold{ self .seq_length } logits "
121+ f"({ position } already filled) but trying to append{ logits .size (0 )} more logits"
122+ )
95123
96- self ._storage [position :new_position ].copy_ (logits ,non_blocking = True )
97- self ._logits_indices .append ((position ,new_position ))
124+ self ._storage [position :new_position ].copy_ (logits ,
125+ non_blocking = True )
126+ self ._logits_indices .append ((position ,new_position ))
98127
99128def get (self ,all_logits :bool )-> torch .Tensor | None :
100129"""Returns the used logits storage if there are any, otherwise, returns None.
101130 When all_logits is True then all set logits are returned, otherwise, only the last logits are returned."""
131+ if self ._storage is None :
132+ return None
133+
102134try :
103135last = - 2 if self ._should_exclude_last else - 1
104136start = 0 if all_logits else self ._logits_indices [last ][0 ]
@@ -107,6 +139,41 @@ def get(self, all_logits: bool) -> torch.Tensor | None:
107139except IndexError :
108140return None
109141
142+ def _add_fragment (self ,logits :torch .Tensor ):
143+ """Add a logits fragment to device storage"""
144+ self ._device_fragments .append (logits .clone ())
145+
146+ # Streaming mode: transfer immediately after each fragment (self.chunk_size=1).
147+ # Non-streaming mode: batch transfer every chunk_size steps.
148+ if len (self ._device_fragments )== self .chunk_size :
149+ self ._transfer_chunk_to_host ()
150+
151+ def _transfer_chunk_to_host (self ):
152+ """Transfer accumulated fragments to host"""
153+ if not self ._device_fragments :
154+ return
155+
156+ # Allocate host storage if needed
157+ assert self ._storage is not None ,"Storage should be initialized"
158+
159+ # Merge fragments on device first
160+ merged_logits = torch .cat (self ._device_fragments ,dim = 0 )
161+
162+ # Copy to host (device-to-host transfer)
163+ end_pos = self ._current_position + len (self ._device_fragments )
164+ self ._storage [self ._current_position :end_pos ].copy_ (merged_logits ,
165+ non_blocking = True )
166+
167+ # Update position and clear fragments
168+ self ._logits_indices .append ((self ._current_position ,end_pos ))
169+ self ._current_position = end_pos
170+ self ._device_fragments .clear ()
171+
172+ def finalize_chunked_transfer (self ):
173+ """Force transfer of any remaining fragments to host (for chunked mode)"""
174+ if self .use_chunked_generation_logits and self ._device_fragments :
175+ self ._transfer_chunk_to_host ()
176+
110177def set_exclude_last (self ,should_exclude_last :bool )-> None :
111178self ._should_exclude_last = should_exclude_last
112179
@@ -164,13 +231,25 @@ def __init__(self,
164231return_log_probs :bool = False ,
165232return_context_logits :bool = False ,
166233return_generation_logits :bool = False ,
167- exclude_last_generation_logits :bool = False ):
234+ exclude_last_generation_logits :bool = False ,
235+ use_chunked_generation_logits :bool = True ,
236+ chunk_size :int = 8 ):
237+ if streaming and use_chunked_generation_logits :
238+ assert chunk_size == 1 ,"chunk_size must be 1 in streaming mode"
168239self ._streaming = streaming
240+ self ._chunk_size = chunk_size
241+
242+ # Note that in C++ implemnetation both context logits and generation logits are stored on host memory.
243+ # Here we only use host memory for generation logits if in chunked model.
169244self ._context_logits = LogitsStorage (
170- prompt_len ,use_device_memory )if return_context_logits else None
245+ prompt_len ,use_device_memory ,use_chunked_generation_logits = False
246+ )if return_context_logits else None
171247self ._generation_logits = LogitsStorage (
172- max_new_tokens ,use_device_memory ,exclude_last_generation_logits
173- )if return_generation_logits else None
248+ max_new_tokens ,
249+ use_device_memory ,
250+ exclude_last_generation_logits ,
251+ use_chunked_generation_logits = use_chunked_generation_logits ,
252+ chunk_size = self ._chunk_size )if return_generation_logits else None
174253self ._log_probs = LogProbStorage ()if return_log_probs else None
175254
176255def append_context_logits (self ,context_logits :torch .Tensor ):
@@ -187,6 +266,11 @@ def append_log_probs(self,
187266if self ._log_probs :
188267self ._log_probs .append (log_probs ,cum_log_probs )
189268
269+ def transfer_remaining_device_logits (self ):
270+ """Finalize any remaining generation logits transfers (for chunked mode)"""
271+ if self ._generation_logits :
272+ self ._generation_logits .finalize_chunked_transfer ()
273+
190274def set_log_probs (self ,log_probs :list [TokenLogprobs ],
191275cum_log_probs :list [float ]):
192276"""
@@ -292,6 +376,8 @@ def __init__(
292376llm_request :Optional [
293377tensorrt_llm .bindings .internal .batch_manager .LlmRequest ]= None ,
294378is_draft :bool = False ,
379+ use_chunked_generation_logits :bool = True ,
380+ logits_chunk_size :int = 8 ,
295381** kwargs ):
296382
297383self .py_logits_post_processors = kwargs .pop ("py_logits_post_processors" ,
@@ -339,15 +425,25 @@ def __init__(
339425self .py_is_draft = is_draft
340426self .py_seq_slot = None
341427
428+ # Chunked logits parameters
429+ self .py_use_chunked_generation_logits = use_chunked_generation_logits
430+ self .py_logits_chunk_size = logits_chunk_size if not self .streaming else 1
431+
342432# TODO: remove this when use DynamicDecodeOp in pytorch flow.
343433# currently, keep py_stop_words_list as python list, rather than tensor.
344434self .py_stop_words_list = stop_words_list
345435
346- self .py_result = PyResult (self .py_prompt_len ,self .py_max_new_tokens ,
347- return_logits_device_memory ,self .streaming ,
348- return_log_probs ,return_context_logits ,
349- return_generation_logits ,
350- exclude_last_generation_logits )
436+ self .py_result = PyResult (
437+ self .py_prompt_len ,
438+ self .py_max_new_tokens ,
439+ return_logits_device_memory ,
440+ self .streaming ,
441+ return_log_probs ,
442+ return_context_logits ,
443+ return_generation_logits ,
444+ exclude_last_generation_logits ,
445+ use_chunked_generation_logits = self .py_use_chunked_generation_logits ,
446+ chunk_size = self .py_logits_chunk_size )
351447self .child_requests = []
352448
353449self ._py_embedding_bias_1d = None