Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit80c5491

Browse files
Specdec Bench: vLLM reqid, SGL path, conc > 1 metric fix (#541)
## What does this PR do?**SGLang** Fix for actually passing the draft model path to the engine**vLLM** Fix for multiturn to not overlap request_id strings**Acceptance Rate** Fix for potential race condition on multiturndatasets in writing back AR**Overview:** ?## Usage<!-- You can potentially add a usage example below. -->```python# Add a code snippet demonstrating how to use this```## Testing<!-- Mention how have you tested your change if applicable. -->## Before your PR is "*Ready for review*"<!-- If you haven't finished some of the above items you can still open`Draft` PR. -->- **Make sure you read and follow [Contributorguidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**and your commits are signed.- **Is this change backward compatible?**: Yes/No <!--- If No, explainwhy. -->- **Did you write any new necessary tests?**: Yes/No- **Did you add or update any necessary documentation?**: Yes/No- **Did you update[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:Yes/No <!--- Only for new features, API changes, critical bug fixes orbw breaking changes. -->## Additional Information<!-- E.g. related issue. -->Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
1 parent592a499 commit80c5491

File tree

12 files changed

+48
-39
lines changed

12 files changed

+48
-39
lines changed

‎examples/specdec_bench/run.py‎

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ async def process_single_request(request, i):
4949
ifrequest.system_promptisnotNone:
5050
messages.append({"role":"system","content":request.system_prompt})
5151

52-
forquestioninrequest.turns:
52+
forturn_id,questioninenumerate(request.turns):
5353
messages.append({"role":"user","content":question})
5454
entry_encoded=encode_chat(tokenizer,messages)
5555

5656
# Run the async runner.run directly
57-
output_tokens=awaitrunner.run(entry_encoded,max_length,end_id,i)
57+
output_tokens=awaitrunner.run(
58+
entry_encoded,max_length,end_id,request_id=i,turn_id=turn_id
59+
)
5860
output_text=decode_chat(tokenizer,output_tokens["output_ids"][0])
5961
output_text=postprocess(output_text)
6062
messages.append({"role":"assistant","content":output_text})

‎examples/specdec_bench/specdec_bench/metrics/aa_timing.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, base_tokenizer):
3434
self.base_tokenizer=base_tokenizer
3535
self.total_tokens= []
3636

37-
defprocess_step(self,step_outputs,new_turn=True):
37+
defprocess_step(self,step_outputs,request_id,turn_id):
3838
self.timing.append(step_outputs["token_times"])
3939
target_tokens= [
4040
tfortok_listinstep_outputs["output_ids"]fortokintok_listfortintok

‎examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py‎

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@
2222
classAcceptanceRate(Metric):
2323
def__init__(self):
2424
super().__init__()
25-
self.prompt_ar=[]
25+
self.prompt_ar={}
2626
self.name="acceptance_rate"
2727

28-
defprocess_step(self,step_outputs,new_turn=True):
29-
ifnew_turn:
30-
self.prompt_ar.append([])
28+
defprocess_step(self,step_outputs,request_id,turn_id):
29+
ifrequest_idnotinself.prompt_ar:
30+
self.prompt_ar[request_id]= {}
31+
ifturn_idnotinself.prompt_ar[request_id]:
32+
self.prompt_ar[request_id][turn_id]= []
3133
fori,beam_outputinenumerate(step_outputs["output_ids"]):
3234
foroutput_id_iterinbeam_output:
33-
self.prompt_ar[-1].append(len(output_id_iter))
35+
self.prompt_ar[request_id][turn_id].append(len(output_id_iter))
3436

3537
def_get_lengths(self,turn,lengths):
3638
forjinturn:
@@ -55,16 +57,19 @@ def _process_lengths(self, lengths):
5557
running_len-=v
5658

5759
defprocess_final(self,text_outputs):
58-
i=0
60+
all_ar=[]
5961
lengths= {}
6062
self.out["Request_AR"]= {}
61-
whilei<len(self.prompt_ar):
62-
turn_1=self.prompt_ar[i]
63-
self.out["Request_AR"][i]=sum(turn_1)/len(turn_1)
64-
self._get_lengths(turn_1,lengths)
65-
print(i,self.out["Request_AR"][i])
66-
i+=1
67-
average_ar=sum(self.out["Request_AR"].values())/len(self.out["Request_AR"])
63+
self.prompt_ar=dict(sorted(self.prompt_ar.items(),key=lambdax:x[0]))
64+
forrequest_id,turnsinself.prompt_ar.items():
65+
self.out["Request_AR"][request_id]= {}
66+
forturn_id,turninturns.items():
67+
ar=sum(turn)/len(turn)
68+
self.out["Request_AR"][request_id][turn_id]=ar
69+
all_ar.append(ar)
70+
self._get_lengths(turn,lengths)
71+
print(request_id,turn_id,self.out["Request_AR"][request_id][turn_id])
72+
average_ar=sum(all_ar)/len(all_ar)
6873
print("Average AR:",average_ar)
6974
self.out["Average_AR"]=average_ar
7075
self._process_lengths(lengths)

‎examples/specdec_bench/specdec_bench/metrics/base.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self):
2424
self.out= {}
2525
self.name="metric"
2626

27-
defprocess_step(self,step_outputs,new_turn=True):
27+
defprocess_step(self,step_outputs,request_id,turn_id):
2828
raiseNotImplementedError
2929

3030
defprocess_final(self,text_outputs):

‎examples/specdec_bench/specdec_bench/metrics/mtbench.py‎

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@ def process_final(self, text_outputs):
3535
i=0
3636
lengths= {}
3737
self.out["Request_AR"]= {}
38-
whilei<len(self.prompt_ar):
39-
turn_1=self.prompt_ar[i]
40-
turn_2=self.prompt_ar[i+1]
41-
q_id=i//2
38+
self.prompt_ar=dict(sorted(self.prompt_ar.items(),key=lambdax:x[0]))
39+
forrequest_id,turnsinself.prompt_ar.items():
40+
turn_1=turns[0]
41+
turn_2=turns[1]
42+
q_id=request_id
4243
mtbench_topic=MTBENCH_TOPICS[q_id//10]
43-
self.out["Request_AR"][q_id]=sum(turn_1+turn_2)/len(turn_1+turn_2)
44+
self.out["Request_AR"][request_id]=sum(turn_1+turn_2)/len(turn_1+turn_2)
4445
self._get_lengths(turn_1,lengths)
4546
self._get_lengths(turn_2,lengths)
4647
print(mtbench_topic,sum(turn_1+turn_2)/len(turn_1+turn_2))
47-
i+=2
4848
per_category= [[]for_inrange(len(MTBENCH_TOPICS))]
4949
forq_id,arinself.out["Request_AR"].items():
5050
per_category[q_id//10].append(ar)

‎examples/specdec_bench/specdec_bench/metrics/timing.py‎

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, tp_size):
2626
self.total_tokens= []
2727
self.tp_size=tp_size
2828

29-
defprocess_step(self,step_outputs,new_turn=True):
29+
defprocess_step(self,step_outputs,request_id,turn_id):
3030
self.timing.append(step_outputs["token_times"])
3131
self.total_tokens.append(
3232
sum([sum([len(j)forjini])foriinstep_outputs["output_ids"]])
@@ -42,8 +42,9 @@ def process_final(self, text_outputs):
4242
self.out["Output TPS"]=sum(self.total_tokens)/ (end_time-start_time)
4343
self.out["Output TPS/gpu"]=self.out["Output TPS"]/self.tp_size
4444
fortokens,timesinzip(self.total_tokens,self.timing):
45-
e2e_time.append(times[-1]-times[0])
46-
ttft_time.append(times[1]-times[0])
45+
iflen(times)>1:
46+
e2e_time.append(times[-1]-times[0])
47+
ttft_time.append(times[1]-times[0])
4748
iflen(times)>2:
4849
gen_tp_time.append((tokens-1)/ (times[-1]-times[1]))
4950
tpot_time.extend([a-bfora,binzip(times[1:],times[:-1])])

‎examples/specdec_bench/specdec_bench/models/base.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Model:
1818
def__init__(self,model_dir,tokenizer,max_draft_length):
1919
raiseNotImplementedError
2020

21-
asyncdefrun(self,prompt_ids,max_length,end_id,request_id):
21+
asyncdefrun(self,prompt_ids,max_length,end_id,request_id,turn_id):
2222
"""
2323
prompt_ids is list of tokens
2424
output is list of list of tokens

‎examples/specdec_bench/specdec_bench/models/sglang.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
speculative_num_steps=kwargs.get("speculative_num_steps",3),
5151
speculative_eagle_topk=kwargs.get("speculative_eagle_topk",1),
5252
speculative_num_draft_tokens=kwargs.get("speculative_num_draft_tokens",4),
53+
speculative_draft_model_path=kwargs.get("draft_model_dir"),
5354
torch_compile_max_bs=max_concurrent_requests,
5455
attention_backend=kwargs.get("attention_backend"),
5556
enable_torch_compile=kwargs.get("enable_torch_compile",False),
@@ -70,7 +71,7 @@ def __init__(
7071

7172
self.sampling_config=sampling_kwargs
7273

73-
asyncdefrun(self,prompt_ids,max_length,end_id,request_id):
74+
asyncdefrun(self,prompt_ids,max_length,end_id,request_id,turn_id):
7475
timing= []
7576
output_dict= {}
7677
self.sampling_config["max_new_tokens"]=max_length

‎examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
self.model=create_executor(model_path,max_concurrent_requests,kwargs)
4444
self.sampling_kwargs=sampling_kwargs
4545

46-
asyncdefrun(self,prompt_ids,max_length,end_id,request_id):
46+
asyncdefrun(self,prompt_ids,max_length,end_id,request_id,turn_id):
4747
output_dict= {}
4848
sampling_config=check_sampling_config(self.sampling_kwargs,max_length,end_id)
4949
outputs= []

‎examples/specdec_bench/specdec_bench/models/vllm.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs
8484
self.loop=asyncio.new_event_loop()
8585
asyncio.set_event_loop(self.loop)
8686

87-
asyncdefrun(self,prompt_ids,max_length,end_id,request_id):
87+
asyncdefrun(self,prompt_ids,max_length,end_id,request_id,turn_id):
8888
output_dict= {}
8989
self.sampling_config.max_tokens=max_length
9090
self.sampling_config.stop_token_ids= [end_id]
9191

92-
outputs,timing,full_tokens=awaitself.generate(prompt_ids,request_id)
92+
outputs,timing,full_tokens=awaitself.generate(prompt_ids,request_id,turn_id)
9393

9494
reformatted_output_ids= [[]for_inrange(self.sampling_kwargs.get("beam_width",1))]
9595
start=0
@@ -114,13 +114,13 @@ async def run(self, prompt_ids, max_length, end_id, request_id):
114114
]
115115
returnoutput_dict
116116

117-
asyncdefgenerate(self,prompt_ids,request_id):
117+
asyncdefgenerate(self,prompt_ids,request_id,turn_id):
118118
timing= []
119119
timing.append(time.perf_counter())
120120
outputs= []
121121
full_tokens= []
122122
asyncforoutputinself.model.generate(
123-
request_id=str(request_id),
123+
request_id=f"{request_id}.{turn_id}",
124124
prompt=TokensPrompt(prompt_token_ids=prompt_ids),
125125
sampling_params=self.sampling_config,
126126
):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp