1- #!/usr/bin/env python3
21"""
32Unit tests for chunked logits functionality in TensorRT-LLM.
43
@@ -30,7 +29,7 @@ def chunked_request():
3029sampling_config = SamplingConfig (),
3130is_streaming = False ,
3231return_generation_logits = True ,
33- use_chunked_logits = True ,
32+ use_chunked_generation_logits = True ,
3433logits_chunk_size = 4 )
3534
3635
@@ -43,7 +42,7 @@ def non_chunked_request():
4342sampling_config = SamplingConfig (),
4443is_streaming = False ,
4544return_generation_logits = True ,
46- use_chunked_logits = False )
45+ use_chunked_generation_logits = False )
4746
4847
4948# Test parameters
@@ -62,13 +61,13 @@ def test_initialization(self):
6261storage = LogitsStorage (seq_length = 10 ,
6362use_device_memory = True ,
6463should_exclude_last = False ,
65- use_chunked_logits = False ,
64+ use_chunked_generation_logits = False ,
6665chunk_size = 8 )
6766
6867assert storage .seq_length == 10
6968assert storage .use_device_memory is True
7069assert storage ._should_exclude_last is False
71- assert storage .use_chunked_logits is False
70+ assert storage .use_chunked_generation_logits is False
7271assert storage .chunk_size == 8
7372assert storage ._logits_indices == []
7473assert storage .beam_width == - 1
@@ -77,10 +76,10 @@ def test_initialization(self):
7776def test_initialization_chunked_mode (self ):
7877"""Test LogitsStorage initialization in chunked mode"""
7978storage = LogitsStorage (seq_length = 10 ,
80- use_chunked_logits = True ,
79+ use_chunked_generation_logits = True ,
8180chunk_size = 4 )
8281
83- assert storage .use_chunked_logits is True
82+ assert storage .use_chunked_generation_logits is True
8483assert storage .chunk_size == 4
8584assert hasattr (storage ,'_device_fragments' )
8685assert hasattr (storage ,'_current_position' )
@@ -89,23 +88,23 @@ def test_initialization_chunked_mode(self):
8988
9089def test_append_3d_logits (self ,sample_logits ):
9190"""Test appending 3D logits"""
92- storage = LogitsStorage (seq_length = 10 ,use_chunked_logits = False )
91+ storage = LogitsStorage (seq_length = 10 ,use_chunked_generation_logits = False )
9392storage .append (sample_logits )
9493
9594assert storage .beam_width == 1
9695assert storage .vocab_size == 1000
9796
9897def test_append_invalid_shape (self ):
9998"""Test appending logits with invalid shape"""
100- storage = LogitsStorage (seq_length = 10 ,use_chunked_logits = False )
99+ storage = LogitsStorage (seq_length = 10 ,use_chunked_generation_logits = False )
101100
102101with pytest .raises (AssertionError ):
103102storage .append (torch .randn (1000 ))# 1D - should fail
104103
105104def test_append_chunked_mode_streaming (self ,sample_logits ):
106105"""Test append behavior in chunked streaming mode"""
107106storage = LogitsStorage (seq_length = 10 ,
108- use_chunked_logits = True ,
107+ use_chunked_generation_logits = True ,
109108chunk_size = 1 )
110109storage .append (sample_logits )
111110
@@ -116,7 +115,7 @@ def test_append_chunked_mode_streaming(self, sample_logits):
116115def test_append_chunked_mode_non_streaming (self ,sample_logits ):
117116"""Test append behavior in chunked non-streaming mode"""
118117storage = LogitsStorage (seq_length = 10 ,
119- use_chunked_logits = True ,
118+ use_chunked_generation_logits = True ,
120119chunk_size = 2 )
121120
122121# Add first fragment
@@ -131,7 +130,7 @@ def test_append_chunked_mode_non_streaming(self, sample_logits):
131130def test_finalize_transfer_chunked_mode (self ,sample_logits ):
132131"""Test finalize_transfer in chunked mode"""
133132storage = LogitsStorage (seq_length = 10 ,
134- use_chunked_logits = True ,
133+ use_chunked_generation_logits = True ,
135134chunk_size = 5 )
136135storage .append (sample_logits )
137136
@@ -145,14 +144,14 @@ def test_finalize_transfer_chunked_mode(self, sample_logits):
145144
146145def test_finalize_transfer_non_chunked_mode (self ):
147146"""Test finalize_transfer in non-chunked mode (should be no-op)"""
148- storage = LogitsStorage (seq_length = 10 ,use_chunked_logits = False )
147+ storage = LogitsStorage (seq_length = 10 ,use_chunked_generation_logits = False )
149148
150149# Should not raise any errors
151150storage .finalize_transfer ()
152151
153152def test_storage_overflow (self ,sample_logits ):
154153"""Test storage overflow handling"""
155- storage = LogitsStorage (seq_length = 2 ,use_chunked_logits = False )
154+ storage = LogitsStorage (seq_length = 2 ,use_chunked_generation_logits = False )
156155storage .append (sample_logits )
157156storage .append (sample_logits )
158157
@@ -173,7 +172,7 @@ def test_initialization(self):
173172return_context_logits = True ,
174173return_generation_logits = True ,
175174exclude_last_generation_logits = False ,
176- use_chunked_logits = True ,
175+ use_chunked_generation_logits = True ,
177176chunk_size = 4 )
178177
179178assert result ._streaming is False
@@ -198,7 +197,7 @@ def test_post_processing_transfer(self, sample_logits):
198197result = PyResult (prompt_len = 5 ,
199198max_new_tokens = 10 ,
200199return_generation_logits = True ,
201- use_chunked_logits = True )
200+ use_chunked_generation_logits = True )
202201
203202result .append_generation_logits (sample_logits )
204203result .post_processing_transfer ()
@@ -210,7 +209,8 @@ def test_context_generation_logits_property(self, sample_logits):
210209result = PyResult (prompt_len = 5 ,
211210max_new_tokens = 10 ,
212211return_context_logits = True ,
213- use_chunked_logits = False )
212+ return_generation_logits = True ,
213+ use_chunked_generation_logits = False )
214214
215215result .append_context_logits (sample_logits )
216216context_logits = result .context_logits
@@ -225,20 +225,6 @@ def test_context_generation_logits_property(self, sample_logits):
225225assert generation_logits .shape == (1 ,1 ,1000
226226 )# Should transpose dimensions
227227
228- def test_generation_logits_property_streaming (self ,sample_logits ):
229- """Test generation_logits property in streaming mode"""
230- result = PyResult (prompt_len = 5 ,
231- max_new_tokens = 10 ,
232- return_generation_logits = True ,
233- use_chunked_logits = False ,
234- streaming = True )
235-
236- result .append_generation_logits (sample_logits )
237- generation_logits = result .generation_logits
238-
239- assert generation_logits is not None
240- assert generation_logits .shape == (1 ,1 ,1000 )
241-
242228
243229class TestLlmRequest :
244230"""Unit tests for LlmRequest class"""
@@ -278,7 +264,7 @@ def test_chunked_vs_non_chunked_equivalence(self, sample_logits):
278264sampling_config = SamplingConfig (),
279265is_streaming = False ,
280266return_generation_logits = True ,
281- use_chunked_logits = True ,
267+ use_chunked_generation_logits = True ,
282268logits_chunk_size = 2 )
283269
284270# Create non-chunked request
@@ -288,7 +274,7 @@ def test_chunked_vs_non_chunked_equivalence(self, sample_logits):
288274sampling_config = SamplingConfig (),
289275is_streaming = False ,
290276return_generation_logits = True ,
291- use_chunked_logits = False )
277+ use_chunked_generation_logits = False )
292278
293279# Add same logits to both
294280for _ in range (5 ):
@@ -319,7 +305,7 @@ def test_streaming_vs_non_streaming_behavior(self, sample_logits):
319305sampling_config = SamplingConfig (),
320306is_streaming = True ,
321307return_generation_logits = True ,
322- use_chunked_logits = True ,
308+ use_chunked_generation_logits = True ,
323309logits_chunk_size = 3 )
324310
325311# Create non-streaming request
@@ -329,7 +315,7 @@ def test_streaming_vs_non_streaming_behavior(self, sample_logits):
329315sampling_config = SamplingConfig (),
330316is_streaming = False ,
331317return_generation_logits = True ,
332- use_chunked_logits = True ,
318+ use_chunked_generation_logits = True ,
333319logits_chunk_size = 3 )
334320
335321# Add logits one by one
@@ -375,7 +361,7 @@ def test_memory_management(self, sample_logits):
375361sampling_config = SamplingConfig (),
376362is_streaming = False ,
377363return_generation_logits = True ,
378- use_chunked_logits = True ,
364+ use_chunked_generation_logits = True ,
379365logits_chunk_size = 2 ,
380366return_logits_device_memory = False # Use host memory
381367 )
@@ -402,7 +388,7 @@ def test_large_sequence_handling(self):
402388sampling_config = SamplingConfig (),
403389is_streaming = False ,
404390return_generation_logits = True ,
405- use_chunked_logits = True ,
391+ use_chunked_generation_logits = True ,
406392logits_chunk_size = 10 )
407393
408394# Add many logits
@@ -447,7 +433,7 @@ def get_memory_usage():
447433sampling_config = SamplingConfig (),
448434is_streaming = False ,
449435return_generation_logits = True ,
450- use_chunked_logits = True ,
436+ use_chunked_generation_logits = True ,
451437logits_chunk_size = 5 ,
452438return_logits_device_memory = False )
453439
@@ -464,7 +450,7 @@ def get_memory_usage():
464450sampling_config = SamplingConfig (),
465451is_streaming = False ,
466452return_generation_logits = True ,
467- use_chunked_logits = False ,
453+ use_chunked_generation_logits = False ,
468454return_logits_device_memory = False )
469455
470456for _ in range (50 ):