Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7dea7d2

Browse files
authored
[grpo] multi turn rollout tis/mis (#6803)
1 parent6501402 commit7dea7d2

File tree

6 files changed

+402
-99
lines changed

6 files changed

+402
-99
lines changed

‎docs/source/Instruction/GRPO/DeveloperGuide/multi_turn.md‎

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class MultiTurnScheduler(ABC):
4141
- infer_request (必需): 下一轮的推理请求对象
4242
- response_token_ids (可选): 每个 rollout 轮次的响应 token IDs
4343
- response_loss_mask (可选): 每个 rollout 轮次响应的损失掩码
44+
- rollout_logprobs (可选): 每个 rollout 轮次的响应对应的 logps
4445
- rollout_infos (可选): 额外信息数据
4546
"""
4647
raiseNotImplementedError
@@ -145,6 +146,7 @@ swift rollout \
145146

146147
`rollout` 命令中使用参数`use_async_engine` 来指定 engine 的种类(默认使用 async engine):
147148

149+
>注意: async engine 以及下面的自定义多轮交互逻辑 目前仅支持 server mode,对于 colocate mode 下的多轮交互逻辑,请参考 RolloutTrainerMixin 的_colocate_multi_turn_infer 方法
148150
149151
##高级设置
150152

@@ -222,3 +224,27 @@ class RewardFunction():
222224
###在 Scheduler 中获取额外的数据集信息
223225

224226
在训练侧设置参数`--vllm_server_pass_dataset`,可将数据集中的其他列传入多轮规划器。在`infer_request.data_dict`中获取。
227+
228+
###训推一致性兼容
229+
swift >= 3.11 支持从 vLLM 侧返回 rollouot 的 logps 用于纠正训推不一致问题,具体请参考该[文档](../AdvancedResearch/training_inference_mismatch.md)
230+
231+
在多轮训练中,如果启用了`rollout_importance_sampling_mode`,框架会自动收集每轮 rollout 的 log probabilities,用于校正训推不一致带来的 off-policy 问题。
232+
233+
**默认行为**
234+
- 使用默认的`run` 方法时,框架会自动从`response_choice.logprobs` 中提取 log probabilities
235+
- 这些 logprobs 会与`response_token_ids``response_loss_mask` 一起传递给 trainer
236+
237+
**自定义 Scheduler 的注意事项**
238+
239+
如果你在`step` 方法中修改了 response(如截断、添加内容),需要同步返回对应的`rollout_logprobs`
240+
241+
**关键规则**
242+
-`rollout_logprobs` 的长度应该等于`response_loss_mask` 中值为 1 的数量
243+
- 对于`loss_mask=0` 的 token(如用户添加的提示、工具返回结果),不需要提供 logprobs
244+
- 如果`step` 方法没有返回`rollout_logprobs`,框架会自动从`response_choice.logprobs` 中提取
245+
246+
**重写`run` 方法的场景**
247+
248+
如果你完全重写了`run` 方法,需要手动收集和传递`rollout_logprobs`
249+
250+
具体的实现请参考[内置实现](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/multi_turn.py)

‎docs/source_en/Instruction/GRPO/DeveloperGuide/multi_turn.md‎

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class MultiTurnScheduler(ABC):
4545
- infer_request (required): the inference request for the next turn
4646
- response_token_ids (optional): token IDs of each rollout response
4747
- response_loss_mask (optional): loss mask of each rollout response
48+
- rollout_logprobs (optional): token logps of each rollout response
4849
- rollout_infos (optional): extra information
4950
"""
5051
raiseNotImplementedError
@@ -150,6 +151,8 @@ AsyncEngine reduces compute bubbles in multi-turn inference:
150151

151152
Use the`use_async_engine` argument in the`rollout` command to specify the engine type (async is the default).
152153

154+
>Note: The async engine and the custom multi-turn interaction logic below are currently only supported in server mode. For multi-turn interaction logic in colocate mode, please refer to the_colocate_multi_turn_infer method in RolloutTrainerMixin.
155+
153156
##Advanced topics
154157

155158
###Customising the interaction logic
@@ -237,3 +240,28 @@ class RewardFunction():
237240

238241
Set`--vllm_server_pass_dataset` on the training side to pass other dataset columns to the scheduler.
239242
They can be read from`infer_request.data_dict`.
243+
244+
###Training-Inference-Mismatch
245+
246+
Swift >= 3.11 supports returning rollout logprobs from the vLLM side to address training-inference mismatch issues. For details, please refer to this[document](../AdvancedResearch/training_inference_mismatch.md).
247+
248+
In multi-turn training, if`rollout_importance_sampling_mode` is enabled, the framework automatically collects log probabilities from each rollout turn to correct off-policy issues.
249+
250+
**Default Behavior**:
251+
- When using the default`run` method, the framework automatically extracts log probabilities from`response_choice.logprobs`
252+
- These logprobs are passed to the trainer along with`response_token_ids` and`response_loss_mask`
253+
254+
**Notes for Custom Schedulers**:
255+
256+
If you modify the response in your`step` method (e.g., truncation, adding content), you need to return the corresponding`rollout_logprobs`:
257+
258+
**Key Rules**:
259+
- The length of`rollout_logprobs` should equal the count of 1s in`response_loss_mask`
260+
- For tokens with`loss_mask=0` (e.g., user-added prompts, tool return results), no logprobs are needed
261+
- If`step` does not return`rollout_logprobs`, the framework will automatically extract them from`response_choice.logprobs`
262+
263+
**When Overriding the`run` Method**:
264+
265+
If you completely override the`run` method, you need to manually collect and pass`rollout_logprobs`
266+
267+
For implementation, please refer to[here](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/multi_turn.py)

‎swift/llm/infer/protocol.py‎

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,11 +358,13 @@ class RolloutOutput(BaseModel):
358358
response_token_ids:List[List[int]]=Field(default_factory=list)
359359
response_loss_mask:List[List[int]]=Field(default_factory=list)
360360
rollout_infos:Dict[str,Any]=Field(default_factory=dict)
361+
# rollout logprobs for each turn (used for rollout importance sampling correction in multi-turn scenarios)
362+
rollout_logprobs:List[List[float]]=Field(default_factory=list)
361363

362-
@field_validator('response_token_ids','response_loss_mask',mode='before')
364+
@field_validator('response_token_ids','response_loss_mask','rollout_logprobs',mode='before')
363365
@classmethod
364366
def_wrap_flat_list(cls,v):
365-
ifisinstance(v,list)andvandisinstance(v[0],int):
367+
ifisinstance(v,list)andvandisinstance(v[0],(int,float)):
366368
return [v]
367369
returnv
368370

‎swift/plugin/multi_turn.py‎

Lines changed: 145 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def __init__(self,
2222
*args,
2323
**kwargs):
2424
self.infer_engine=infer_engine
25+
# Tokenizer can be passed explicitly (e.g., in colocate mode where infer_engine may be None)
26+
self._tokenizer=kwargs.get('tokenizer',None)
2527
self.max_turns=max_turns
2628

2729
asyncdefasync_infer(self,
@@ -131,6 +133,15 @@ def __getattr__(self, key: str):
131133
defengine(self):
132134
returnself.infer_engine
133135

136+
@property
137+
deftokenizer(self):
138+
"""Get tokenizer, prioritizing explicitly passed tokenizer over infer_engine's tokenizer."""
139+
ifself._tokenizerisnotNone:
140+
returnself._tokenizer
141+
ifself.infer_engineisnotNone:
142+
returnself.infer_engine.tokenizer
143+
returnNone
144+
134145

135146
classMultiTurnScheduler(RolloutScheduler,ABC):
136147
"""
@@ -219,6 +230,7 @@ async def run(self, infer_request, request_config, **kwargs):
219230
rollout_infos= {}
220231
total_response_ids= []
221232
total_response_loss_mask= []
233+
total_rollout_logprobs= []
222234
whileTrue:
223235
messages=current_request.messages
224236
ifcurrent_turn==1ornotmessages[-1]['content']:
@@ -247,12 +259,45 @@ async def run(self, infer_request, request_config, **kwargs):
247259
should_stop=should_stopor (current_turn>=self.max_turns)
248260

249261
ifshould_stop:
262+
# Collect final turn's data
263+
current_logprobs=self._extract_logprobs_from_choice(response_choice)
264+
final_token_ids=response_choice.token_ids
265+
250266
ifis_continuationandtotal_response_ids:
251-
# for continuation and total_response_ids is not empty
252-
# we need to extend the last turn's response_token_ids and response_loss_mask
253-
total_response_ids[-1].extend(response_choice.token_ids)
267+
# For continuation, extend the last turn's data
268+
total_response_ids[-1].extend(final_token_ids)
254269
iftotal_response_loss_mask:
255-
total_response_loss_mask[-1].extend([1]*len(response_choice.token_ids))
270+
total_response_loss_mask[-1].extend([1]*len(final_token_ids))
271+
iftotal_rollout_logprobsandcurrent_logprobs:
272+
total_rollout_logprobs[-1].extend(current_logprobs)
273+
elifnottotal_response_ids:
274+
# First turn stopped immediately - need to initialize with final response data
275+
iffinal_token_ids:
276+
total_response_ids= [list(final_token_ids)]
277+
total_response_loss_mask= [[1]*len(final_token_ids)]
278+
ifcurrent_logprobs:
279+
total_rollout_logprobs= [current_logprobs]
280+
281+
# Validate rollout_logprobs completeness: if logprobs are incomplete (missing for some turns),
282+
# clear them to disable rollout importance sampling correction (which requires complete logprobs)
283+
# Note: rollout_logprobs should match the number of loss_mask=1 tokens, not total response tokens
284+
# because completion_mask in grpo_trainer is based on labels != -100, which corresponds to loss_mask=1
285+
final_rollout_logprobs=total_rollout_logprobs
286+
iftotal_rollout_logprobs:
287+
total_logprob_count=sum(len(turn_lps)forturn_lpsintotal_rollout_logprobs)
288+
iftotal_response_loss_mask:
289+
# Check if the number of logprobs matches the number of loss_mask=1 tokens
290+
total_loss_mask_1_count=sum(sum(mask)formaskintotal_response_loss_mask)
291+
iftotal_loss_mask_1_count!=total_logprob_count:
292+
# Incomplete logprobs, clear them
293+
final_rollout_logprobs= []
294+
else:
295+
iftotal_response_ids:
296+
total_response_id_count=sum(len(turn_ids)forturn_idsintotal_response_ids)
297+
iftotal_response_id_count!=total_logprob_count:
298+
final_rollout_logprobs= []
299+
else:
300+
final_rollout_logprobs= []
256301

257302
returnRolloutOutput(
258303
response=response,
@@ -262,6 +307,7 @@ async def run(self, infer_request, request_config, **kwargs):
262307
rollout_infos={
263308
**rollout_infos,'num_turns':current_turn
264309
},
310+
rollout_logprobs=final_rollout_logprobs,
265311
)
266312

267313
# Prepare next turn
@@ -291,6 +337,18 @@ async def run(self, infer_request, request_config, **kwargs):
291337
# If you need to keep all step-wise details, switch to append or merge instead.
292338
rollout_infos.update(ret['rollout_infos'])
293339

340+
# Track rollout_logprobs for rollout importance sampling correction
341+
# Prefer step's returned logprobs (which may be modified/truncated) over raw response_choice logprobs
342+
if'rollout_logprobs'inretandret['rollout_logprobs']:
343+
current_logprobs=ret['rollout_logprobs']
344+
else:
345+
current_logprobs=self._extract_logprobs_from_choice(response_choice)
346+
ifcurrent_logprobs:
347+
ifis_continuationandtotal_rollout_logprobs:
348+
total_rollout_logprobs[-1].extend(current_logprobs)
349+
else:
350+
total_rollout_logprobs.append(current_logprobs)
351+
294352
ifcurrent_request.messages[-1]['role']=='assistant':
295353
# Add a dummy response to allow engine to continue generating
296354
current_request.messages.append({'role':'assistant','content':None})
@@ -310,8 +368,11 @@ def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompl
310368
Returns:
311369
Dict[str, Any]: A dictionary containing inference results with the following structure:
312370
- infer_request (required): Main inference request object
313-
- response_token_ids (Optional[List[List[int]]]): Token IDs of responses for each rollout turn
314-
- response_loss_scale (Optional[List[List[int]]]): Loss scaling factors for responses in each rollout turn # noqa
371+
- response_token_ids (Optional[List[int]]): Token IDs of response for current rollout turn
372+
- response_loss_mask (Optional[List[int]]): Loss mask for response tokens (same length as response_token_ids) # noqa
373+
- rollout_logprobs (Optional[List[float]]): Log probabilities for response tokens.
374+
If not provided, will be extracted from response_choice.logprobs as fallback.
375+
Useful when modifying response content (e.g., adding prompts) to avoid logprob misalignment.
315376
- rollout_infos (Optional[Dict[str, Any]]): Additional metadata (must be serializable)
316377
317378
"""
@@ -347,6 +408,22 @@ def check_finished(self, infer_request: 'RolloutInferRequest', response_choice:
347408
returnTrue
348409
returnFalse
349410

411+
@staticmethod
412+
def_extract_logprobs_from_choice(response_choice:'ChatCompletionResponseChoice')->List[float]:
413+
"""Extract logprobs list from response choice for rollout importance sampling.
414+
415+
Args:
416+
response_choice: The response choice containing logprobs
417+
418+
Returns:
419+
List of logprob values, or empty list if not available
420+
"""
421+
ifresponse_choice.logprobsisNone:
422+
return []
423+
if'content'inresponse_choice.logprobs:
424+
return [item['logprob']foriteminresponse_choice.logprobs['content']]
425+
return []
426+
350427

351428
classThinkingModelTipsScheduler(MultiTurnScheduler):
352429
"""
@@ -513,6 +590,15 @@ def __init__(self, *args, **kwargs):
513590
from .ormimportMathAccuracy
514591
super().__init__(*args,**kwargs)
515592
self.acc_func=kwargs.get('acc_function',MathAccuracy())
593+
# Cache the tokenized tips_prompt length for loss mask computation
594+
self._tips_token_ids=None
595+
596+
def_get_tips_token_ids(self,tokenizer)->List[int]:
597+
"""Get tokenized tips_prompt (cached for efficiency)."""
598+
ifself._tips_token_idsisNone:
599+
# Tokenize without special tokens to get the raw token ids
600+
self._tips_token_ids=tokenizer.encode(self.tips_prompt,add_special_tokens=False)
601+
returnself._tips_token_ids
516602

517603
defcheck_finished(self,infer_request:'RolloutInferRequest',response_choice:'ChatCompletionResponseChoice',
518604
current_turn:int)->bool:
@@ -531,11 +617,55 @@ def check_finished(self, infer_request: 'RolloutInferRequest', response_choice:
531617
defstep(self,infer_request:'RolloutInferRequest',response_choice:'ChatCompletionResponseChoice',
532618
current_turn:int)->Dict:
533619
completion=response_choice.message.content
620+
response_token_ids=list(response_choice.token_ids)ifresponse_choice.token_idselse []
621+
622+
# Extract logprobs from response_choice before any truncation
623+
rollout_logprobs=self._extract_logprobs_from_choice(response_choice)
624+
625+
# Truncate completion at <answer> or </think> tags
626+
truncate_idx=len(completion)
534627
if'<answer>'incompletion:
535-
completion=completion[:completion.index('<answer>')]
628+
truncate_idx=min(truncate_idx,completion.index('<answer>'))
536629
if'</think>'incompletion:
537-
completion=completion[:completion.index('</think>')]
630+
truncate_idx=min(truncate_idx,completion.index('</think>'))
631+
632+
iftruncate_idx<len(completion):
633+
# Need to truncate token_ids and logprobs as well
634+
truncated_completion=completion[:truncate_idx]
635+
ifresponse_token_idsandself.tokenizerisnotNone:
636+
# Find the token index corresponding to the truncation point
637+
# by decoding progressively until we reach or exceed the truncation point
638+
token_truncate_idx=len(response_token_ids)
639+
foriinrange(1,len(response_token_ids)+1):
640+
decoded=self.tokenizer.decode(response_token_ids[:i],skip_special_tokens=False)
641+
iflen(decoded)>=truncate_idx:
642+
token_truncate_idx=i
643+
break
644+
response_token_ids=response_token_ids[:token_truncate_idx]
645+
# Truncate logprobs to match
646+
ifrollout_logprobs:
647+
rollout_logprobs=rollout_logprobs[:token_truncate_idx]
648+
completion=truncated_completion
649+
650+
# Add tips_prompt
538651
completion+=self.tips_prompt
652+
653+
# Compute loss_mask for tips tokens
654+
# Note: rollout_logprobs should NOT include tips tokens because:
655+
# 1. Tips tokens have loss_mask=0, so their labels will be -100
656+
# 2. completion_mask = (labels != -100), so tips tokens won't be in completion_mask
657+
# 3. rollout_logprobs must align with completion_mask, not response_token_ids
658+
ifresponse_token_idsandself.tokenizerisnotNone:
659+
tips_token_ids=self._get_tips_token_ids(self.tokenizer)
660+
# Loss mask: original tokens = 1, tips tokens = 0
661+
response_loss_mask= [1]*len(response_token_ids)+ [0]*len(tips_token_ids)
662+
# Append tips token ids to response
663+
response_token_ids=response_token_ids+tips_token_ids
664+
# Do NOT extend rollout_logprobs for tips tokens - they are masked out in completion_mask
665+
else:
666+
response_loss_mask= []
667+
668+
# Update messages
539669
ifinfer_request.messages[-1]['role']=='assistant':
540670
ifnotinfer_request.messages[-1]['content']:
541671
# Multi-turn continuation: pop the dummy input we add in last turn
@@ -544,7 +674,13 @@ def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompl
544674
else:
545675
infer_request.messages.append({'role':'assistant','content':completion})
546676

547-
return {'infer_request':infer_request}
677+
result= {'infer_request':infer_request}
678+
ifresponse_token_ids:
679+
result['response_token_ids']=response_token_ids
680+
result['response_loss_mask']=response_loss_mask
681+
ifrollout_logprobs:
682+
result['rollout_logprobs']=rollout_logprobs
683+
returnresult
548684

549685

550686
classGYMScheduler(RolloutScheduler):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp