@@ -1278,7 +1278,8 @@ def _pad_or_trim_last_dim(tensor: Optional[torch.Tensor], target_len: int) -> Op
12781278reporting_metric = {** avg_metric ,** custom_metrics }
12791279
12801280# log_completions
1281- if self .log_completions and self .is_main_process and (self ._step - 1 )% self .steps_per_generation == 0 :
1281+ if (self .log_completions and self .is_main_process and (self ._step - 1 )% self .steps_per_generation == 0
1282+ and self ._step != self ._last_logged_step ):
12821283table = {
12831284'gen_step' : [self ._step - 1 ]* len (self ._logs ['prompt' ]),
12841285'prompt' :list (self ._logs ['prompt' ]),
@@ -1297,6 +1298,7 @@ def _pad_or_trim_last_dim(tensor: Optional[torch.Tensor], target_len: int) -> Op
12971298# wandb_writer.define_metric('completions', step_metric='gen_step')
12981299# self.init_custom_metric = True
12991300wandb_writer .log ({'completions' :wandb .Table (dataframe = df )})
1301+ self ._last_logged_step = self ._step
13001302
13011303return loss ,reporting_metric
13021304
@@ -1486,6 +1488,7 @@ def _prepare_metrics(self):
14861488self .wandb_log_unique_prompts = args .wandb_log_unique_prompts
14871489self .jsonl_writer = JsonlWriter (os .path .join (args .save ,'completions.jsonl' ),write_on_rank = 'last' )
14881490self .init_custom_metric = False
1491+ self ._last_logged_step = - 1
14891492self ._logs = {
14901493'prompt' :deque (maxlen = args .generation_batch_size ),
14911494'completion' :deque (maxlen = args .generation_batch_size ),