- Notifications
You must be signed in to change notification settings - Fork19
Official codebase for "Can 1B LLM Surpass 405B LLM? Rethinking Compute-Optimal Test-Time Scaling".
License
RyanLiu112/compute-optimal-tts
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
- [2025-02-14] ✨ Code is now available.
- [2025-02-12] 📢 Our work is reported by bothQbitAI (量子位) andAI Era (新智元).
- [2025-02-12] 🏅 Our paper ranked#1 onHuggingFace Daily Papers.
- [2025-02-11] 📄 Our paper is released onarXiv.
Clone the repository:
git clone https://github.com/RyanLiu112/compute-optimal-tts.gitcd compute-optimal-tts/src
Create a new conda environment and install the dependencies:
conda create -n tts python=3.10conda activate ttspip install -r requirements.txtpip install flash-attn --no-build-isolationpip install"ray[default]==2.38.0"pip install"fschat[model_worker,webui]"pip install sympy==1.12cd envs/MATH/latex2sympypip install -e.
Installtmux
for serving policy models and PRMs:
sudo apt-get updatesudo apt-get install tmux
Note
Our mathematical expression evaluation code is based onQwen2.5-Math. For a more powerful evaluator, please refer to this repository:Math-Verify.
Llama series (Instruct):
Qwen series (Instruct):
DeepSeek-R1-Distill series:
- Math-Shepherd:Math-Shepherd-PRM-7B
- RLHFlow:RLHFlow-PRM-Mistral-8B,RLHFlow-PRM-Deepseek-8B
- Skywork:Skywork-PRM-1.5B,Skywork-PRM-7B
- Qwen2.5-Math:Qwen2.5-Math-PRM-7B,Qwen2.5-Math-PRM-72B
Policy Model | PRM | GPU |
---|---|---|
0.5B-14B | 1.5B-8B | 1x A100 80GB |
32B | 1.5B-8B | 2x A100 80GB |
72B | 1.5B-8B | 3x A100 80GB |
0.5B-32B | 72B | 3x A100 80GB |
72B | 72B | 4x A100 80GB |
Set the environment variables:
cd srcexport VALUE_MODEL_PATH=path/to/RM# dummy for CoTexport POLICY_MODEL_PATH=path/to/LM&&export LOGDIR=path/to/logdirexport HOST_ADDR=0.0.0.0&&export CONTROLLER_PORT=10014&&export WORKER_BASE_PORT=10081
Run the corresponding script:
# 1 gpubash scripts/serve_gpu1.sh$POLICY_MODEL_PATH$VALUE_MODEL_PATH$HOST_ADDR$CONTROLLER_PORT$WORKER_BASE_PORT# 2 gpus (32B policy model + 1.5B-8B PRM)bash scripts/serve_gpu2.sh$POLICY_MODEL_PATH$VALUE_MODEL_PATH$HOST_ADDR$CONTROLLER_PORT$WORKER_BASE_PORT# 3 gpus (72B policy model + 1.5B-8B PRM)bash scripts/serve_gpu3_1-2.sh$POLICY_MODEL_PATH$VALUE_MODEL_PATH$HOST_ADDR$CONTROLLER_PORT$WORKER_BASE_PORT# 3 gpus (0.5B-32B policy model + 72B PRM)bash scripts/serve_gpu3_2-1.sh$POLICY_MODEL_PATH$VALUE_MODEL_PATH$HOST_ADDR$CONTROLLER_PORT$WORKER_BASE_PORT# 4 gpus (72B policy model + 72B PRM)bash scripts/serve_gpu4.sh$POLICY_MODEL_PATH$VALUE_MODEL_PATH$HOST_ADDR$CONTROLLER_PORT$WORKER_BASE_PORT
We provide the following commands for different TTS methods.
cd srcbash scripts/run.sh --method cot --LM$POLICY_MODEL_PATH --RM dummy --width 1 --num_seq 1
Note
Configuring batch size for BoN and DVTS:For instance, when running BoN on MATH-500, it processes 500 problems with each executing 256 times (determined bynum_q
). To enhance the compute efficiency, it is recommended to distribute the problems across multiple GPUs by adjusting thebatch size
(bs). For example, set bs to 500 for 256 GPUs or 16000 for 8 GPUs.
cd srcbash scripts/run.sh --method best_of_n --LM$POLICY_MODEL_PATH --RM$VALUE_MODEL_PATH --width 1 --num_seq 1 --num_q 256 --bs batch_size
cd srcbash scripts/run.sh --method beam_search --LM$POLICY_MODEL_PATH --RM$VALUE_MODEL_PATH --width 4 --num_seq 1
cd srcbash scripts/run.sh --method beam_search --LM$POLICY_MODEL_PATH --RM$VALUE_MODEL_PATH --width 4 --num_seq 1 --num_q 64 --bs batch_size
For BoN and DVTS, no average result is computed by default. To compute the average, aggregate themajority_vote
values from all jsonl files after processing all problemsnum_q
times.
If you find this work helpful, please kindly cite our paper:
@article{liu2025can,title ={Can 1B LLM Surpass 405B LLM? Rethinking Compute-Optimal Test-Time Scaling},author ={Runze Liu and Junqi Gao and Jian Zhao and Kaiyan Zhang and Xiu Li and Biqing Qi and Wanli Ouyang and Bowen Zhou},journal ={arXiv preprint arXiv:2502.06703},year ={2025}}
Our code is largely based onOpenR, an awesome LLM reasoning repository, and their work has been instrumental in our study. Our mathematical expression evaluation code is based onQwen2.5-Math. We also want to thank the community for providing high-quality open-source PRMs, includingQwen2.5-Math,Skywork-o1,RLHFlow, andMath-Shepherd.