@@ -45,7 +45,7 @@ def __init__(self,
4545seq_length :int ,
4646use_device_memory = True ,
4747should_exclude_last = False ,
48- chunked_mode = False ,
48+ use_chunked_logits = False ,
4949streaming = False ,
5050chunk_size = 8 ):
5151if should_exclude_last :
@@ -55,7 +55,7 @@ def __init__(self,
5555self .seq_length = seq_length
5656self .use_device_memory = use_device_memory
5757self ._should_exclude_last = should_exclude_last
58- self .chunked_mode = chunked_mode
58+ self .use_chunked_logits = use_chunked_logits
5959self .chunk_size = chunk_size
6060self .streaming = streaming
6161self ._logits_indices = []
@@ -66,7 +66,7 @@ def __init__(self,
6666self .vocab_size = - 1
6767
6868# Chunked mode: device-side fragments
69- if chunked_mode :
69+ if use_chunked_logits :
7070self ._device_fragments :List [torch .Tensor ]= []
7171self ._current_position = 0
7272
@@ -103,7 +103,7 @@ def append(self, logits: torch.Tensor):
103103logits = logits .unsqueeze (1 )
104104assert logits .ndim == 3 ,f"Bad logits shape, expect [num_tokens, beam_width, vocab_size], got{ logits .shape } "
105105
106- if self .chunked_mode :
106+ if self .use_chunked_logits :
107107if self .beam_width == - 1 :
108108self ._init_chunked_storage (logits )
109109self ._add_fragment (logits )
@@ -179,7 +179,7 @@ def _transfer_chunk_to_host(self):
179179
180180def finalize_transfer (self ):
181181"""Force transfer of any remaining fragments to host (for chunked mode)"""
182- if self .chunked_mode and hasattr (
182+ if self .use_chunked_logits and hasattr (
183183self ,'_device_fragments' )and self ._device_fragments :
184184self ._transfer_chunk_to_host ()
185185
@@ -241,20 +241,20 @@ def __init__(self,
241241return_context_logits :bool = False ,
242242return_generation_logits :bool = False ,
243243exclude_last_generation_logits :bool = False ,
244- chunked_mode :bool = False ,
244+ use_chunked_logits :bool = True ,
245245chunk_size :int = 8 ):
246246self ._streaming = streaming
247247self ._context_logits = LogitsStorage (
248248prompt_len ,
249249use_device_memory ,
250- chunked_mode = chunked_mode ,
250+ use_chunked_logits = use_chunked_logits ,
251251streaming = streaming ,
252252chunk_size = chunk_size )if return_context_logits else None
253253self ._generation_logits = LogitsStorage (
254254max_new_tokens ,
255255use_device_memory ,
256256exclude_last_generation_logits ,
257- chunked_mode = chunked_mode ,
257+ use_chunked_logits = use_chunked_logits ,
258258streaming = streaming ,
259259chunk_size = chunk_size )if return_generation_logits else None
260260self ._log_probs = LogProbStorage ()if return_log_probs else None
@@ -392,7 +392,7 @@ def __init__(
392392is_draft :bool = False ,
393393seq_slot :Optional [int ]= None ,
394394target_seq_slot :Optional [int ]= None ,
395- use_chunked_logits :bool = False ,
395+ use_chunked_logits :bool = True ,
396396logits_chunk_size :int = 8 ,
397397** kwargs ):
398398
@@ -464,7 +464,7 @@ def __init__(
464464return_context_logits ,
465465return_generation_logits ,
466466exclude_last_generation_logits ,
467- chunked_mode = use_chunked_logits ,
467+ use_chunked_logits = use_chunked_logits ,
468468chunk_size = logits_chunk_size )
469469self .child_requests = []
470470