@@ -22,6 +22,8 @@ def __init__(self,
2222* args ,
2323** kwargs ):
2424self .infer_engine = infer_engine
25+ # Tokenizer can be passed explicitly (e.g., in colocate mode where infer_engine may be None)
26+ self ._tokenizer = kwargs .get ('tokenizer' ,None )
2527self .max_turns = max_turns
2628
2729async def async_infer (self ,
@@ -131,6 +133,15 @@ def __getattr__(self, key: str):
131133def engine (self ):
132134return self .infer_engine
133135
136+ @property
137+ def tokenizer (self ):
138+ """Get tokenizer, prioritizing explicitly passed tokenizer over infer_engine's tokenizer."""
139+ if self ._tokenizer is not None :
140+ return self ._tokenizer
141+ if self .infer_engine is not None :
142+ return self .infer_engine .tokenizer
143+ return None
144+
134145
135146class MultiTurnScheduler (RolloutScheduler ,ABC ):
136147"""
@@ -219,6 +230,7 @@ async def run(self, infer_request, request_config, **kwargs):
219230rollout_infos = {}
220231total_response_ids = []
221232total_response_loss_mask = []
233+ total_rollout_logprobs = []
222234while True :
223235messages = current_request .messages
224236if current_turn == 1 or not messages [- 1 ]['content' ]:
@@ -247,12 +259,45 @@ async def run(self, infer_request, request_config, **kwargs):
247259should_stop = should_stop or (current_turn >= self .max_turns )
248260
249261if should_stop :
262+ # Collect final turn's data
263+ current_logprobs = self ._extract_logprobs_from_choice (response_choice )
264+ final_token_ids = response_choice .token_ids
265+
250266if is_continuation and total_response_ids :
251- # for continuation and total_response_ids is not empty
252- # we need to extend the last turn's response_token_ids and response_loss_mask
253- total_response_ids [- 1 ].extend (response_choice .token_ids )
267+ # For continuation, extend the last turn's data
268+ total_response_ids [- 1 ].extend (final_token_ids )
254269if total_response_loss_mask :
255- total_response_loss_mask [- 1 ].extend ([1 ]* len (response_choice .token_ids ))
270+ total_response_loss_mask [- 1 ].extend ([1 ]* len (final_token_ids ))
271+ if total_rollout_logprobs and current_logprobs :
272+ total_rollout_logprobs [- 1 ].extend (current_logprobs )
273+ elif not total_response_ids :
274+ # First turn stopped immediately - need to initialize with final response data
275+ if final_token_ids :
276+ total_response_ids = [list (final_token_ids )]
277+ total_response_loss_mask = [[1 ]* len (final_token_ids )]
278+ if current_logprobs :
279+ total_rollout_logprobs = [current_logprobs ]
280+
281+ # Validate rollout_logprobs completeness: if logprobs are incomplete (missing for some turns),
282+ # clear them to disable rollout importance sampling correction (which requires complete logprobs)
283+ # Note: rollout_logprobs should match the number of loss_mask=1 tokens, not total response tokens
284+ # because completion_mask in grpo_trainer is based on labels != -100, which corresponds to loss_mask=1
285+ final_rollout_logprobs = total_rollout_logprobs
286+ if total_rollout_logprobs :
287+ total_logprob_count = sum (len (turn_lps )for turn_lps in total_rollout_logprobs )
288+ if total_response_loss_mask :
289+ # Check if the number of logprobs matches the number of loss_mask=1 tokens
290+ total_loss_mask_1_count = sum (sum (mask )for mask in total_response_loss_mask )
291+ if total_loss_mask_1_count != total_logprob_count :
292+ # Incomplete logprobs, clear them
293+ final_rollout_logprobs = []
294+ else :
295+ if total_response_ids :
296+ total_response_id_count = sum (len (turn_ids )for turn_ids in total_response_ids )
297+ if total_response_id_count != total_logprob_count :
298+ final_rollout_logprobs = []
299+ else :
300+ final_rollout_logprobs = []
256301
257302return RolloutOutput (
258303response = response ,
@@ -262,6 +307,7 @@ async def run(self, infer_request, request_config, **kwargs):
262307rollout_infos = {
263308** rollout_infos ,'num_turns' :current_turn
264309 },
310+ rollout_logprobs = final_rollout_logprobs ,
265311 )
266312
267313# Prepare next turn
@@ -291,6 +337,18 @@ async def run(self, infer_request, request_config, **kwargs):
291337# If you need to keep all step-wise details, switch to append or merge instead.
292338rollout_infos .update (ret ['rollout_infos' ])
293339
340+ # Track rollout_logprobs for rollout importance sampling correction
341+ # Prefer step's returned logprobs (which may be modified/truncated) over raw response_choice logprobs
342+ if 'rollout_logprobs' in ret and ret ['rollout_logprobs' ]:
343+ current_logprobs = ret ['rollout_logprobs' ]
344+ else :
345+ current_logprobs = self ._extract_logprobs_from_choice (response_choice )
346+ if current_logprobs :
347+ if is_continuation and total_rollout_logprobs :
348+ total_rollout_logprobs [- 1 ].extend (current_logprobs )
349+ else :
350+ total_rollout_logprobs .append (current_logprobs )
351+
294352if current_request .messages [- 1 ]['role' ]== 'assistant' :
295353# Add a dummy response to allow engine to continue generating
296354current_request .messages .append ({'role' :'assistant' ,'content' :None })
@@ -310,8 +368,11 @@ def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompl
310368 Returns:
311369 Dict[str, Any]: A dictionary containing inference results with the following structure:
312370 - infer_request (required): Main inference request object
313- - response_token_ids (Optional[List[List[int]]]): Token IDs of responses for each rollout turn
314- - response_loss_scale (Optional[List[List[int]]]): Loss scaling factors for responses in each rollout turn # noqa
371+ - response_token_ids (Optional[List[int]]): Token IDs of response for current rollout turn
372+ - response_loss_mask (Optional[List[int]]): Loss mask for response tokens (same length as response_token_ids) # noqa
373+ - rollout_logprobs (Optional[List[float]]): Log probabilities for response tokens.
374+ If not provided, will be extracted from response_choice.logprobs as fallback.
375+ Useful when modifying response content (e.g., adding prompts) to avoid logprob misalignment.
315376 - rollout_infos (Optional[Dict[str, Any]]): Additional metadata (must be serializable)
316377
317378 """
@@ -347,6 +408,22 @@ def check_finished(self, infer_request: 'RolloutInferRequest', response_choice:
347408return True
348409return False
349410
411+ @staticmethod
412+ def _extract_logprobs_from_choice (response_choice :'ChatCompletionResponseChoice' )-> List [float ]:
413+ """Extract logprobs list from response choice for rollout importance sampling.
414+
415+ Args:
416+ response_choice: The response choice containing logprobs
417+
418+ Returns:
419+ List of logprob values, or empty list if not available
420+ """
421+ if response_choice .logprobs is None :
422+ return []
423+ if 'content' in response_choice .logprobs :
424+ return [item ['logprob' ]for item in response_choice .logprobs ['content' ]]
425+ return []
426+
350427
351428class ThinkingModelTipsScheduler (MultiTurnScheduler ):
352429"""
@@ -513,6 +590,15 @@ def __init__(self, *args, **kwargs):
513590from .orm import MathAccuracy
514591super ().__init__ (* args ,** kwargs )
515592self .acc_func = kwargs .get ('acc_function' ,MathAccuracy ())
593+ # Cache the tokenized tips_prompt length for loss mask computation
594+ self ._tips_token_ids = None
595+
596+ def _get_tips_token_ids (self ,tokenizer )-> List [int ]:
597+ """Get tokenized tips_prompt (cached for efficiency)."""
598+ if self ._tips_token_ids is None :
599+ # Tokenize without special tokens to get the raw token ids
600+ self ._tips_token_ids = tokenizer .encode (self .tips_prompt ,add_special_tokens = False )
601+ return self ._tips_token_ids
516602
517603def check_finished (self ,infer_request :'RolloutInferRequest' ,response_choice :'ChatCompletionResponseChoice' ,
518604current_turn :int )-> bool :
@@ -531,11 +617,55 @@ def check_finished(self, infer_request: 'RolloutInferRequest', response_choice:
531617def step (self ,infer_request :'RolloutInferRequest' ,response_choice :'ChatCompletionResponseChoice' ,
532618current_turn :int )-> Dict :
533619completion = response_choice .message .content
620+ response_token_ids = list (response_choice .token_ids )if response_choice .token_ids else []
621+
622+ # Extract logprobs from response_choice before any truncation
623+ rollout_logprobs = self ._extract_logprobs_from_choice (response_choice )
624+
625+ # Truncate completion at <answer> or </think> tags
626+ truncate_idx = len (completion )
534627if '<answer>' in completion :
535- completion = completion [: completion .index ('<answer>' )]
628+ truncate_idx = min ( truncate_idx , completion .index ('<answer>' ))
536629if '</think>' in completion :
537- completion = completion [:completion .index ('</think>' )]
630+ truncate_idx = min (truncate_idx ,completion .index ('</think>' ))
631+
632+ if truncate_idx < len (completion ):
633+ # Need to truncate token_ids and logprobs as well
634+ truncated_completion = completion [:truncate_idx ]
635+ if response_token_ids and self .tokenizer is not None :
636+ # Find the token index corresponding to the truncation point
637+ # by decoding progressively until we reach or exceed the truncation point
638+ token_truncate_idx = len (response_token_ids )
639+ for i in range (1 ,len (response_token_ids )+ 1 ):
640+ decoded = self .tokenizer .decode (response_token_ids [:i ],skip_special_tokens = False )
641+ if len (decoded )>= truncate_idx :
642+ token_truncate_idx = i
643+ break
644+ response_token_ids = response_token_ids [:token_truncate_idx ]
645+ # Truncate logprobs to match
646+ if rollout_logprobs :
647+ rollout_logprobs = rollout_logprobs [:token_truncate_idx ]
648+ completion = truncated_completion
649+
650+ # Add tips_prompt
538651completion += self .tips_prompt
652+
653+ # Compute loss_mask for tips tokens
654+ # Note: rollout_logprobs should NOT include tips tokens because:
655+ # 1. Tips tokens have loss_mask=0, so their labels will be -100
656+ # 2. completion_mask = (labels != -100), so tips tokens won't be in completion_mask
657+ # 3. rollout_logprobs must align with completion_mask, not response_token_ids
658+ if response_token_ids and self .tokenizer is not None :
659+ tips_token_ids = self ._get_tips_token_ids (self .tokenizer )
660+ # Loss mask: original tokens = 1, tips tokens = 0
661+ response_loss_mask = [1 ]* len (response_token_ids )+ [0 ]* len (tips_token_ids )
662+ # Append tips token ids to response
663+ response_token_ids = response_token_ids + tips_token_ids
664+ # Do NOT extend rollout_logprobs for tips tokens - they are masked out in completion_mask
665+ else :
666+ response_loss_mask = []
667+
668+ # Update messages
539669if infer_request .messages [- 1 ]['role' ]== 'assistant' :
540670if not infer_request .messages [- 1 ]['content' ]:
541671# Multi-turn continuation: pop the dummy input we add in last turn
@@ -544,7 +674,13 @@ def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompl
544674else :
545675infer_request .messages .append ({'role' :'assistant' ,'content' :completion })
546676
547- return {'infer_request' :infer_request }
677+ result = {'infer_request' :infer_request }
678+ if response_token_ids :
679+ result ['response_token_ids' ]= response_token_ids
680+ result ['response_loss_mask' ]= response_loss_mask
681+ if rollout_logprobs :
682+ result ['rollout_logprobs' ]= rollout_logprobs
683+ return result
548684
549685
550686class GYMScheduler (RolloutScheduler ):