Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Refactors the loss calculation to pull it out into a free function#1137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
finbarrtimbers wants to merge13 commits intomain
base:main
Choose a base branch
Loading
fromrefactor-loss
Open
Changes from1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
PrevPrevious commit
NextNext commit
Cleaned up code
  • Loading branch information
@finbarrtimbers
finbarrtimbers committedNov 3, 2025
commit2d450c26c460e63726bc654357ed3905afc34804
88 changes: 6 additions & 82 deletionsopen_instruct/grpo_fast.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -92,7 +92,7 @@
cleanup_all_llm_judge_clients,
soft_format_reward_func,
)
from open_instruct.metrics import MetricsTracker
from open_instruct.metrics importLossStatistics,MetricsTracker
from open_instruct.model_utils import (
Batch,
ModelConfig,
Expand DownExpand Up@@ -136,81 +136,6 @@
INVALID_LOGPROB = 1.0


class LossStatistics:
def __init__(self, num_batches: int, record_entropy: bool = False):
self.kl1_stats = torch.zeros(num_batches)
self.kl2_stats = torch.zeros(num_batches)
self.kl3_stats = torch.zeros(num_batches)
self.kl4_stats = torch.zeros(num_batches)
self.kl_loss_stats = torch.zeros(num_batches)
self.pg_clipfrac_stats = torch.zeros(num_batches)
self.pg_loss_stats = torch.zeros(num_batches)
self.loss_stats = torch.zeros(num_batches)
self.ratio_stats = torch.zeros(num_batches)
self.entropy_stats = torch.zeros(num_batches) if record_entropy else None
self.kl1 = None
self.kl2 = None
self.kl3 = None
self.kl4 = None

def update_kl_estimates(self, ref_logprobs_diff, ratio, mb_response_masks_bool, args):
self.kl1 = ref_logprobs_diff
self.kl2 = (ref_logprobs_diff) ** 2 / 2
self.kl3 = torch.expm1(-ref_logprobs_diff) + ref_logprobs_diff
self.kl4 = ratio * ref_logprobs_diff

def kl(self, args):
if args.kl_estimator == "kl1":
return self.kl1
elif args.kl_estimator == "kl2":
return self.kl2
elif args.kl_estimator == "kl3":
return self.kl3
elif args.kl_estimator == "kl4":
return self.kl4

def update_stats(
self, i, mb_response_masks_bool, pg_losses, pg_losses2, pg_loss_max, ratio, loss, mb_entropy, args
):
self.kl1_stats[i] = masked_mean(
self.kl1, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
).float()
self.kl2_stats[i] = masked_mean(
self.kl2, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
).float()
self.kl3_stats[i] = masked_mean(
self.kl3, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
).float()
self.kl4_stats[i] = masked_mean(
self.kl4, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
).float()
if args.kl_estimator == "kl1":
self.kl_loss_stats[i] = self.kl1_stats[i] * args.beta
elif args.kl_estimator == "kl2":
self.kl_loss_stats[i] = self.kl2_stats[i] * args.beta
elif args.kl_estimator == "kl3":
self.kl_loss_stats[i] = self.kl3_stats[i] * args.beta
elif args.kl_estimator == "kl4":
self.kl_loss_stats[i] = self.kl4_stats[i] * args.beta
self.pg_clipfrac_stats[i] = masked_mean(
(pg_losses2 > pg_losses).float(),
mb_response_masks_bool,
args.masked_mean_axis,
args.masked_mean_denominator,
)
self.pg_loss_stats[i] = masked_mean(
pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
)
self.loss_stats[i] = loss
self.ratio_stats[i] = masked_mean(
ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
)
if args.record_entropy and self.entropy_stats is not None:
self.entropy_stats[i] = masked_mean(
mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
).float()


class ShutdownSentinel:
"""Sentinel value to signal thread shutdown via queue."""

Expand DownExpand Up@@ -1035,8 +960,7 @@ def calculate_loss(
pg_loss_max = torch.max(pg_losses, pg_losses2)

ref_logprobs_diff = (mb_new_logprobs - mb_ref_logprob).clamp(-40.0, 40.0)
loss_statistics.update_kl_estimates(ref_logprobs_diff, ratio, mb_response_masks_bool, args)
kl = loss_statistics.kl(args)
kl = loss_statistics.update_kl_estimates(i, ref_logprobs_diff, ratio, mb_response_masks_bool, args)

loss = masked_mean(
pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
Expand DownExpand Up@@ -1228,10 +1152,10 @@ def train(
)

with torch.no_grad():
self.local_metrics.add("objective/kl_avg", loss_statistics.kl1_stats.mean())
self.local_metrics.add("objective/kl2_avg", loss_statistics.kl2_stats.mean())
self.local_metrics.add("objective/kl3_avg", loss_statistics.kl3_stats.mean())
self.local_metrics.add("objective/kl4_avg", loss_statistics.kl4_stats.mean())
self.local_metrics.add("objective/kl_avg", loss_statistics.kl_stats[0].mean())
self.local_metrics.add("objective/kl2_avg", loss_statistics.kl_stats[1].mean())
self.local_metrics.add("objective/kl3_avg", loss_statistics.kl_stats[2].mean())
self.local_metrics.add("objective/kl4_avg", loss_statistics.kl_stats[3].mean())
self.local_metrics.add("loss/policy_avg", loss_statistics.pg_loss_stats.mean())
self.local_metrics.add("loss/kl_avg", loss_statistics.kl_loss_stats.mean())
self.local_metrics.add("loss/total_avg", loss_statistics.loss_stats.mean())
Expand Down
Loading

[8]ページ先頭

©2009-2025 Movatter.jp