@@ -20,52 +20,9 @@ def sample_logits():
2020return torch .randn (1 ,1 ,1000 ,device = 'cuda' )
2121
2222
23- @pytest .fixture
24- def chunked_request ():
25- """Create LlmRequest with chunked logits enabled"""
26- return LlmRequest (request_id = 100 ,
27- max_new_tokens = 10 ,
28- input_tokens = [1 ,2 ,3 ],
29- sampling_config = SamplingConfig (),
30- is_streaming = False ,
31- return_generation_logits = True ,
32- use_chunked_generation_logits = True ,
33- logits_chunk_size = 4 )
34-
35-
36- @pytest .fixture
37- def non_chunked_request ():
38- """Create LlmRequest with chunked logits disabled"""
39- return LlmRequest (request_id = 101 ,
40- max_new_tokens = 10 ,
41- input_tokens = [1 ,2 ,3 ],
42- sampling_config = SamplingConfig (),
43- is_streaming = False ,
44- return_generation_logits = True ,
45- use_chunked_generation_logits = False )
46-
47-
4823class TestLogitsStorage :
4924"""Unit tests for LogitsStorage class"""
5025
51- def test_initialization (self ):
52- """Test LogitsStorage initialization with different parameters"""
53- # Test basic initialization
54- storage = LogitsStorage (seq_length = 10 ,
55- use_device_memory = True ,
56- should_exclude_last = False ,
57- use_chunked_generation_logits = False ,
58- chunk_size = 8 )
59-
60- assert storage .seq_length == 10
61- assert storage .use_device_memory is True
62- assert storage ._should_exclude_last is False
63- assert storage .use_chunked_generation_logits is False
64- assert storage .chunk_size == 8
65- assert storage ._logits_indices == []
66- assert storage .beam_width == - 1
67- assert storage .vocab_size == - 1
68-
6926def test_initialization_chunked_mode (self ):
7027"""Test LogitsStorage initialization in chunked mode"""
7128storage = LogitsStorage (seq_length = 10 ,
@@ -79,23 +36,6 @@ def test_initialization_chunked_mode(self):
7936assert storage ._device_fragments == []
8037assert storage ._current_position == 0
8138
82- def test_append_3d_logits (self ,sample_logits ):
83- """Test appending 3D logits"""
84- storage = LogitsStorage (seq_length = 10 ,
85- use_chunked_generation_logits = False )
86- storage .append (sample_logits )
87-
88- assert storage .beam_width == 1
89- assert storage .vocab_size == 1000
90-
91- def test_append_invalid_shape (self ):
92- """Test appending logits with invalid shape"""
93- storage = LogitsStorage (seq_length = 10 ,
94- use_chunked_generation_logits = False )
95-
96- with pytest .raises (AssertionError ):
97- storage .append (torch .randn (1000 ))# 1D - should fail
98-
9939def test_append_chunked_mode_streaming (self ,sample_logits ):
10040"""Test append behavior in chunked streaming mode"""
10141storage = LogitsStorage (seq_length = 10 ,
@@ -145,17 +85,6 @@ def test_finalize_chunked_transfer_non_chunked_mode(self):
14585# Should not raise any errors
14686storage .finalize_chunked_transfer ()
14787
148- def test_storage_overflow (self ,sample_logits ):
149- """Test storage overflow handling"""
150- storage = LogitsStorage (seq_length = 2 ,
151- use_chunked_generation_logits = False )
152- storage .append (sample_logits )
153- storage .append (sample_logits )
154-
155- # This should cause overflow
156- with pytest .raises (ValueError ,match = "LogitsStorage overflow" ):
157- storage .append (sample_logits )
158-
15988
16089class TestPyResult :
16190"""Unit tests for PyResult class"""
@@ -202,27 +131,6 @@ def test_transfer_remaining_device_logits(self, sample_logits):
202131
203132# Should not raise errors
204133
205- def test_context_generation_logits_property (self ,sample_logits ):
206- """Test context_logits property"""
207- result = PyResult (prompt_len = 5 ,
208- max_new_tokens = 10 ,
209- return_context_logits = True ,
210- return_generation_logits = True ,
211- use_chunked_generation_logits = False )
212-
213- result .append_context_logits (sample_logits )
214- context_logits = result .context_logits
215-
216- assert context_logits is not None
217- assert context_logits .shape == (1 ,1000 )# Should remove beam dimension
218-
219- result .append_generation_logits (sample_logits )
220- generation_logits = result .generation_logits
221-
222- assert generation_logits is not None
223- assert generation_logits .shape == (1 ,1 ,1000
224- )# Should transpose dimensions
225-
226134
227135class TestLlmRequest :
228136"""Unit tests for LlmRequest class"""
@@ -251,7 +159,7 @@ def test_initialization_chunked_logits(self):
251159assert request_streaming .py_logits_chunk_size == 1 # 1 in streaming mode
252160
253161
254- class TestChunkedLogitsIntegration :
162+ class TestChunkedLogitsComplicated :
255163"""Integration tests for chunked logits functionality"""
256164
257165def test_chunked_vs_non_chunked_equivalence (self ,sample_logits ):