Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitba0e1e7

Browse files
authored
Support Soft Adaptive Policy Optimization (#6766)
1 parent8a2c221 commitba0e1e7

File tree

16 files changed

+376
-13
lines changed

16 files changed

+376
-13
lines changed

‎docs/resources/sapo_tau.png‎

664 KB
Loading

‎docs/source/Instruction/Command-line-parameters.md‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ reward模型参数将在PPO、GRPO中使用。
563563
- dataset_shuffle: 是否对dataset进行随机操作,默认为True。
564564
- truncation_strategy: 对输入长度超过`max_length`的处理方式,支持`delete``left`,代表删除、左侧裁剪,默认为`left`, 注意对于多模态模型,
565565
左裁剪可能会裁剪掉多模态token导致模型前向报错shape mismatch。使用`delete`方式,对于超长数据和编码失败的样例会在原数据集中重采样其他数据作为补充。
566-
- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo', 'dapo', 'cispo'], 默认为'grpo', 具体参考[文档](./GRPO/DeveloperGuide/loss_types.md)
566+
- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo', 'dapo', 'cispo', 'sapo'], 默认为'grpo', 具体参考[文档](./GRPO/DeveloperGuide/loss_types.md)
567567
- log_completions: 是否记录训练中的模型生成内容,搭配`--report_to wandb/swanlab` 使用。默认为False。
568568
- 提示:若没有设置`--report_to wandb/swanlab`,则会在checkpoint中创建`completions.jsonl`来存储生成内容。
569569
- use_vllm: 是否使用 vLLM 作为 GRPO 生成的 infer_backend,默认为False。
@@ -592,6 +592,8 @@ reward模型参数将在PPO、GRPO中使用。
592592
- num_iterations: 每条数据的更新次数,[GRPO论文](https://arxiv.org/abs/2402.03300)中的 $\mu$ 值,默认为1。
593593
- epsilon: clip 系数,默认为0.2。
594594
- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。
595+
- tau_pos:[SAPO](https://arxiv.org/abs/2511.20347)算法中正优势的温度参数,控制软门控函数的锐度。较大值使门控更锐利(接近硬裁剪),较小值使门控更平滑。默认为1.0。
596+
- tau_neg: SAPO算法中负优势的温度参数,控制软门控函数的锐度。通常设置`tau_neg > tau_pos`以对负优势施加更强约束。默认为1.05。
595597
- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。
596598
- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。
597599
- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#Soft Adaptive Policy Optimization
2+
3+
**版本依赖**:ms-swift>=3.11
4+
5+
[Soft Adaptive Policy Optimization (SAPO)](https://arxiv.org/abs/2511.20347) 针对 GRPO 中硬裁剪(hard clipping)带来的问题,提出了一种基于温度控制的软门控(soft gate)机制,用于平滑地衰减离策略更新,同时保留有用的学习信号。
6+
7+
##背景与动机
8+
9+
在强化学习训练 LLM 时,GRPO 通过计算 token 级别的重要性采样比(Importance Sampling Ratio)来处理 off-policy 训练:
10+
11+
$$
12+
r_t = \frac{\pi_\theta(y_t|x, y_{<t})}{\pi_{\theta_{\mathrm{old}}}(y_t|x, y_{<t})}
13+
$$
14+
15+
然而,token 级别的重要性采样比往往表现出高方差,这一现象在以下情况下可能更严重:
16+
-**长文本生成**
17+
-**MoE 模型的路由异质性**:采样时的 old-policy 模型与训练模型可能使用不同的专家路由,导致 logps 差异显著放大
18+
19+
为此,GRPO 通过硬裁剪来限制策略更新的幅度:
20+
21+
$$
22+
L^{\mathrm{GRPO}} = -\min\left( r_t \cdot A, \mathrm{clip}(r_t, 1-\epsilon, 1+\epsilon) \cdot A \right)
23+
$$
24+
25+
**硬裁剪的困境**:硬裁剪难以在稳定性和学习效率之间取得平衡——裁剪范围过严格会限制有效样本的数量,而过宽松则会引入离策略样本的噪声梯度,导致训练不稳定。
26+
27+
##SAPO 方法
28+
29+
SAPO 使用温度控制的 sigmoid 软门控函数替代硬裁剪,实现平滑的梯度衰减。
30+
31+
###软门控函数
32+
33+
SAPO 的核心是使用 sigmoid 函数对重要性采样比进行软门控:
34+
35+
对于正向优势($A > 0$),使用正向门控:
36+
37+
$$
38+
g^{+}_t = \sigma\left( \tau_{\mathrm{pos}} \cdot (r_t - 1) \right)
39+
$$
40+
41+
对于负向优势($A < 0$),使用负向门控:
42+
43+
$$
44+
g^{-}_t = \sigma\left( \tau_{\mathrm{neg}} \cdot (r_t - 1) \right)
45+
$$
46+
47+
其中:
48+
- $\sigma(\cdot)$ 是 sigmoid 函数
49+
- $\tau_{\mathrm{pos}}$ 和 $\tau_{\mathrm{neg}}$ 是温度参数,控制门控函数的斜率
50+
- $r_t$ 是重要性采样比
51+
52+
###SAPO 损失函数
53+
54+
$$
55+
L^{\mathrm{SAPO}} = -g_t \cdot A
56+
$$
57+
58+
其中 $g_t = g^{+}_t$ 当 $A > 0$,$g_t = g^{-}_t$ 当 $A < 0$。
59+
60+
###温度参数
61+
62+
温度参数 $\tau$ 控制软门控函数的衰减速率,数值越大,衰减越快。
63+
64+
![tau curve](../../../../resources/sapo_tau.png)
65+
66+
论文指出正向优势会提升采样token的logit,并降低所有未采样token的logit;负向优势相反,提高许多未采样token的logit,可能会扩散到大量无关token上,带来一定的不稳定性。所以论文推荐设置温度 $\tau_\text{neg} > \tau_\text{pos}$,来使负向奖励的token梯度衰减更快,提升训练的稳定性和性能。
67+
68+
论文默认推荐 $\tau_{\mathrm{pos}} = 1.0$,$\tau_{\mathrm{neg}} = 1.05$。
69+
70+
##参数设置
71+
72+
| 参数| 类型| 默认值| 说明|
73+
|------|------|--------|------|
74+
|`--loss_type`|`str`| -| 设置为`sapo`|
75+
|`--tau_pos`|`float`|`1.0`| 正向优势的温度参数,控制门控斜率|
76+
|`--tau_neg`|`float`|`1.05`| 负向优势的温度参数,控制门控斜率|
77+
78+
```bash
79+
swift rlhf \
80+
--rlhf_type grpo \
81+
--loss_type sapo \
82+
--tau_pos 1.0 \
83+
--tau_neg 1.05 \
84+
# ... 其他参数
85+
```
86+
87+
训练脚本参考
88+
89+
-[swift](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/sapo.sh)
90+
-[megatron swift](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/grpo/sapo.sh)
91+
92+
>SAPO 的软门控机制仅在 off-policy 训练下生效。
93+
>SAPO 中的重要性采样粒度为 token 级别(即 importance_sampling_level 默认设置为 token),与 GSPO 冲突。

‎docs/source/Instruction/GRPO/AdvancedResearch/index.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ Advanced Research
1212
CHORD.md
1313
CISPO.md
1414
training_inference_mismatch.md
15+
SAPO.md

‎docs/source/Instruction/GRPO/DeveloperGuide/loss_types.md‎

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#Loss Types
22

3-
GRPO训练支持五种不同的loss类型,主要区别在于归一化的维度上有所不同
3+
GRPO训练支持多种不同的loss类型,主要区别在于归一化的维度和梯度处理方式上有所不同
44

55
##损失函数
66

@@ -12,11 +12,18 @@ $$\mathcal{L}_{i,t} = -\min\left(\rho_{i,t} A_{i,t}, \text{clip}(\rho_{i,t}, 1-\
1212

1313
$$\mathcal{L}_{i,t}^{\text{CISPO}} = -\text{detach}\left(\min(\rho_{i,t}, \epsilon_{\text{high}})\right) \cdot A_{i,t} \cdot \log \pi_\theta(y_{i,t}|y_{i,<t})$$
1414

15+
当设置`loss_type sapo`时,使用软门控替代硬裁剪,详见[SAPO](../AdvancedResearch/SAPO.md)
16+
17+
$$\mathcal{L}_{i,t}^{\text{SAPO}} = -g_{i,t} \cdot A_{i,t}$$
18+
19+
其中 $g_{i,t} = \sigma(\tau \cdot (\rho_{i,t} - 1))$ 是温度控制的软门控函数。
20+
1521
其中:
1622
- $\rho_{i,t} = \frac{\pi_\theta(y_{i,t}|y_{i,<t})}{\pi_{\theta_{\text{old}}}(y_{i,t}|y_{i,<t})}$ 是重要性采样权重
1723
- $A_{i,t}$ 是优势函数
1824
- $\epsilon$ 和 $\epsilon_{\text{high}}$ 是clipping参数
1925
- $\text{detach}(\cdot)$ 表示该项不参与梯度计算
26+
- $\sigma(\cdot)$ 是 sigmoid 函数,$\tau$ 是温度参数
2027

2128
##GRPO
2229

@@ -100,3 +107,11 @@ $$\mathcal{L}_{\text{DAPO}} = \frac{\sum_{i=1}^{N} \sum_{t=1}^{T_i} \mathcal{L}_
100107
- $N_p$ 是第$p$个进程的样本数量
101108

102109
**归一化维度:** 全局token维度(跨所有进程的completion token总数)
110+
111+
##SAPO
112+
113+
`--loss_type sapo`
114+
115+
SAPO使用温度控制的软门控替代硬裁剪,实现平滑的梯度衰减。归一化方式与GRPO相同。
116+
117+
详细说明请参考[SAPO](../AdvancedResearch/SAPO.md)

‎docs/source_en/Instruction/Command-line-parameters.md‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ The meanings of the following parameters can be referenced [here](https://huggin
574574
- reward_model_plugin: The logic for the reward model, which defaults to ORM logic. For more information, please refer to[Customized Reward Models](./GRPO/DeveloperGuide/reward_model.md#custom-reward-model).
575575
- dataset_shuffle: Whether to shuffle the dataset randomly. Default is True.
576576
- truncation_strategy: The method to handle inputs exceeding`max_length`. Supported values are`delete` and`left`, representing deletion and left-side truncation respectively. The default is`left`. Note that for multi-modal models, left-side truncation may remove multi-modal tokens and cause a shape mismatch error during model forward. With the delete strategy, over-long or encoding-failed samples are discarded, and new samples are resampled from the original dataset to maintain the intended batch size.
577-
- loss_type: The type of loss normalization. Options are['grpo', 'bnpo', 'dr_grpo', 'dapo', 'cispo'], default is 'grpo'. For details, refer to this[doc](./GRPO/DeveloperGuide/loss_types.md)
577+
- loss_type: The type of loss normalization. Options are['grpo', 'bnpo', 'dr_grpo', 'dapo', 'cispo', 'sapo'], default is 'grpo'. For details, refer to this[doc](./GRPO/DeveloperGuide/loss_types.md)
578578
- log_completions: Whether to log the model-generated content during training, to be used in conjunction with`--report_to wandb/swanlab`, default is False.
579579
- Note: If`--report_to wandb/swanlab` is not set, a`completions.jsonl` will be created in the checkpoint to store the generated content.
580580
- use_vllm: Whether to use vLLM as the infer_backend for GRPO generation, default is False.
@@ -605,6 +605,8 @@ The meanings of the following parameters can be referenced [here](https://huggin
605605
- num_iterations: The number of updates per data sample, corresponding to the $\mu$ value in the GRPO paper. Default is 1.
606606
- epsilon: epsilon value for clipping. Default is 0.2.
607607
- epsilon_high: Upper clip coefficient, default is None. When set, it forms a clipping range of[epsilon, epsilon_high] together with epsilon.
608+
- tau_pos: Temperature parameter for positive advantages in[SAPO](https://arxiv.org/abs/2511.20347) algorithm, controlling the sharpness of the soft gating function. Larger values make the gate sharper (closer to hard clipping), smaller values make it smoother. Default is 1.0.
609+
- tau_neg: Temperature parameter for negative advantages in SAPO algorithm, controlling the sharpness of the soft gating function. Typically set`tau_neg > tau_pos` to apply stronger constraints on negative advantages. Default is 1.05.
608610
- dynamic_sample: Exclude data within the group where the reward standard deviation is 0, and additionally sample new data. Default is False.
609611
- max_resample_times: Under the dynamic_sample setting, limit the number of resampling attempts to a maximum of 3. Default is 3 times.
610612
- overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#Soft Adaptive Policy Optimization
2+
3+
**Version Requirement**: ms-swift>=3.11
4+
5+
[Soft Adaptive Policy Optimization (SAPO)](https://arxiv.org/abs/2511.20347) addresses the issues caused by hard clipping in GRPO by proposing a temperature-controlled soft gate mechanism that smoothly attenuates off-policy updates while preserving useful learning signals.
6+
7+
##Background and Motivation
8+
9+
When training LLMs with reinforcement learning, GRPO handles off-policy training by computing token-level importance sampling ratios:
10+
11+
$$
12+
r_t = \frac{\pi_\theta(y_t|x, y_{<t})}{\pi_{\theta_{\mathrm{old}}}(y_t|x, y_{<t})}
13+
$$
14+
15+
However, token-level importance sampling ratios often exhibit high variance, which can be exacerbated in the following cases:
16+
-**Long text generation**
17+
-**MoE model routing heterogeneity**: The old-policy model during sampling and the training model may use different expert routing, significantly amplifying logps differences
18+
19+
To address this, GRPO uses hard clipping to limit the magnitude of policy updates:
20+
21+
$$
22+
L^{\mathrm{GRPO}} = -\min\left( r_t \cdot A, \mathrm{clip}(r_t, 1-\epsilon, 1+\epsilon) \cdot A \right)
23+
$$
24+
25+
**The Dilemma of Hard Clipping**: Hard clipping struggles to balance stability and learning efficiency—too strict clipping limits the number of effective samples, while too loose clipping introduces noisy gradients from off-policy samples, leading to training instability.
26+
27+
##SAPO Method
28+
29+
SAPO uses a temperature-controlled sigmoid soft gate function to replace hard clipping, achieving smooth gradient attenuation.
30+
31+
###Soft Gate Function
32+
33+
The core of SAPO is using the sigmoid function to apply soft gating on the importance sampling ratio:
34+
35+
For positive advantages ($A > 0$), use positive gating:
36+
37+
$$
38+
g^{+}_t = \sigma\left( \tau_{\mathrm{pos}} \cdot (r_t - 1) \right)
39+
$$
40+
41+
For negative advantages ($A < 0$), use negative gating:
42+
43+
$$
44+
g^{-}_t = \sigma\left( \tau_{\mathrm{neg}} \cdot (r_t - 1) \right)
45+
$$
46+
47+
where:
48+
- $\sigma(\cdot)$ is the sigmoid function
49+
- $\tau_{\mathrm{pos}}$ and $\tau_{\mathrm{neg}}$ are temperature parameters that control the gate function slope
50+
- $r_t$ is the importance sampling ratio
51+
52+
###SAPO Loss Function
53+
54+
$$
55+
L^{\mathrm{SAPO}} = -g_t \cdot A
56+
$$
57+
58+
where $g_t = g^{+}_t$ when $A > 0$, $g_t = g^{-}_t$ when $A < 0$.
59+
60+
###Temperature Parameters
61+
62+
The temperature parameter $\tau$ controls the decay rate of the soft gate function—larger values result in faster decay.
63+
64+
![tau curve](../../../../resources/sapo_tau.png)
65+
66+
The paper points out that positive advantages increase the logit of sampled tokens while decreasing the logits of all unsampled tokens; negative advantages do the opposite, increasing the logits of many unsampled tokens, which may spread to a large number of irrelevant tokens and introduce instability. Therefore, the paper recommends setting $\tau_\text{neg} > \tau_\text{pos}$ to make the gradient decay faster for tokens with negative rewards, improving training stability and performance.
67+
68+
The paper recommends default values of $\tau_{\mathrm{pos}} = 1.0$ and $\tau_{\mathrm{neg}} = 1.05$.
69+
70+
##Parameter Settings
71+
72+
| Parameter| Type| Default| Description|
73+
|-----------|------|---------|-------------|
74+
|`--loss_type`|`str`| -| Set to`sapo`|
75+
|`--tau_pos`|`float`|`1.0`| Temperature parameter for positive advantages, controls gate slope|
76+
|`--tau_neg`|`float`|`1.05`| Temperature parameter for negative advantages, controls gate slope|
77+
78+
```bash
79+
swift rlhf \
80+
--rlhf_type grpo \
81+
--loss_type sapo \
82+
--tau_pos 1.0 \
83+
--tau_neg 1.05 \
84+
# ... other parameters
85+
```
86+
87+
Example training scripts:
88+
89+
-[swift](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/sapo.sh)
90+
-[megatron swift](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/grpo/sapo.sh)
91+
92+
>The soft gate mechanism of SAPO only takes effect during off-policy training.
93+
>The importance sampling granularity in SAPO is at the token level (i.e., importance_sampling_level defaults to token), which conflicts with GSPO.

‎docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ Advanced Research
1212
CHORD.md
1313
CISPO.md
1414
training_inference_mismatch.md
15+
SAPO.md

‎docs/source_en/Instruction/GRPO/DeveloperGuide/loss_types.md‎

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#Loss Types
22

3-
GRPO training supportsfive differentloss types, with the maindifference being the normalization dimension.
3+
GRPO training supportsmultipleloss types, with the maindifferences being the normalization dimension and gradient handling.
44

55
##Loss Function
66

@@ -12,11 +12,18 @@ When setting `loss_type cispo`, the CISPO loss is used:
1212

1313
$$\mathcal{L}_{i,t}^{\text{CISPO}} = -\text{detach}\left(\min(\rho_{i,t}, \epsilon_{\text{high}})\right) \cdot A_{i,t} \cdot \log \pi_\theta(y_{i,t}|y_{i,<t})$$
1414

15+
When setting`loss_type sapo`, soft gating replaces hard clipping, see[SAPO](../AdvancedResearch/SAPO.md)
16+
17+
$$\mathcal{L}_{i,t}^{\text{SAPO}} = -g_{i,t} \cdot A_{i,t}$$
18+
19+
where $g_{i,t} = \sigma(\tau \cdot (\rho_{i,t} - 1))$ is the temperature-controlled soft gate function.
20+
1521
where:
1622
- $\rho_{i,t} = \frac{\pi_\theta(y_{i,t}|y_{i,<t})}{\pi_{\theta_{\text{old}}}(y_{i,t}|y_{i,<t})}$ is the importance sampling weight
1723
- $A_{i,t}$ is the advantage function
1824
- $\epsilon$ and $\epsilon_{\text{high}}$ are the clipping parameters
1925
- $\text{detach}(\cdot)$ indicates that this term does not participate in gradient computation
26+
- $\sigma(\cdot)$ is the sigmoid function, $\tau$ is the temperature parameter
2027

2128
##GRPO
2229

@@ -100,3 +107,11 @@ where:
100107
- $N_p$ is the number of samples for the $p$-th process
101108

102109
**Normalization Dimension:** Global token dimension (total completion tokens across all processes)
110+
111+
##SAPO
112+
113+
`--loss_type sapo`
114+
115+
SAPO uses temperature-controlled soft gating instead of hard clipping to achieve smooth gradient attenuation. The normalization method is the same as GRPO.
116+
117+
For details, please refer to[SAPO](../AdvancedResearch/SAPO.md)

‎examples/megatron/grpo/sapo.sh‎

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SAPO https://arxiv.org/abs/2511.20347
2+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
3+
NPROC_PER_NODE=8 \
4+
MAX_PIXELS=602112 \
5+
megatron rlhf \
6+
--rlhf_type grpo \
7+
--loss_type sapo \
8+
--tau_pos 1 \
9+
--tau_neg 1.05 \
10+
--model Qwen/Qwen2.5-VL-3B-Instruct \
11+
--context_parallel_size 1 \
12+
--tensor_model_parallel_size 1 \
13+
--pipeline_model_parallel_size 1 \
14+
--dataset AI-ModelScope/clevr_cogen_a_train \
15+
--load_safetensorstrue \
16+
--save_safetensorstrue \
17+
--external_plugins examples/train/grpo/plugin/plugin.py \
18+
--reward_funcs external_r1v_acc format \
19+
--dynamic_samplefalse \
20+
--steps_per_generation 4 \
21+
--micro_batch_size 2 \
22+
--global_batch_size 128 \
23+
--num_generations 8 \
24+
--use_vllmtrue \
25+
--vllm_mode colocate \
26+
--vllm_gpu_memory_utilization 0.7 \
27+
--vllm_max_model_len 8192 \
28+
--max_length 4096 \
29+
--max_completion_length 4096 \
30+
--train_type full \
31+
--bf16true \
32+
--importance_sampling_level token \
33+
--epsilon 0.2 \
34+
--epsilon_high 0.2 \
35+
--overlong_filtertrue \
36+
--max_epochs 1 \
37+
--eval_interval 1000 \
38+
--save_interval 1000 \
39+
--sleep_level 2 \
40+
--offload_modeltrue \
41+
--offload_optimizertrue \
42+
--log_interval 1 \
43+
--recompute_granularity selective \
44+
--finetune \
45+
--lr 1e-6 \
46+
--num_workers 8 \
47+
--dataset_num_proc 8 \
48+
--no_save_optim \
49+
--no_save_rng \
50+
--attention_backend flash \
51+
--temperature 1.0 \
52+
--system examples/train/grpo/prompt.txt \
53+
--beta 0.001 \
54+
--padding_freetrue \
55+
--wandb_project swift-megatron \
56+
--wandb_exp_name xxx

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp