Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit572ae3f

Browse files
committed
chunked generation logics
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
1 parent0c80d1d commit572ae3f

File tree

1 file changed

+25
-39
lines changed

1 file changed

+25
-39
lines changed

‎tests/unittest/_torch/executor/test_chunked_logits.py‎

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python3
21
"""
32
Unit tests for chunked logits functionality in TensorRT-LLM.
43
@@ -30,7 +29,7 @@ def chunked_request():
3029
sampling_config=SamplingConfig(),
3130
is_streaming=False,
3231
return_generation_logits=True,
33-
use_chunked_logits=True,
32+
use_chunked_generation_logits=True,
3433
logits_chunk_size=4)
3534

3635

@@ -43,7 +42,7 @@ def non_chunked_request():
4342
sampling_config=SamplingConfig(),
4443
is_streaming=False,
4544
return_generation_logits=True,
46-
use_chunked_logits=False)
45+
use_chunked_generation_logits=False)
4746

4847

4948
# Test parameters
@@ -62,13 +61,13 @@ def test_initialization(self):
6261
storage=LogitsStorage(seq_length=10,
6362
use_device_memory=True,
6463
should_exclude_last=False,
65-
use_chunked_logits=False,
64+
use_chunked_generation_logits=False,
6665
chunk_size=8)
6766

6867
assertstorage.seq_length==10
6968
assertstorage.use_device_memoryisTrue
7069
assertstorage._should_exclude_lastisFalse
71-
assertstorage.use_chunked_logitsisFalse
70+
assertstorage.use_chunked_generation_logitsisFalse
7271
assertstorage.chunk_size==8
7372
assertstorage._logits_indices== []
7473
assertstorage.beam_width==-1
@@ -77,10 +76,10 @@ def test_initialization(self):
7776
deftest_initialization_chunked_mode(self):
7877
"""Test LogitsStorage initialization in chunked mode"""
7978
storage=LogitsStorage(seq_length=10,
80-
use_chunked_logits=True,
79+
use_chunked_generation_logits=True,
8180
chunk_size=4)
8281

83-
assertstorage.use_chunked_logitsisTrue
82+
assertstorage.use_chunked_generation_logitsisTrue
8483
assertstorage.chunk_size==4
8584
asserthasattr(storage,'_device_fragments')
8685
asserthasattr(storage,'_current_position')
@@ -89,23 +88,23 @@ def test_initialization_chunked_mode(self):
8988

9089
deftest_append_3d_logits(self,sample_logits):
9190
"""Test appending 3D logits"""
92-
storage=LogitsStorage(seq_length=10,use_chunked_logits=False)
91+
storage=LogitsStorage(seq_length=10,use_chunked_generation_logits=False)
9392
storage.append(sample_logits)
9493

9594
assertstorage.beam_width==1
9695
assertstorage.vocab_size==1000
9796

9897
deftest_append_invalid_shape(self):
9998
"""Test appending logits with invalid shape"""
100-
storage=LogitsStorage(seq_length=10,use_chunked_logits=False)
99+
storage=LogitsStorage(seq_length=10,use_chunked_generation_logits=False)
101100

102101
withpytest.raises(AssertionError):
103102
storage.append(torch.randn(1000))# 1D - should fail
104103

105104
deftest_append_chunked_mode_streaming(self,sample_logits):
106105
"""Test append behavior in chunked streaming mode"""
107106
storage=LogitsStorage(seq_length=10,
108-
use_chunked_logits=True,
107+
use_chunked_generation_logits=True,
109108
chunk_size=1)
110109
storage.append(sample_logits)
111110

@@ -116,7 +115,7 @@ def test_append_chunked_mode_streaming(self, sample_logits):
116115
deftest_append_chunked_mode_non_streaming(self,sample_logits):
117116
"""Test append behavior in chunked non-streaming mode"""
118117
storage=LogitsStorage(seq_length=10,
119-
use_chunked_logits=True,
118+
use_chunked_generation_logits=True,
120119
chunk_size=2)
121120

122121
# Add first fragment
@@ -131,7 +130,7 @@ def test_append_chunked_mode_non_streaming(self, sample_logits):
131130
deftest_finalize_transfer_chunked_mode(self,sample_logits):
132131
"""Test finalize_transfer in chunked mode"""
133132
storage=LogitsStorage(seq_length=10,
134-
use_chunked_logits=True,
133+
use_chunked_generation_logits=True,
135134
chunk_size=5)
136135
storage.append(sample_logits)
137136

@@ -145,14 +144,14 @@ def test_finalize_transfer_chunked_mode(self, sample_logits):
145144

146145
deftest_finalize_transfer_non_chunked_mode(self):
147146
"""Test finalize_transfer in non-chunked mode (should be no-op)"""
148-
storage=LogitsStorage(seq_length=10,use_chunked_logits=False)
147+
storage=LogitsStorage(seq_length=10,use_chunked_generation_logits=False)
149148

150149
# Should not raise any errors
151150
storage.finalize_transfer()
152151

153152
deftest_storage_overflow(self,sample_logits):
154153
"""Test storage overflow handling"""
155-
storage=LogitsStorage(seq_length=2,use_chunked_logits=False)
154+
storage=LogitsStorage(seq_length=2,use_chunked_generation_logits=False)
156155
storage.append(sample_logits)
157156
storage.append(sample_logits)
158157

@@ -173,7 +172,7 @@ def test_initialization(self):
173172
return_context_logits=True,
174173
return_generation_logits=True,
175174
exclude_last_generation_logits=False,
176-
use_chunked_logits=True,
175+
use_chunked_generation_logits=True,
177176
chunk_size=4)
178177

179178
assertresult._streamingisFalse
@@ -198,7 +197,7 @@ def test_post_processing_transfer(self, sample_logits):
198197
result=PyResult(prompt_len=5,
199198
max_new_tokens=10,
200199
return_generation_logits=True,
201-
use_chunked_logits=True)
200+
use_chunked_generation_logits=True)
202201

203202
result.append_generation_logits(sample_logits)
204203
result.post_processing_transfer()
@@ -210,7 +209,8 @@ def test_context_generation_logits_property(self, sample_logits):
210209
result=PyResult(prompt_len=5,
211210
max_new_tokens=10,
212211
return_context_logits=True,
213-
use_chunked_logits=False)
212+
return_generation_logits=True,
213+
use_chunked_generation_logits=False)
214214

215215
result.append_context_logits(sample_logits)
216216
context_logits=result.context_logits
@@ -225,20 +225,6 @@ def test_context_generation_logits_property(self, sample_logits):
225225
assertgeneration_logits.shape== (1,1,1000
226226
)# Should transpose dimensions
227227

228-
deftest_generation_logits_property_streaming(self,sample_logits):
229-
"""Test generation_logits property in streaming mode"""
230-
result=PyResult(prompt_len=5,
231-
max_new_tokens=10,
232-
return_generation_logits=True,
233-
use_chunked_logits=False,
234-
streaming=True)
235-
236-
result.append_generation_logits(sample_logits)
237-
generation_logits=result.generation_logits
238-
239-
assertgeneration_logitsisnotNone
240-
assertgeneration_logits.shape== (1,1,1000)
241-
242228

243229
classTestLlmRequest:
244230
"""Unit tests for LlmRequest class"""
@@ -278,7 +264,7 @@ def test_chunked_vs_non_chunked_equivalence(self, sample_logits):
278264
sampling_config=SamplingConfig(),
279265
is_streaming=False,
280266
return_generation_logits=True,
281-
use_chunked_logits=True,
267+
use_chunked_generation_logits=True,
282268
logits_chunk_size=2)
283269

284270
# Create non-chunked request
@@ -288,7 +274,7 @@ def test_chunked_vs_non_chunked_equivalence(self, sample_logits):
288274
sampling_config=SamplingConfig(),
289275
is_streaming=False,
290276
return_generation_logits=True,
291-
use_chunked_logits=False)
277+
use_chunked_generation_logits=False)
292278

293279
# Add same logits to both
294280
for_inrange(5):
@@ -319,7 +305,7 @@ def test_streaming_vs_non_streaming_behavior(self, sample_logits):
319305
sampling_config=SamplingConfig(),
320306
is_streaming=True,
321307
return_generation_logits=True,
322-
use_chunked_logits=True,
308+
use_chunked_generation_logits=True,
323309
logits_chunk_size=3)
324310

325311
# Create non-streaming request
@@ -329,7 +315,7 @@ def test_streaming_vs_non_streaming_behavior(self, sample_logits):
329315
sampling_config=SamplingConfig(),
330316
is_streaming=False,
331317
return_generation_logits=True,
332-
use_chunked_logits=True,
318+
use_chunked_generation_logits=True,
333319
logits_chunk_size=3)
334320

335321
# Add logits one by one
@@ -375,7 +361,7 @@ def test_memory_management(self, sample_logits):
375361
sampling_config=SamplingConfig(),
376362
is_streaming=False,
377363
return_generation_logits=True,
378-
use_chunked_logits=True,
364+
use_chunked_generation_logits=True,
379365
logits_chunk_size=2,
380366
return_logits_device_memory=False# Use host memory
381367
)
@@ -402,7 +388,7 @@ def test_large_sequence_handling(self):
402388
sampling_config=SamplingConfig(),
403389
is_streaming=False,
404390
return_generation_logits=True,
405-
use_chunked_logits=True,
391+
use_chunked_generation_logits=True,
406392
logits_chunk_size=10)
407393

408394
# Add many logits
@@ -447,7 +433,7 @@ def get_memory_usage():
447433
sampling_config=SamplingConfig(),
448434
is_streaming=False,
449435
return_generation_logits=True,
450-
use_chunked_logits=True,
436+
use_chunked_generation_logits=True,
451437
logits_chunk_size=5,
452438
return_logits_device_memory=False)
453439

@@ -464,7 +450,7 @@ def get_memory_usage():
464450
sampling_config=SamplingConfig(),
465451
is_streaming=False,
466452
return_generation_logits=True,
467-
use_chunked_logits=False,
453+
use_chunked_generation_logits=False,
468454
return_logits_device_memory=False)
469455

470456
for_inrange(50):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp