@@ -47,7 +47,7 @@ def __init__(self,
4747seq_length :int ,
4848use_device_memory = True ,
4949should_exclude_last = False ,
50- chunked_mode = False ,
50+ use_chunked_logits = False ,
5151streaming = False ,
5252chunk_size = 8 ):
5353if should_exclude_last :
@@ -57,7 +57,7 @@ def __init__(self,
5757self .seq_length = seq_length
5858self .use_device_memory = use_device_memory
5959self ._should_exclude_last = should_exclude_last
60- self .chunked_mode = chunked_mode
60+ self .use_chunked_logits = use_chunked_logits
6161self .chunk_size = chunk_size
6262self .streaming = streaming
6363self ._logits_indices = []
@@ -68,7 +68,7 @@ def __init__(self,
6868self .vocab_size = - 1
6969
7070# Chunked mode: device-side fragments
71- if chunked_mode :
71+ if use_chunked_logits :
7272self ._device_fragments :List [torch .Tensor ]= []
7373self ._current_position = 0
7474
@@ -105,7 +105,7 @@ def append(self, logits: torch.Tensor):
105105logits = logits .unsqueeze (1 )
106106assert logits .ndim == 3 ,f"Bad logits shape, expect [num_tokens, beam_width, vocab_size], got{ logits .shape } "
107107
108- if self .chunked_mode :
108+ if self .use_chunked_logits :
109109if self .beam_width == - 1 :
110110self ._init_chunked_storage (logits )
111111self ._add_fragment (logits )
@@ -181,7 +181,7 @@ def _transfer_chunk_to_host(self):
181181
182182def finalize_transfer (self ):
183183"""Force transfer of any remaining fragments to host (for chunked mode)"""
184- if self .chunked_mode and hasattr (
184+ if self .use_chunked_logits and hasattr (
185185self ,'_device_fragments' )and self ._device_fragments :
186186self ._transfer_chunk_to_host ()
187187
@@ -243,20 +243,20 @@ def __init__(self,
243243return_context_logits :bool = False ,
244244return_generation_logits :bool = False ,
245245exclude_last_generation_logits :bool = False ,
246- chunked_mode :bool = False ,
246+ use_chunked_logits :bool = True ,
247247chunk_size :int = 8 ):
248248self ._streaming = streaming
249249self ._context_logits = LogitsStorage (
250250prompt_len ,
251251use_device_memory ,
252- chunked_mode = chunked_mode ,
252+ use_chunked_logits = use_chunked_logits ,
253253streaming = streaming ,
254254chunk_size = chunk_size )if return_context_logits else None
255255self ._generation_logits = LogitsStorage (
256256max_new_tokens ,
257257use_device_memory ,
258258exclude_last_generation_logits ,
259- chunked_mode = chunked_mode ,
259+ use_chunked_logits = use_chunked_logits ,
260260streaming = streaming ,
261261chunk_size = chunk_size )if return_generation_logits else None
262262self ._log_probs = LogProbStorage ()if return_log_probs else None
@@ -394,7 +394,7 @@ def __init__(
394394is_draft :bool = False ,
395395seq_slot :Optional [int ]= None ,
396396target_seq_slot :Optional [int ]= None ,
397- use_chunked_logits :bool = False ,
397+ use_chunked_logits :bool = True ,
398398logits_chunk_size :int = 8 ,
399399** kwargs ):
400400
@@ -466,7 +466,7 @@ def __init__(
466466return_context_logits ,
467467return_generation_logits ,
468468exclude_last_generation_logits ,
469- chunked_mode = use_chunked_logits ,
469+ use_chunked_logits = use_chunked_logits ,
470470chunk_size = logits_chunk_size )
471471self .child_requests = []
472472