99from tensorrt_llm .llmapi .llm_utils import BuildConfig ,KvCacheConfig
1010
1111prompts = ["A B C" ]
12- global_kvcache_config = KvCacheConfig (max_tokens = 10000 )
12+ global_kvcache_config = KvCacheConfig (
13+ max_tokens = 10000 ,
14+ enable_block_reuse = True ,
15+ )
1316
1417
15- @force_ampere # Save H100 resource
16- @pytest .mark .parametrize ("return_log_probs" , [False ,True ])
17- @pytest .mark .parametrize ("gather_generation_logits" , [False ,True ])
18- @pytest .mark .parametrize ("gather_context_logits" , [False ,True ])
19- @pytest .mark .parametrize ("sampler_type" , ["TRTLLMSampler" ,"TorchSampler" ])
20- @pytest .mark .parametrize ("disable_overlap_scheduler" , [False ,True ])
21- def test_generate_with_return_logits (disable_overlap_scheduler :bool ,
22- sampler_type :str ,
23- gather_context_logits :bool ,
24- gather_generation_logits :bool ,
25- return_log_probs :bool ):
26- if not (gather_context_logits or gather_generation_logits
27- or return_log_probs ):# prune space
28- pytest .skip ("Nothing to test" )
18+ @pytest .fixture (scope = "module" ,params = [False ,True ])
19+ def gather_generation_logits_fixture (request )-> bool :
20+ return request .param
21+
22+
23+ @pytest .fixture (scope = "module" ,params = [False ,True ])
24+ def gather_context_logits_fixture (request )-> bool :
25+ return request .param
26+
27+
28+ @pytest .fixture (scope = "module" ,params = [False ,True ])
29+ def disable_overlap_scheduler_fixture (request )-> bool :
30+ return request .param
31+
32+
33+ @pytest .fixture (scope = "module" ,params = ["TRTLLMSampler" ,"TorchSampler" ])
34+ def sampler_type_fixture (request )-> str :
35+ return request .param
36+
37+
38+ class CacheSalter :
39+
40+ _salt = 0
41+
42+ @classmethod
43+ def get_salt_unique (cls )-> str :
44+ cls ._salt += 1
45+ return str (cls ._salt )
46+
47+ @classmethod
48+ def get_salt_shared (cls )-> str :
49+ return str (0 )
50+
51+ @classmethod
52+ def get_salt (cls ,reuse_cache :bool )-> str :
53+ if reuse_cache :
54+ salt = cls .get_salt_shared ()
55+ else :
56+ salt = cls .get_salt_unique ()
57+ return salt
58+
59+
60+ @pytest .fixture (scope = "module" )
61+ def llm (
62+ gather_context_logits_fixture :bool ,
63+ gather_generation_logits_fixture :bool ,
64+ sampler_type_fixture :str ,
65+ disable_overlap_scheduler_fixture :bool ,
66+ ):
67+ gather_context_logits = gather_context_logits_fixture
68+ gather_generation_logits = gather_generation_logits_fixture
69+ sampler_type = sampler_type_fixture
70+ disable_overlap_scheduler = disable_overlap_scheduler_fixture
2971
3072build_config = BuildConfig ()
3173build_config .gather_context_logits = gather_context_logits
@@ -42,100 +84,156 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool,
4284disable_overlap_scheduler = disable_overlap_scheduler ,
4385 )
4486
87+ # FIXME: Sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178.
88+ # Remove patch below once fixed.
89+ old_exit = LLM .__exit__
90+
91+ def _exit_with_xfail_on_timeout (self ,exc_type ,exc_value ,
92+ traceback )-> bool :
93+ import _pytest .outcomes
94+ try :
95+ return old_exit (self ,exc_type ,exc_value ,traceback )
96+ except _pytest .outcomes .Failed as e :
97+ if e .msg and "pytest-timeout" in e .msg .lower ():
98+ pytest .xfail (
99+ "Known LLM shutdown issue (https://nvbugs/5577178)." )
100+ else :
101+ raise
102+
103+ with pytest .MonkeyPatch .context ()as patch :
104+ patch .setattr (LLM ,"__exit__" ,_exit_with_xfail_on_timeout )
105+
106+ with llm :
107+ yield llm
108+
109+
110+ @force_ampere # Save H100 resource
111+ @pytest .mark .parametrize ("reuse_cache" , [False ,True ])
112+ @pytest .mark .parametrize ("return_log_probs" , [False ,True ])
113+ # FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178
114+ # NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134
115+ @pytest .mark .timeout (120 ,method = "signal" )
116+ @pytest .mark .threadleak (enabled = False )
117+ def test_generate_with_return_logits (
118+ llm ,
119+ gather_context_logits_fixture :bool ,
120+ gather_generation_logits_fixture :bool ,
121+ reuse_cache :bool ,
122+ return_log_probs :bool ,
123+ ):
124+ gather_context_logits = gather_context_logits_fixture
125+ gather_generation_logits = gather_generation_logits_fixture
126+
127+ if not (gather_context_logits or gather_generation_logits
128+ or return_log_probs ):# prune space
129+ pytest .skip ("Nothing to test" )
130+
45131sampling_params = SamplingParams (
46132max_tokens = 8 ,
47133return_context_logits = gather_context_logits ,
48134return_generation_logits = gather_generation_logits ,
49135logprobs = return_log_probs ,
50136 )
51137
52- with llm :
53- for output in llm .generate (prompts ,sampling_params = sampling_params ):
54- if gather_context_logits :
55- assert output .context_logits is not None
56- # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
57- expected_len = len (prompts [0 ].split ())+ 1
138+ for output in llm .generate (
139+ prompts ,
140+ sampling_params = sampling_params ,
141+ cache_salt = CacheSalter .get_salt (reuse_cache ),
142+ ):
143+ if gather_context_logits :
144+ assert output .context_logits is not None
145+ # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
146+ expected_len = len (prompts [0 ].split ())+ 1
147+ try :
58148assert expected_len == output .context_logits .shape [0 ]
59- else :
60- assert output .context_logits is None
61-
62- if gather_generation_logits :
63- gen_logits = output .outputs [0 ].generation_logits
64- assert gen_logits is not None
65- assert gen_logits .ndim == 2
66- assert gen_logits .shape [0 ]== sampling_params .max_tokens
67- assert torch .argmax (
68- gen_logits ,dim = 1 ).tolist ()== output .outputs [0 ].token_ids
69- else :
70- assert output .outputs [0 ].generation_logits is None
71-
72- if return_log_probs :
73- assert len (
74- output .outputs [0 ].logprobs )== sampling_params .max_tokens
75- else :
76- assert len (output .outputs [0 ].logprobs )== 0
149+ except AssertionError :
150+ # FIXME: Remove this once the bug has been fixed
151+ if gather_context_logits and reuse_cache :
152+ pytest .xfail ("Known bug: https://nvbugs/5577178" )
153+ raise
154+ else :
155+ assert output .context_logits is None
156+
157+ if gather_generation_logits :
158+ gen_logits = output .outputs [0 ].generation_logits
159+ assert gen_logits is not None
160+ assert gen_logits .ndim == 2
161+ assert gen_logits .shape [0 ]== sampling_params .max_tokens
162+ assert torch .argmax (gen_logits ,
163+ dim = 1 ).tolist ()== output .outputs [0 ].token_ids
164+ else :
165+ assert output .outputs [0 ].generation_logits is None
166+
167+ if return_log_probs :
168+ assert len (output .outputs [0 ].logprobs )== sampling_params .max_tokens
169+ else :
170+ assert len (output .outputs [0 ].logprobs )== 0
77171
78172
79173@force_ampere # Save H100 resource
174+ @pytest .mark .parametrize ("reuse_cache" , [False ,True ])
80175@pytest .mark .parametrize ("return_log_probs" , [False ,True ])
81- @pytest .mark .parametrize ("gather_generation_logits" , [False ,True ])
82- @pytest .mark .parametrize ("gather_context_logits" , [False ,True ])
83- @pytest .mark .parametrize ("sampler_type" , ["TRTLLMSampler" ,"TorchSampler" ])
84- @pytest .mark .parametrize ("disable_overlap_scheduler" , [False ,True ])
85- def test_generate_async_with_return_logits (disable_overlap_scheduler :bool ,
86- sampler_type :str ,
87- gather_context_logits :bool ,
88- gather_generation_logits :bool ,
89- return_log_probs :bool ):
176+ # FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178
177+ # NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134
178+ @pytest .mark .timeout (120 ,method = "signal" )
179+ @pytest .mark .threadleak (enabled = False )
180+ def test_generate_async_with_return_logits (
181+ llm ,
182+ gather_context_logits_fixture :bool ,
183+ gather_generation_logits_fixture :bool ,
184+ reuse_cache :bool ,
185+ return_log_probs :bool ,
186+ ):
187+ gather_context_logits = gather_context_logits_fixture
188+ gather_generation_logits = gather_generation_logits_fixture
189+
90190if not (gather_context_logits or gather_generation_logits
91191or return_log_probs ):# prune space
92192pytest .skip ("Nothing to test" )
93193
94- build_config = BuildConfig ()
95- build_config .gather_context_logits = gather_context_logits
96-
97- llm = LLM (
98- model = os .path .join (llm_models_root (),"llama-models-v2" ,
99- "TinyLlama-1.1B-Chat-v1.0" ),
100- kv_cache_config = global_kvcache_config ,
101- build_config = build_config ,
102- gather_generation_logits = gather_generation_logits ,
103- max_batch_size =
104- 128 ,# reduce buffer sizes, specially for generation logits
105- sampler_type = sampler_type ,
106- disable_overlap_scheduler = disable_overlap_scheduler ,
107- )
108194sampling_params = SamplingParams (
109195max_tokens = 8 ,
110196return_context_logits = gather_context_logits ,
111197return_generation_logits = gather_generation_logits ,
112198logprobs = return_log_probs )
113199
114- with llm :
115- for idx ,output in enumerate (
116- llm .generate_async (prompts [0 ],
117- sampling_params = sampling_params ,
118- streaming = True )):
119- if gather_context_logits :
120- assert output .context_logits is not None
121- # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
122- expected_len = len (prompts [0 ].split ())+ 1
200+ for idx ,output in enumerate (
201+ llm .generate_async (
202+ prompts [0 ],
203+ sampling_params = sampling_params ,
204+ streaming = True ,
205+ cache_salt = CacheSalter .get_salt (reuse_cache ),
206+ )):
207+ if gather_context_logits :
208+ assert output .context_logits is not None
209+ # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
210+ expected_len = len (prompts [0 ].split ())+ 1
211+ try :
123212assert expected_len == output .context_logits .shape [0 ]
124- else :
125- assert output .context_logits is None
126-
127- if gather_generation_logits :
128- gen_logits = output .outputs [0 ].generation_logits
129- assert gen_logits is not None
130- assert gen_logits .ndim == 2
131- assert gen_logits .shape [0 ]== 1
213+ except AssertionError :
214+ # FIXME: Remove this once the bug has been fixed
215+ if gather_context_logits and reuse_cache :
216+ pytest .xfail ("Known bug: https://nvbugs/5577178" )
217+ raise
218+ else :
219+ assert output .context_logits is None
220+
221+ if gather_generation_logits :
222+ gen_logits = output .outputs [0 ].generation_logits
223+ assert gen_logits is not None
224+ assert gen_logits .ndim == 2
225+ assert gen_logits .shape [0 ]== 1
226+ try :
132227assert torch .argmax (
133228gen_logits ,
134229dim = 1 ).tolist ()[0 ]== output .outputs [0 ].token_ids [- 1 ]
135- else :
136- assert output .outputs [0 ].generation_logits is None
137-
138- if return_log_probs :
139- assert len (output .outputs [0 ].logprobs )== idx + 1
140- else :
141- assert len (output .outputs [0 ].logprobs )== 0
230+ except AssertionError :
231+ # FIXME: Remove xfail once the bug is fixed
232+ pytest .xfail ("Known bug: https://nvbugs/5573238" )
233+ else :
234+ assert output .outputs [0 ].generation_logits is None
235+
236+ if return_log_probs :
237+ assert len (output .outputs [0 ].logprobs )== idx + 1
238+ else :
239+ assert len (output .outputs [0 ].logprobs )== 0