@@ -68,6 +68,8 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
6868temp_extra_llm_api_options_file :str ,num_postprocess_workers :int ):
6969model_path = get_model_path (model_name )
7070args = ["--backend" ,f"{ backend } " ]
71+ args .extend (["--kv_cache_free_gpu_memory_fraction" ,
72+ "0.2" ])# for co-existence with other servers
7173if backend == "trt" :
7274args .extend (["--max_beam_width" ,"4" ])
7375if extra_llm_api_options :
@@ -78,11 +80,34 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
7880yield remote_server
7981
8082
83+ @pytest .fixture (scope = "module" )
84+ def server_with_beam_search (model_name :str ,backend :str ,
85+ extra_llm_api_options :bool ,
86+ temp_extra_llm_api_options_file :str ,
87+ num_postprocess_workers :int ):
88+ model_path = get_model_path (model_name )
89+ args = ["--backend" ,f"{ backend } " ]
90+ args .extend (["--kv_cache_free_gpu_memory_fraction" ,
91+ "0.2" ])# for co-existence with other servers
92+ args .extend (["--max_beam_width" ,"2" ])
93+ if extra_llm_api_options :
94+ args .extend (
95+ ["--extra_llm_api_options" ,temp_extra_llm_api_options_file ])
96+ args .extend (["--num_postprocess_workers" ,f"{ num_postprocess_workers } " ])
97+ with RemoteOpenAIServer (model_path ,args )as remote_server :
98+ yield remote_server
99+
100+
81101@pytest .fixture (scope = "module" )
82102def client (server :RemoteOpenAIServer ):
83103return server .get_client ()
84104
85105
106+ @pytest .fixture (scope = "module" )
107+ def client_with_beam_search (server_with_beam_search :RemoteOpenAIServer ):
108+ return server_with_beam_search .get_client ()
109+
110+
86111@pytest .fixture (scope = "module" )
87112def async_client (server :RemoteOpenAIServer ):
88113return server .get_async_client ()
@@ -180,7 +205,33 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
180205backend :str ):
181206if backend == "pytorch" :
182207pytest .skip (
183- "Multiple responses are not supported in PyTorch backend yet" )
208+ "'n' not allowed with temperature=0 unless TLLM_ALLOW_N_GREEDY_DECODING=1"
209+ )
210+ messages = [{
211+ "role" :"system" ,
212+ "content" :"you are a helpful assistant"
213+ }, {
214+ "role" :"user" ,
215+ "content" :"what is 1+1?"
216+ }]
217+ # test n and best_of
218+ chat_completion = client .chat .completions .create (
219+ model = model_name ,
220+ messages = messages ,
221+ max_completion_tokens = 10 ,
222+ n = 2 ,
223+ temperature = 0.0 ,
224+ extra_body = dict (best_of = 4 ),
225+ )
226+ assert len (chat_completion .choices )== 2
227+
228+
229+ def test_multiple_responses_and_beam_search (client :openai .OpenAI ,
230+ model_name :str ,backend :str ):
231+ if backend == "pytorch" :
232+ pytest .skip (
233+ "Mixing beam search and regular requests is not supported in PyTorch backend"
234+ )
184235
185236messages = [{
186237"role" :"system" ,
@@ -202,6 +253,7 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
202253assert chat_completion .choices [
2032540 ].message .content != chat_completion .choices [
2042551 ].message .content ,"beam search should be different"
256+
205257# test n and best_of
206258chat_completion = client .chat .completions .create (
207259model = model_name ,
@@ -214,6 +266,30 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
214266assert len (chat_completion .choices )== 2
215267
216268
269+ def test_multiple_responses_with_beam_search (
270+ client_with_beam_search :openai .OpenAI ,model_name :str ):
271+ messages = [{
272+ "role" :"system" ,
273+ "content" :"you are a helpful assistant"
274+ }, {
275+ "role" :"user" ,
276+ "content" :"what is 1+1?"
277+ }]
278+ # test beam search
279+ chat_completion = client_with_beam_search .chat .completions .create (
280+ model = model_name ,
281+ messages = messages ,
282+ max_completion_tokens = 10 ,
283+ n = 2 ,
284+ temperature = 0.0 ,
285+ extra_body = dict (use_beam_search = True ),
286+ )
287+ assert len (chat_completion .choices )== 2
288+ assert chat_completion .choices [
289+ 0 ].message .content != chat_completion .choices [
290+ 1 ].message .content ,"beam search should be different"
291+
292+
217293@pytest .mark .asyncio (loop_scope = "module" )
218294async def test_chat_streaming (async_client :openai .AsyncOpenAI ,
219295model_name :str ):