Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit10fbddd

Browse files
committed
feat: add cache_salt in LLM.generate and refactor test_return_logits.py
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
1 parent72d65d0 commit10fbddd

File tree

4 files changed

+193
-87
lines changed

4 files changed

+193
-87
lines changed

‎tensorrt_llm/llmapi/llm.py‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def generate(
262262
DisaggregatedParams,Sequence[DisaggregatedParams]]]=None,
263263
scheduling_params:Optional[Union[SchedulingParams,
264264
List[SchedulingParams]]]=None,
265+
cache_salt:Optional[Union[str,Sequence[str]]]=None,
265266
)->Union[RequestOutput,List[RequestOutput]]:
266267
"""Generate output for the given prompts in the synchronous mode.
267268
Synchronous generation accepts either single prompt or batched prompts.
@@ -282,6 +283,7 @@ def generate(
282283
Disaggregated parameters. Defaults to None.
283284
scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, List[tensorrt_llm.scheduling_params.SchedulingParams], optional):
284285
Scheduling parameters. Defaults to None.
286+
cache_salt (str, optional): If specified, KV cache will be salted with the provided string to limit the kv cache reuse to the requests with the same string. Defaults to None.
285287
Returns:
286288
Union[tensorrt_llm.llmapi.RequestOutput, List[tensorrt_llm.llmapi.RequestOutput]]: The output data of the completion request to the LLM.
287289
"""
@@ -312,7 +314,9 @@ def _item_at(maybe_batched: Union[Any, Sequence[Any]], pos: int) -> Any:
312314
i),
313315
disaggregated_params=_item_at(disaggregated_params,i),
314316
scheduling_params=_item_at(scheduling_params,i),
315-
streaming=False)
317+
cache_salt=_item_at(cache_salt,i),
318+
streaming=False,
319+
)
316320
futures.append(future)
317321

318322
forfutureintqdm(futures,

‎tests/integration/test_lists/test-db/l0_a30.yml‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ l0_a30:
2121
-unittest/_torch/modeling -k "modeling_out_of_tree"
2222
-unittest/_torch/auto_deploy/unit/singlegpu
2323
-unittest/_torch/sampler/test_beam_search.py
24+
-unittest/_torch/sampler/test_return_logits.py
2425
-test_e2e.py::test_openai_completions_with_logit_bias[torch_sampler]
2526
-test_e2e.py::test_openai_chat_with_logit_bias[torch_sampler]
2627
-test_e2e.py::test_openai_completions_with_logit_bias[trtllm_sampler]

‎tests/unittest/_torch/sampler/test_return_logits.py‎

Lines changed: 184 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,65 @@
99
fromtensorrt_llm.llmapi.llm_utilsimportBuildConfig,KvCacheConfig
1010

1111
prompts= ["A B C"]
12-
global_kvcache_config=KvCacheConfig(max_tokens=10000)
12+
global_kvcache_config=KvCacheConfig(
13+
max_tokens=10000,
14+
enable_block_reuse=True,
15+
)
1316

1417

15-
@force_ampere# Save H100 resource
16-
@pytest.mark.parametrize("return_log_probs", [False,True])
17-
@pytest.mark.parametrize("gather_generation_logits", [False,True])
18-
@pytest.mark.parametrize("gather_context_logits", [False,True])
19-
@pytest.mark.parametrize("sampler_type", ["TRTLLMSampler","TorchSampler"])
20-
@pytest.mark.parametrize("disable_overlap_scheduler", [False,True])
21-
deftest_generate_with_return_logits(disable_overlap_scheduler:bool,
22-
sampler_type:str,
23-
gather_context_logits:bool,
24-
gather_generation_logits:bool,
25-
return_log_probs:bool):
26-
ifnot (gather_context_logitsorgather_generation_logits
27-
orreturn_log_probs):# prune space
28-
pytest.skip("Nothing to test")
18+
@pytest.fixture(scope="module",params=[False,True])
19+
defgather_generation_logits_fixture(request)->bool:
20+
returnrequest.param
21+
22+
23+
@pytest.fixture(scope="module",params=[False,True])
24+
defgather_context_logits_fixture(request)->bool:
25+
returnrequest.param
26+
27+
28+
@pytest.fixture(scope="module",params=[False,True])
29+
defdisable_overlap_scheduler_fixture(request)->bool:
30+
returnrequest.param
31+
32+
33+
@pytest.fixture(scope="module",params=["TRTLLMSampler","TorchSampler"])
34+
defsampler_type_fixture(request)->str:
35+
returnrequest.param
36+
37+
38+
classCacheSalter:
39+
40+
_salt=0
41+
42+
@classmethod
43+
defget_salt_unique(cls)->str:
44+
cls._salt+=1
45+
returnstr(cls._salt)
46+
47+
@classmethod
48+
defget_salt_shared(cls)->str:
49+
returnstr(0)
50+
51+
@classmethod
52+
defget_salt(cls,reuse_cache:bool)->str:
53+
ifreuse_cache:
54+
salt=cls.get_salt_shared()
55+
else:
56+
salt=cls.get_salt_unique()
57+
returnsalt
58+
59+
60+
@pytest.fixture(scope="module")
61+
defllm(
62+
gather_context_logits_fixture:bool,
63+
gather_generation_logits_fixture:bool,
64+
sampler_type_fixture:str,
65+
disable_overlap_scheduler_fixture:bool,
66+
):
67+
gather_context_logits=gather_context_logits_fixture
68+
gather_generation_logits=gather_generation_logits_fixture
69+
sampler_type=sampler_type_fixture
70+
disable_overlap_scheduler=disable_overlap_scheduler_fixture
2971

3072
build_config=BuildConfig()
3173
build_config.gather_context_logits=gather_context_logits
@@ -42,100 +84,156 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool,
4284
disable_overlap_scheduler=disable_overlap_scheduler,
4385
)
4486

87+
# FIXME: Sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178.
88+
# Remove patch below once fixed.
89+
old_exit=LLM.__exit__
90+
91+
def_exit_with_xfail_on_timeout(self,exc_type,exc_value,
92+
traceback)->bool:
93+
import_pytest.outcomes
94+
try:
95+
returnold_exit(self,exc_type,exc_value,traceback)
96+
except_pytest.outcomes.Failedase:
97+
ife.msgand"pytest-timeout"ine.msg.lower():
98+
pytest.xfail(
99+
"Known LLM shutdown issue (https://nvbugs/5577178).")
100+
else:
101+
raise
102+
103+
withpytest.MonkeyPatch.context()aspatch:
104+
patch.setattr(LLM,"__exit__",_exit_with_xfail_on_timeout)
105+
106+
withllm:
107+
yieldllm
108+
109+
110+
@force_ampere# Save H100 resource
111+
@pytest.mark.parametrize("reuse_cache", [False,True])
112+
@pytest.mark.parametrize("return_log_probs", [False,True])
113+
# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178
114+
# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134
115+
@pytest.mark.timeout(120,method="signal")
116+
@pytest.mark.threadleak(enabled=False)
117+
deftest_generate_with_return_logits(
118+
llm,
119+
gather_context_logits_fixture:bool,
120+
gather_generation_logits_fixture:bool,
121+
reuse_cache:bool,
122+
return_log_probs:bool,
123+
):
124+
gather_context_logits=gather_context_logits_fixture
125+
gather_generation_logits=gather_generation_logits_fixture
126+
127+
ifnot (gather_context_logitsorgather_generation_logits
128+
orreturn_log_probs):# prune space
129+
pytest.skip("Nothing to test")
130+
45131
sampling_params=SamplingParams(
46132
max_tokens=8,
47133
return_context_logits=gather_context_logits,
48134
return_generation_logits=gather_generation_logits,
49135
logprobs=return_log_probs,
50136
)
51137

52-
withllm:
53-
foroutputinllm.generate(prompts,sampling_params=sampling_params):
54-
ifgather_context_logits:
55-
assertoutput.context_logitsisnotNone
56-
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
57-
expected_len=len(prompts[0].split())+1
138+
foroutputinllm.generate(
139+
prompts,
140+
sampling_params=sampling_params,
141+
cache_salt=CacheSalter.get_salt(reuse_cache),
142+
):
143+
ifgather_context_logits:
144+
assertoutput.context_logitsisnotNone
145+
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
146+
expected_len=len(prompts[0].split())+1
147+
try:
58148
assertexpected_len==output.context_logits.shape[0]
59-
else:
60-
assertoutput.context_logitsisNone
61-
62-
ifgather_generation_logits:
63-
gen_logits=output.outputs[0].generation_logits
64-
assertgen_logitsisnotNone
65-
assertgen_logits.ndim==2
66-
assertgen_logits.shape[0]==sampling_params.max_tokens
67-
asserttorch.argmax(
68-
gen_logits,dim=1).tolist()==output.outputs[0].token_ids
69-
else:
70-
assertoutput.outputs[0].generation_logitsisNone
71-
72-
ifreturn_log_probs:
73-
assertlen(
74-
output.outputs[0].logprobs)==sampling_params.max_tokens
75-
else:
76-
assertlen(output.outputs[0].logprobs)==0
149+
exceptAssertionError:
150+
# FIXME: Remove this once the bug has been fixed
151+
ifgather_context_logitsandreuse_cache:
152+
pytest.xfail("Known bug: https://nvbugs/5577178")
153+
raise
154+
else:
155+
assertoutput.context_logitsisNone
156+
157+
ifgather_generation_logits:
158+
gen_logits=output.outputs[0].generation_logits
159+
assertgen_logitsisnotNone
160+
assertgen_logits.ndim==2
161+
assertgen_logits.shape[0]==sampling_params.max_tokens
162+
asserttorch.argmax(gen_logits,
163+
dim=1).tolist()==output.outputs[0].token_ids
164+
else:
165+
assertoutput.outputs[0].generation_logitsisNone
166+
167+
ifreturn_log_probs:
168+
assertlen(output.outputs[0].logprobs)==sampling_params.max_tokens
169+
else:
170+
assertlen(output.outputs[0].logprobs)==0
77171

78172

79173
@force_ampere# Save H100 resource
174+
@pytest.mark.parametrize("reuse_cache", [False,True])
80175
@pytest.mark.parametrize("return_log_probs", [False,True])
81-
@pytest.mark.parametrize("gather_generation_logits", [False,True])
82-
@pytest.mark.parametrize("gather_context_logits", [False,True])
83-
@pytest.mark.parametrize("sampler_type", ["TRTLLMSampler","TorchSampler"])
84-
@pytest.mark.parametrize("disable_overlap_scheduler", [False,True])
85-
deftest_generate_async_with_return_logits(disable_overlap_scheduler:bool,
86-
sampler_type:str,
87-
gather_context_logits:bool,
88-
gather_generation_logits:bool,
89-
return_log_probs:bool):
176+
# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178
177+
# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134
178+
@pytest.mark.timeout(120,method="signal")
179+
@pytest.mark.threadleak(enabled=False)
180+
deftest_generate_async_with_return_logits(
181+
llm,
182+
gather_context_logits_fixture:bool,
183+
gather_generation_logits_fixture:bool,
184+
reuse_cache:bool,
185+
return_log_probs:bool,
186+
):
187+
gather_context_logits=gather_context_logits_fixture
188+
gather_generation_logits=gather_generation_logits_fixture
189+
90190
ifnot (gather_context_logitsorgather_generation_logits
91191
orreturn_log_probs):# prune space
92192
pytest.skip("Nothing to test")
93193

94-
build_config=BuildConfig()
95-
build_config.gather_context_logits=gather_context_logits
96-
97-
llm=LLM(
98-
model=os.path.join(llm_models_root(),"llama-models-v2",
99-
"TinyLlama-1.1B-Chat-v1.0"),
100-
kv_cache_config=global_kvcache_config,
101-
build_config=build_config,
102-
gather_generation_logits=gather_generation_logits,
103-
max_batch_size=
104-
128,# reduce buffer sizes, specially for generation logits
105-
sampler_type=sampler_type,
106-
disable_overlap_scheduler=disable_overlap_scheduler,
107-
)
108194
sampling_params=SamplingParams(
109195
max_tokens=8,
110196
return_context_logits=gather_context_logits,
111197
return_generation_logits=gather_generation_logits,
112198
logprobs=return_log_probs)
113199

114-
withllm:
115-
foridx,outputinenumerate(
116-
llm.generate_async(prompts[0],
117-
sampling_params=sampling_params,
118-
streaming=True)):
119-
ifgather_context_logits:
120-
assertoutput.context_logitsisnotNone
121-
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
122-
expected_len=len(prompts[0].split())+1
200+
foridx,outputinenumerate(
201+
llm.generate_async(
202+
prompts[0],
203+
sampling_params=sampling_params,
204+
streaming=True,
205+
cache_salt=CacheSalter.get_salt(reuse_cache),
206+
)):
207+
ifgather_context_logits:
208+
assertoutput.context_logitsisnotNone
209+
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
210+
expected_len=len(prompts[0].split())+1
211+
try:
123212
assertexpected_len==output.context_logits.shape[0]
124-
else:
125-
assertoutput.context_logitsisNone
126-
127-
ifgather_generation_logits:
128-
gen_logits=output.outputs[0].generation_logits
129-
assertgen_logitsisnotNone
130-
assertgen_logits.ndim==2
131-
assertgen_logits.shape[0]==1
213+
exceptAssertionError:
214+
# FIXME: Remove this once the bug has been fixed
215+
ifgather_context_logitsandreuse_cache:
216+
pytest.xfail("Known bug: https://nvbugs/5577178")
217+
raise
218+
else:
219+
assertoutput.context_logitsisNone
220+
221+
ifgather_generation_logits:
222+
gen_logits=output.outputs[0].generation_logits
223+
assertgen_logitsisnotNone
224+
assertgen_logits.ndim==2
225+
assertgen_logits.shape[0]==1
226+
try:
132227
asserttorch.argmax(
133228
gen_logits,
134229
dim=1).tolist()[0]==output.outputs[0].token_ids[-1]
135-
else:
136-
assertoutput.outputs[0].generation_logitsisNone
137-
138-
ifreturn_log_probs:
139-
assertlen(output.outputs[0].logprobs)==idx+1
140-
else:
141-
assertlen(output.outputs[0].logprobs)==0
230+
exceptAssertionError:
231+
# FIXME: Remove xfail once the bug is fixed
232+
pytest.xfail("Known bug: https://nvbugs/5573238")
233+
else:
234+
assertoutput.outputs[0].generation_logitsisNone
235+
236+
ifreturn_log_probs:
237+
assertlen(output.outputs[0].logprobs)==idx+1
238+
else:
239+
assertlen(output.outputs[0].logprobs)==0

‎tests/unittest/api_stability/references/llm.yaml‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ methods:
195195
scheduling_params:
196196
annotation:Union[tensorrt_llm.scheduling_params.SchedulingParams, List[tensorrt_llm.scheduling_params.SchedulingParams], NoneType]
197197
default:null
198+
cache_salt:
199+
annotation:Optional[str]
200+
default:null
198201
return_annotation:Union[tensorrt_llm.llmapi.llm.RequestOutput, List[tensorrt_llm.llmapi.llm.RequestOutput]]
199202
generate_async:
200203
parameters:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp