Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit533add5

Browse files
authored
[TRTLLM-8598][feat] enable n > 1 in OpenAI API with PyTorch backend (#8951)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
1 parent6ff82ea commit533add5

File tree

4 files changed

+101
-19
lines changed

4 files changed

+101
-19
lines changed

‎tensorrt_llm/serve/chat_utils.py‎

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,6 @@ def parse_chat_messages_coroutines(
217217
),mm_placeholder_counts
218218

219219

220-
defcheck_multiple_response(n:int,backend:Optional[str]):
221-
ifn>1andbackend=="pytorch":
222-
raiseValueError(
223-
"Multiple response is not supported in PyTorch workflow")
224-
225-
226220
defmake_tool_call_id(id_type:str="random",func_name=None,idx=None):
227221
ifid_type=="kimi_k2":
228222
returnf"functions.{func_name}:{idx}"

‎tensorrt_llm/serve/openai_server.py‎

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@
3434
fromtensorrt_llm.llmapi.llmimportRequestOutput
3535
fromtensorrt_llm.loggerimportlogger
3636
fromtensorrt_llm.metrics.collectorimportMetricsCollector
37-
fromtensorrt_llm.serve.chat_utilsimport (check_multiple_response,
38-
parse_chat_messages_coroutines)
37+
fromtensorrt_llm.serve.chat_utilsimportparse_chat_messages_coroutines
3938
fromtensorrt_llm.serve.cluster_storageimportcreate_cluster_storage_client
4039
fromtensorrt_llm.serve.disagg_auto_scalingimportDisaggClusterWorker
4140
fromtensorrt_llm.serve.metadata_serverimportcreate_metadata_server
@@ -484,7 +483,6 @@ async def create_chat_response(
484483
returnchat_response
485484

486485
try:
487-
check_multiple_response(request.n,self.llm.args.backend)
488486
conversation:List[ConversationMessage]= []
489487
tool_dicts=Noneifrequest.toolsisNoneelse [
490488
tool.model_dump()fortoolinrequest.tools
@@ -595,7 +593,6 @@ async def create_mm_embedding_response(promise: RequestOutput):
595593
)
596594

597595
try:
598-
check_multiple_response(request.n,self.llm.args.backend)
599596
conversation:List[ConversationMessage]= []
600597
tool_dicts=Noneifrequest.toolsisNoneelse [
601598
tool.model_dump()fortoolinrequest.tools
@@ -730,7 +727,6 @@ async def generator_wrapper(generator: AsyncIterator[Any]):
730727
yield"data: [DONE]\n\n"
731728

732729
try:
733-
check_multiple_response(request.n,self.llm.args.backend)
734730
ifisinstance(request.prompt,str)or \
735731
(isinstance(request.prompt,list)andisinstance(request.prompt[0],int)):
736732
prompts= [request.prompt]

‎tests/unittest/llmapi/apps/_test_openai_chat.py‎

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
6868
temp_extra_llm_api_options_file:str,num_postprocess_workers:int):
6969
model_path=get_model_path(model_name)
7070
args= ["--backend",f"{backend}"]
71+
args.extend(["--kv_cache_free_gpu_memory_fraction",
72+
"0.2"])# for co-existence with other servers
7173
ifbackend=="trt":
7274
args.extend(["--max_beam_width","4"])
7375
ifextra_llm_api_options:
@@ -78,11 +80,34 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
7880
yieldremote_server
7981

8082

83+
@pytest.fixture(scope="module")
84+
defserver_with_beam_search(model_name:str,backend:str,
85+
extra_llm_api_options:bool,
86+
temp_extra_llm_api_options_file:str,
87+
num_postprocess_workers:int):
88+
model_path=get_model_path(model_name)
89+
args= ["--backend",f"{backend}"]
90+
args.extend(["--kv_cache_free_gpu_memory_fraction",
91+
"0.2"])# for co-existence with other servers
92+
args.extend(["--max_beam_width","2"])
93+
ifextra_llm_api_options:
94+
args.extend(
95+
["--extra_llm_api_options",temp_extra_llm_api_options_file])
96+
args.extend(["--num_postprocess_workers",f"{num_postprocess_workers}"])
97+
withRemoteOpenAIServer(model_path,args)asremote_server:
98+
yieldremote_server
99+
100+
81101
@pytest.fixture(scope="module")
82102
defclient(server:RemoteOpenAIServer):
83103
returnserver.get_client()
84104

85105

106+
@pytest.fixture(scope="module")
107+
defclient_with_beam_search(server_with_beam_search:RemoteOpenAIServer):
108+
returnserver_with_beam_search.get_client()
109+
110+
86111
@pytest.fixture(scope="module")
87112
defasync_client(server:RemoteOpenAIServer):
88113
returnserver.get_async_client()
@@ -180,7 +205,33 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
180205
backend:str):
181206
ifbackend=="pytorch":
182207
pytest.skip(
183-
"Multiple responses are not supported in PyTorch backend yet")
208+
"'n' not allowed with temperature=0 unless TLLM_ALLOW_N_GREEDY_DECODING=1"
209+
)
210+
messages= [{
211+
"role":"system",
212+
"content":"you are a helpful assistant"
213+
}, {
214+
"role":"user",
215+
"content":"what is 1+1?"
216+
}]
217+
# test n and best_of
218+
chat_completion=client.chat.completions.create(
219+
model=model_name,
220+
messages=messages,
221+
max_completion_tokens=10,
222+
n=2,
223+
temperature=0.0,
224+
extra_body=dict(best_of=4),
225+
)
226+
assertlen(chat_completion.choices)==2
227+
228+
229+
deftest_multiple_responses_and_beam_search(client:openai.OpenAI,
230+
model_name:str,backend:str):
231+
ifbackend=="pytorch":
232+
pytest.skip(
233+
"Mixing beam search and regular requests is not supported in PyTorch backend"
234+
)
184235

185236
messages= [{
186237
"role":"system",
@@ -202,6 +253,7 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
202253
assertchat_completion.choices[
203254
0].message.content!=chat_completion.choices[
204255
1].message.content,"beam search should be different"
256+
205257
# test n and best_of
206258
chat_completion=client.chat.completions.create(
207259
model=model_name,
@@ -214,6 +266,30 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
214266
assertlen(chat_completion.choices)==2
215267

216268

269+
deftest_multiple_responses_with_beam_search(
270+
client_with_beam_search:openai.OpenAI,model_name:str):
271+
messages= [{
272+
"role":"system",
273+
"content":"you are a helpful assistant"
274+
}, {
275+
"role":"user",
276+
"content":"what is 1+1?"
277+
}]
278+
# test beam search
279+
chat_completion=client_with_beam_search.chat.completions.create(
280+
model=model_name,
281+
messages=messages,
282+
max_completion_tokens=10,
283+
n=2,
284+
temperature=0.0,
285+
extra_body=dict(use_beam_search=True),
286+
)
287+
assertlen(chat_completion.choices)==2
288+
assertchat_completion.choices[
289+
0].message.content!=chat_completion.choices[
290+
1].message.content,"beam search should be different"
291+
292+
217293
@pytest.mark.asyncio(loop_scope="module")
218294
asyncdeftest_chat_streaming(async_client:openai.AsyncOpenAI,
219295
model_name:str):

‎tests/unittest/llmapi/apps/_test_openai_completions.py‎

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,21 @@ def num_postprocess_workers(request):
3333
defserver(model_name:str,backend:str,num_postprocess_workers:int):
3434
model_path=get_model_path(model_name)
3535
args= ["--backend",f"{backend}"]
36-
ifbackend=="trt":
37-
args.extend(["--max_beam_width","4"])
36+
args.extend(["--kv_cache_free_gpu_memory_fraction",
37+
"0.2"])# for co-existence with other servers
38+
args.extend(["--num_postprocess_workers",f"{num_postprocess_workers}"])
39+
withRemoteOpenAIServer(model_path,args)asremote_server:
40+
yieldremote_server
41+
42+
43+
@pytest.fixture(scope="module")
44+
defserver_with_beam_search(model_name:str,backend:str,
45+
num_postprocess_workers:int):
46+
model_path=get_model_path(model_name)
47+
args= ["--backend",f"{backend}"]
48+
args.extend(["--kv_cache_free_gpu_memory_fraction",
49+
"0.2"])# for co-existence with other servers
50+
args.extend(["--max_beam_width","2"])
3851
args.extend(["--num_postprocess_workers",f"{num_postprocess_workers}"])
3952
withRemoteOpenAIServer(model_path,args)asremote_server:
4053
yieldremote_server
@@ -50,6 +63,11 @@ def async_client(server: RemoteOpenAIServer):
5063
returnserver.get_async_client()
5164

5265

66+
@pytest.fixture(scope="module")
67+
defasync_client_with_beam_search(server_with_beam_search:RemoteOpenAIServer):
68+
returnserver_with_beam_search.get_async_client()
69+
70+
5371
deftest_single_completion(client:openai.OpenAI,model_name):
5472
completion=client.completions.create(
5573
model=model_name,
@@ -145,12 +163,10 @@ async def test_batch_completions(async_client: openai.AsyncOpenAI, model_name,
145163
@pytest.mark.asyncio(loop_scope="module")
146164
@pytest.mark.parametrize("prompts",
147165
[["Hello, my name is"]*2, [[0,0,0,0,0]]*2])
148-
asyncdeftest_batch_completions_beam_search(async_client:openai.AsyncOpenAI,
149-
model_name,prompts,backend):
166+
asyncdeftest_batch_completions_beam_search(
167+
async_client_with_beam_search:openai.AsyncOpenAI,model_name,prompts):
150168
# test beam search
151-
ifbackend=='pytorch':
152-
pytest.skip("Beam search is not supported in PyTorch backend yet")
153-
batch=awaitasync_client.completions.create(
169+
batch=awaitasync_client_with_beam_search.completions.create(
154170
model=model_name,
155171
prompt=prompts,
156172
n=2,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp