Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitb9f00ae

Browse files
crazydemoWong4j
authored andcommitted
[None][test] Add accuracy benchmark in stress test (NVIDIA#7561)
Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
1 parenta57dcf8 commitb9f00ae

File tree

2 files changed

+201
-8
lines changed

2 files changed

+201
-8
lines changed

‎tests/integration/defs/stress_test/stress_test.py‎

Lines changed: 198 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@
1515
"""
1616
Stress test script for inference of model using TensorRT-LLM with PyTorch/TRT backend.
1717
This script is used for stress testing inference performance using trtllm-serve and genai-perf.
18+
19+
The script supports three test modes:
20+
1. "stress-test": Runs performance test followed by stress test
21+
2. "stress-stage-alone": Runs only stress test with customized parameters
22+
3. "stress-test-with-accuracy": Runs performance test, stress test, and accuracy tests (GSM8K)
23+
24+
Accuracy testing is performed using lm_eval with GSM8K dataset:
25+
- Baseline accuracy test: Run before stress test to establish baseline
26+
- Post-stress accuracy test: Run after stress test to verify accuracy stability
27+
28+
Usage example for accuracy testing:
29+
pytest tests/integration/defs/stress_test/stress_test.py::test_run_stress_test[stress-test-with-accuracy]
1830
"""
1931
importcontextlib
2032
importjson
@@ -126,6 +138,14 @@ class StressTestConfig:
126138
customized_stress_concurrency:int=128
127139
customized_stress_request_rate:int=20
128140

141+
# Accuracy test parameters
142+
enable_accuracy_test:bool=False# Enable accuracy testing with GSM8K
143+
accuracy_test_timeout:int=1200# 20 minutes timeout for accuracy tests
144+
accuracy_test_concurrency:int=512# Concurrency for accuracy tests
145+
accuracy_test_max_retries:int=3# Max retries for accuracy tests
146+
accuracy_test_max_gen_toks:int=256# Max generation tokens for accuracy tests
147+
accuracy_test_max_length:int=4096# Max input length for accuracy tests
148+
129149
@property
130150
defrequest_count_stress_test(self)->int:
131151
"""Calculate request count for stress test"""
@@ -320,8 +340,10 @@ def check_server_health(server_url: str,
320340
returnFalse,f"Unexpected error during health check:{str(e)}"
321341

322342

323-
@pytest.mark.parametrize("test_mode", ["stress-test","stress-stage-alone"],
324-
ids=lambdax:x)
343+
@pytest.mark.parametrize(
344+
"test_mode",
345+
["stress-test","stress-stage-alone","stress-test-with-accuracy"],
346+
ids=lambdax:x)
325347
@pytest.mark.parametrize("backend", ["trt","pytorch"],ids=lambdax:x)
326348
@pytest.mark.parametrize("capacity_scheduler_policy",
327349
["GUARANTEED_NO_EVICT","MAX_UTILIZATION"],
@@ -416,9 +438,14 @@ def stress_test(config,
416438
eliftest_mode=="stress-stage-alone":
417439
run_performance=False
418440
run_stress=True
441+
eliftest_mode=="stress-test-with-accuracy":
442+
run_performance=True
443+
run_stress=True
419444
else:
420-
pytest.skip(f"Skipping test for unsupported mode:{test_mode}. "
421-
f"Supported modes: stress-test, stress-stage-alone")
445+
pytest.skip(
446+
f"Skipping test for unsupported mode:{test_mode}. "
447+
f"Supported modes: stress-test, stress-stage-alone, stress-test-with-accuracy"
448+
)
422449
return
423450

424451
# Skip if not enough GPU memory
@@ -458,9 +485,9 @@ def stress_test(config,
458485
pp_size=test_server_config.pp_size,
459486
ep_size=8,# DeepSeek-V3 or DeepSeek-R1 specific ep_size
460487
max_batch_size=
461-
161,# DeepSeek-V3 or DeepSeek-R1 specific max_batch_size
488+
2048,# DeepSeek-V3 or DeepSeek-R1 specific max_batch_size
462489
max_num_tokens=
463-
1160,# DeepSeek-V3 or DeepSeek-R1 specific max_num_tokens
490+
2048,# DeepSeek-V3 or DeepSeek-R1 specific max_num_tokens
464491
kv_cache_free_gpu_memory_fraction=
465492
0.7,# DeepSeek-V3 or DeepSeek-R1 specific kv_cache fraction
466493
capacity_scheduler_policy=test_server_config.
@@ -472,8 +499,12 @@ def stress_test(config,
472499

473500
# Create a StressTestConfig with customized time parameters if provided
474501
ifrun_stress:
502+
# Enable accuracy test for stress-test-with-accuracy mode
503+
enable_accuracy= (test_mode=="stress-test-with-accuracy")
504+
475505
stress_config=StressTestConfig(model_config=config,
476-
server_config=test_server_config)
506+
server_config=test_server_config,
507+
enable_accuracy_test=enable_accuracy)
477508

478509
# Override stress_time and stress_timeout if provided
479510
ifstress_timeisnotNone:
@@ -482,7 +513,8 @@ def stress_test(config,
482513
server_config=test_server_config,
483514
stress_time=stress_time,
484515
stress_timeout=stress_timeout
485-
ifstress_timeoutisnotNoneelsestress_time*2)
516+
ifstress_timeoutisnotNoneelsestress_time*2,
517+
enable_accuracy_test=enable_accuracy)
486518
else:
487519
stress_config=None
488520

@@ -632,6 +664,12 @@ def stress_test(config,
632664
print_info(
633665
f"Server is running with model{model_name}. Starting tests...")
634666

667+
# Run baseline accuracy test first if enabled
668+
baseline_accuracy_success=True
669+
ifstress_configandstress_config.enable_accuracy_test:
670+
baseline_accuracy_success,baseline_accuracy_value=run_accuracy_test(
671+
model_path,test_server_config,stress_config,"baseline")
672+
635673
# Run performance test first if enabled
636674
stage2_output=None# Initialize stage2_output to None
637675
ifrun_performance:
@@ -664,6 +702,52 @@ def stress_test(config,
664702
stress_config,
665703
None,
666704
request_counter=request_counter)
705+
706+
# Run post-stress accuracy test if enabled
707+
post_stress_accuracy_success=True
708+
ifstress_configandstress_config.enable_accuracy_test:
709+
post_stress_accuracy_success,post_stress_accuracy_value=run_accuracy_test(
710+
model_path,test_server_config,stress_config,
711+
"post_stress")
712+
713+
# Report accuracy test results
714+
ifbaseline_accuracy_successandpost_stress_accuracy_success:
715+
print_info("=== ACCURACY TEST SUMMARY ===")
716+
print_info("✓ Baseline accuracy test: PASSED")
717+
print_info("✓ Post-stress accuracy test: PASSED")
718+
719+
# Compare accuracy values if both are available
720+
ifbaseline_accuracy_valueisnotNoneandpost_stress_accuracy_valueisnotNone:
721+
accuracy_drop=baseline_accuracy_value-post_stress_accuracy_value
722+
accuracy_drop_percentage= (
723+
accuracy_drop/baseline_accuracy_value)*100
724+
725+
print_info(
726+
f"Baseline accuracy:{baseline_accuracy_value:.4f}")
727+
print_info(
728+
f"Post-stress accuracy:{post_stress_accuracy_value:.4f}"
729+
)
730+
print_info(
731+
f"Accuracy drop:{accuracy_drop:.4f} ({accuracy_drop_percentage:.2f}%)"
732+
)
733+
734+
# Define threshold for significant accuracy drop (e.g., 5%)
735+
accuracy_drop_threshold=0.05# 5%
736+
# Assert that accuracy drop is within acceptable threshold
737+
assertaccuracy_drop_percentage<= (
738+
accuracy_drop_threshold*100
739+
),f"Accuracy drop{accuracy_drop_percentage:.2f}% exceeds threshold{accuracy_drop_threshold*100}%"
740+
print_info(
741+
"✓ Model accuracy appears stable under stress conditions"
742+
)
743+
else:
744+
print_warning("=== ACCURACY TEST SUMMARY ===")
745+
ifnotbaseline_accuracy_success:
746+
print_warning("✗ Baseline accuracy test: FAILED")
747+
ifnotpost_stress_accuracy_success:
748+
print_warning("✗ Post-stress accuracy test: FAILED")
749+
print_warning(
750+
"Model accuracy may be affected by stress conditions")
667751
finally:
668752
# Clean up temp yaml file
669753
ifos.path.exists(extra_llm_options_path):
@@ -984,6 +1068,112 @@ def format_time(seconds: int) -> str:
9841068
returnf"{seconds}s"
9851069

9861070

1071+
defparse_accuracy_from_lm_eval_output(output_text:str)->float:
1072+
"""
1073+
Parse accuracy value from lm_eval output for GSM8K flexible-extract exact_match
1074+
1075+
Args:
1076+
output_text: The output text from lm_eval command
1077+
1078+
Returns:
1079+
float: The accuracy value (0.7582 in the example)
1080+
"""
1081+
importre
1082+
1083+
# Look for the specific pattern: |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7559|± |0.0118|
1084+
patterns= [
1085+
r'flexible-extract\|\s+\d+\|exact_match\|\↑\s+\|(\d+\.\d+)',
1086+
]
1087+
1088+
forpatterninpatterns:
1089+
match=re.search(pattern,output_text)
1090+
ifmatch:
1091+
accuracy_value=float(match.group(1))
1092+
print_info(f"Extracted accuracy value:{accuracy_value}")
1093+
returnaccuracy_value
1094+
1095+
print_warning("Could not find accuracy value in lm_eval output")
1096+
print_warning(f"Output text:{output_text}")
1097+
returnNone
1098+
1099+
1100+
defrun_accuracy_test(model_path:str,
1101+
server_config:ServerConfig,
1102+
stress_config:StressTestConfig,
1103+
test_phase:str="baseline")->tuple[bool,float]:
1104+
"""
1105+
Run accuracy test using lm_eval with GSM8K dataset
1106+
1107+
Args:
1108+
model_path: Path of the model being tested
1109+
server_config: Server configuration containing URL and port
1110+
stress_config: Stress test configuration containing accuracy test parameters
1111+
test_phase: Phase of the test ("baseline" or "post_stress")
1112+
1113+
Returns:
1114+
tuple: (Boolean indicating whether the accuracy test completed successfully, accuracy value)
1115+
"""
1116+
ifnotstress_config.enable_accuracy_test:
1117+
print_info(f"Skipping accuracy test for{test_phase} phase (disabled)")
1118+
returnTrue,None
1119+
1120+
print_info(f"=== Running{test_phase.upper()} ACCURACY TEST (GSM8K) ===")
1121+
1122+
# Create lm_eval command
1123+
lm_eval_cmd= [
1124+
"lm_eval","--model","local-completions","--tasks","gsm8k",
1125+
"--model_args",
1126+
f"model={model_path},base_url={server_config.url}/v1/completions,"
1127+
f"num_concurrent={stress_config.accuracy_test_concurrency},"
1128+
f"max_retries={stress_config.accuracy_test_max_retries},"
1129+
f"tokenized_requests=False,"
1130+
f"timeout={stress_config.accuracy_test_timeout},"
1131+
f"max_gen_toks={stress_config.accuracy_test_max_gen_toks},"
1132+
f"max_length={stress_config.accuracy_test_max_length}",
1133+
"--trust_remote_code"
1134+
]
1135+
1136+
test_start_time=time.time()
1137+
accuracy_value=None
1138+
1139+
try:
1140+
# Run lm_eval process with timeout monitoring
1141+
print_info(f"Running lm_eval command:{' '.join(lm_eval_cmd)}")
1142+
1143+
# Use subprocess.run to capture output directly
1144+
result=subprocess.run(lm_eval_cmd,
1145+
capture_output=True,
1146+
text=True,
1147+
timeout=stress_config.accuracy_test_timeout)
1148+
1149+
# Check if process completed successfully
1150+
ifresult.returncode==0:
1151+
test_end_time=time.time()
1152+
duration=int(test_end_time-test_start_time)
1153+
print_info(
1154+
f"{test_phase.capitalize()} accuracy test completed successfully in{format_time(duration)}"
1155+
)
1156+
1157+
# Parse accuracy value from output
1158+
output_text=result.stdout
1159+
accuracy_value=parse_accuracy_from_lm_eval_output(output_text)
1160+
returnTrue,accuracy_value
1161+
else:
1162+
print_warning(
1163+
f"lm_eval exited with non-zero code:{result.returncode}")
1164+
print_warning(f"stderr:{result.stderr}")
1165+
returnFalse,None
1166+
1167+
exceptsubprocess.TimeoutExpired:
1168+
print_warning(
1169+
f"Accuracy test timed out after{stress_config.accuracy_test_timeout} seconds"
1170+
)
1171+
returnFalse,None
1172+
exceptExceptionase:
1173+
print_warning(f"Error during{test_phase} accuracy test:{str(e)}")
1174+
returnFalse,None
1175+
1176+
9871177
defextract_stress_test_metrics(artifacts_dir="./artifacts",
9881178
current_model=None):
9891179
"""
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
stress_test/stress_test.py::test_run_stress_test[DeepSeek-V3_tp8-stress_time_3600s_timeout_5400s-GUARANTEED_NO_EVICT-pytorch-stress-test-with-accuracy]
2+
stress_test/stress_test.py::test_run_stress_test[DeepSeek-V3_tp8-stress_time_3600s_timeout_5400s-MAX_UTILIZATION-pytorch-stress-test-with-accuracy]
3+
stress_test/stress_test.py::test_run_stress_test[DeepSeek-R1_tp8-stress_time_3600s_timeout_5400s-MAX_UTILIZATION-pytorch-stress-test-with-accuracy]

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp