1515"""
1616Stress test script for inference of model using TensorRT-LLM with PyTorch/TRT backend.
1717This script is used for stress testing inference performance using trtllm-serve and genai-perf.
18+
19+ The script supports three test modes:
20+ 1. "stress-test": Runs performance test followed by stress test
21+ 2. "stress-stage-alone": Runs only stress test with customized parameters
22+ 3. "stress-test-with-accuracy": Runs performance test, stress test, and accuracy tests (GSM8K)
23+
24+ Accuracy testing is performed using lm_eval with GSM8K dataset:
25+ - Baseline accuracy test: Run before stress test to establish baseline
26+ - Post-stress accuracy test: Run after stress test to verify accuracy stability
27+
28+ Usage example for accuracy testing:
29+ pytest tests/integration/defs/stress_test/stress_test.py::test_run_stress_test[stress-test-with-accuracy]
1830"""
1931import contextlib
2032import json
@@ -126,6 +138,14 @@ class StressTestConfig:
126138customized_stress_concurrency :int = 128
127139customized_stress_request_rate :int = 20
128140
141+ # Accuracy test parameters
142+ enable_accuracy_test :bool = False # Enable accuracy testing with GSM8K
143+ accuracy_test_timeout :int = 1200 # 20 minutes timeout for accuracy tests
144+ accuracy_test_concurrency :int = 512 # Concurrency for accuracy tests
145+ accuracy_test_max_retries :int = 3 # Max retries for accuracy tests
146+ accuracy_test_max_gen_toks :int = 256 # Max generation tokens for accuracy tests
147+ accuracy_test_max_length :int = 4096 # Max input length for accuracy tests
148+
129149@property
130150def request_count_stress_test (self )-> int :
131151"""Calculate request count for stress test"""
@@ -320,8 +340,10 @@ def check_server_health(server_url: str,
320340return False ,f"Unexpected error during health check:{ str (e )} "
321341
322342
323- @pytest .mark .parametrize ("test_mode" , ["stress-test" ,"stress-stage-alone" ],
324- ids = lambda x :x )
343+ @pytest .mark .parametrize (
344+ "test_mode" ,
345+ ["stress-test" ,"stress-stage-alone" ,"stress-test-with-accuracy" ],
346+ ids = lambda x :x )
325347@pytest .mark .parametrize ("backend" , ["trt" ,"pytorch" ],ids = lambda x :x )
326348@pytest .mark .parametrize ("capacity_scheduler_policy" ,
327349 ["GUARANTEED_NO_EVICT" ,"MAX_UTILIZATION" ],
@@ -416,9 +438,14 @@ def stress_test(config,
416438elif test_mode == "stress-stage-alone" :
417439run_performance = False
418440run_stress = True
441+ elif test_mode == "stress-test-with-accuracy" :
442+ run_performance = True
443+ run_stress = True
419444else :
420- pytest .skip (f"Skipping test for unsupported mode:{ test_mode } . "
421- f"Supported modes: stress-test, stress-stage-alone" )
445+ pytest .skip (
446+ f"Skipping test for unsupported mode:{ test_mode } . "
447+ f"Supported modes: stress-test, stress-stage-alone, stress-test-with-accuracy"
448+ )
422449return
423450
424451# Skip if not enough GPU memory
@@ -458,9 +485,9 @@ def stress_test(config,
458485pp_size = test_server_config .pp_size ,
459486ep_size = 8 ,# DeepSeek-V3 or DeepSeek-R1 specific ep_size
460487max_batch_size =
461- 161 ,# DeepSeek-V3 or DeepSeek-R1 specific max_batch_size
488+ 2048 ,# DeepSeek-V3 or DeepSeek-R1 specific max_batch_size
462489max_num_tokens =
463- 1160 ,# DeepSeek-V3 or DeepSeek-R1 specific max_num_tokens
490+ 2048 ,# DeepSeek-V3 or DeepSeek-R1 specific max_num_tokens
464491kv_cache_free_gpu_memory_fraction =
4654920.7 ,# DeepSeek-V3 or DeepSeek-R1 specific kv_cache fraction
466493capacity_scheduler_policy = test_server_config .
@@ -472,8 +499,12 @@ def stress_test(config,
472499
473500# Create a StressTestConfig with customized time parameters if provided
474501if run_stress :
502+ # Enable accuracy test for stress-test-with-accuracy mode
503+ enable_accuracy = (test_mode == "stress-test-with-accuracy" )
504+
475505stress_config = StressTestConfig (model_config = config ,
476- server_config = test_server_config )
506+ server_config = test_server_config ,
507+ enable_accuracy_test = enable_accuracy )
477508
478509# Override stress_time and stress_timeout if provided
479510if stress_time is not None :
@@ -482,7 +513,8 @@ def stress_test(config,
482513server_config = test_server_config ,
483514stress_time = stress_time ,
484515stress_timeout = stress_timeout
485- if stress_timeout is not None else stress_time * 2 )
516+ if stress_timeout is not None else stress_time * 2 ,
517+ enable_accuracy_test = enable_accuracy )
486518else :
487519stress_config = None
488520
@@ -632,6 +664,12 @@ def stress_test(config,
632664print_info (
633665f"Server is running with model{ model_name } . Starting tests..." )
634666
667+ # Run baseline accuracy test first if enabled
668+ baseline_accuracy_success = True
669+ if stress_config and stress_config .enable_accuracy_test :
670+ baseline_accuracy_success ,baseline_accuracy_value = run_accuracy_test (
671+ model_path ,test_server_config ,stress_config ,"baseline" )
672+
635673# Run performance test first if enabled
636674stage2_output = None # Initialize stage2_output to None
637675if run_performance :
@@ -664,6 +702,52 @@ def stress_test(config,
664702stress_config ,
665703None ,
666704request_counter = request_counter )
705+
706+ # Run post-stress accuracy test if enabled
707+ post_stress_accuracy_success = True
708+ if stress_config and stress_config .enable_accuracy_test :
709+ post_stress_accuracy_success ,post_stress_accuracy_value = run_accuracy_test (
710+ model_path ,test_server_config ,stress_config ,
711+ "post_stress" )
712+
713+ # Report accuracy test results
714+ if baseline_accuracy_success and post_stress_accuracy_success :
715+ print_info ("=== ACCURACY TEST SUMMARY ===" )
716+ print_info ("✓ Baseline accuracy test: PASSED" )
717+ print_info ("✓ Post-stress accuracy test: PASSED" )
718+
719+ # Compare accuracy values if both are available
720+ if baseline_accuracy_value is not None and post_stress_accuracy_value is not None :
721+ accuracy_drop = baseline_accuracy_value - post_stress_accuracy_value
722+ accuracy_drop_percentage = (
723+ accuracy_drop / baseline_accuracy_value )* 100
724+
725+ print_info (
726+ f"Baseline accuracy:{ baseline_accuracy_value :.4f} " )
727+ print_info (
728+ f"Post-stress accuracy:{ post_stress_accuracy_value :.4f} "
729+ )
730+ print_info (
731+ f"Accuracy drop:{ accuracy_drop :.4f} ({ accuracy_drop_percentage :.2f} %)"
732+ )
733+
734+ # Define threshold for significant accuracy drop (e.g., 5%)
735+ accuracy_drop_threshold = 0.05 # 5%
736+ # Assert that accuracy drop is within acceptable threshold
737+ assert accuracy_drop_percentage <= (
738+ accuracy_drop_threshold * 100
739+ ),f"Accuracy drop{ accuracy_drop_percentage :.2f} % exceeds threshold{ accuracy_drop_threshold * 100 } %"
740+ print_info (
741+ "✓ Model accuracy appears stable under stress conditions"
742+ )
743+ else :
744+ print_warning ("=== ACCURACY TEST SUMMARY ===" )
745+ if not baseline_accuracy_success :
746+ print_warning ("✗ Baseline accuracy test: FAILED" )
747+ if not post_stress_accuracy_success :
748+ print_warning ("✗ Post-stress accuracy test: FAILED" )
749+ print_warning (
750+ "Model accuracy may be affected by stress conditions" )
667751finally :
668752# Clean up temp yaml file
669753if os .path .exists (extra_llm_options_path ):
@@ -984,6 +1068,112 @@ def format_time(seconds: int) -> str:
9841068return f"{ seconds } s"
9851069
9861070
1071+ def parse_accuracy_from_lm_eval_output (output_text :str )-> float :
1072+ """
1073+ Parse accuracy value from lm_eval output for GSM8K flexible-extract exact_match
1074+
1075+ Args:
1076+ output_text: The output text from lm_eval command
1077+
1078+ Returns:
1079+ float: The accuracy value (0.7582 in the example)
1080+ """
1081+ import re
1082+
1083+ # Look for the specific pattern: |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7559|± |0.0118|
1084+ patterns = [
1085+ r'flexible-extract\|\s+\d+\|exact_match\|\↑\s+\|(\d+\.\d+)' ,
1086+ ]
1087+
1088+ for pattern in patterns :
1089+ match = re .search (pattern ,output_text )
1090+ if match :
1091+ accuracy_value = float (match .group (1 ))
1092+ print_info (f"Extracted accuracy value:{ accuracy_value } " )
1093+ return accuracy_value
1094+
1095+ print_warning ("Could not find accuracy value in lm_eval output" )
1096+ print_warning (f"Output text:{ output_text } " )
1097+ return None
1098+
1099+
1100+ def run_accuracy_test (model_path :str ,
1101+ server_config :ServerConfig ,
1102+ stress_config :StressTestConfig ,
1103+ test_phase :str = "baseline" )-> tuple [bool ,float ]:
1104+ """
1105+ Run accuracy test using lm_eval with GSM8K dataset
1106+
1107+ Args:
1108+ model_path: Path of the model being tested
1109+ server_config: Server configuration containing URL and port
1110+ stress_config: Stress test configuration containing accuracy test parameters
1111+ test_phase: Phase of the test ("baseline" or "post_stress")
1112+
1113+ Returns:
1114+ tuple: (Boolean indicating whether the accuracy test completed successfully, accuracy value)
1115+ """
1116+ if not stress_config .enable_accuracy_test :
1117+ print_info (f"Skipping accuracy test for{ test_phase } phase (disabled)" )
1118+ return True ,None
1119+
1120+ print_info (f"=== Running{ test_phase .upper ()} ACCURACY TEST (GSM8K) ===" )
1121+
1122+ # Create lm_eval command
1123+ lm_eval_cmd = [
1124+ "lm_eval" ,"--model" ,"local-completions" ,"--tasks" ,"gsm8k" ,
1125+ "--model_args" ,
1126+ f"model={ model_path } ,base_url={ server_config .url } /v1/completions,"
1127+ f"num_concurrent={ stress_config .accuracy_test_concurrency } ,"
1128+ f"max_retries={ stress_config .accuracy_test_max_retries } ,"
1129+ f"tokenized_requests=False,"
1130+ f"timeout={ stress_config .accuracy_test_timeout } ,"
1131+ f"max_gen_toks={ stress_config .accuracy_test_max_gen_toks } ,"
1132+ f"max_length={ stress_config .accuracy_test_max_length } " ,
1133+ "--trust_remote_code"
1134+ ]
1135+
1136+ test_start_time = time .time ()
1137+ accuracy_value = None
1138+
1139+ try :
1140+ # Run lm_eval process with timeout monitoring
1141+ print_info (f"Running lm_eval command:{ ' ' .join (lm_eval_cmd )} " )
1142+
1143+ # Use subprocess.run to capture output directly
1144+ result = subprocess .run (lm_eval_cmd ,
1145+ capture_output = True ,
1146+ text = True ,
1147+ timeout = stress_config .accuracy_test_timeout )
1148+
1149+ # Check if process completed successfully
1150+ if result .returncode == 0 :
1151+ test_end_time = time .time ()
1152+ duration = int (test_end_time - test_start_time )
1153+ print_info (
1154+ f"{ test_phase .capitalize ()} accuracy test completed successfully in{ format_time (duration )} "
1155+ )
1156+
1157+ # Parse accuracy value from output
1158+ output_text = result .stdout
1159+ accuracy_value = parse_accuracy_from_lm_eval_output (output_text )
1160+ return True ,accuracy_value
1161+ else :
1162+ print_warning (
1163+ f"lm_eval exited with non-zero code:{ result .returncode } " )
1164+ print_warning (f"stderr:{ result .stderr } " )
1165+ return False ,None
1166+
1167+ except subprocess .TimeoutExpired :
1168+ print_warning (
1169+ f"Accuracy test timed out after{ stress_config .accuracy_test_timeout } seconds"
1170+ )
1171+ return False ,None
1172+ except Exception as e :
1173+ print_warning (f"Error during{ test_phase } accuracy test:{ str (e )} " )
1174+ return False ,None
1175+
1176+
9871177def extract_stress_test_metrics (artifacts_dir = "./artifacts" ,
9881178current_model = None ):
9891179"""