Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPOTrainer log metrics are not gathered and meaned across ranks #2468

Open
zhc7 opened this issue Dec 13, 2024 · 3 comments · May be fixed by #2474
Open

DPOTrainer log metrics are not gathered and meaned across ranks #2468

zhc7 opened this issue Dec 13, 2024 · 3 comments · May be fixed by #2474
Labels
🐛 bug Something isn't working 🏋 DPO Related to DPO

Comments

@zhc7
Copy link

zhc7 commented Dec 13, 2024

Feature request

synchronize and average metrics across ranks.

Motivation

current metrics reported are only numbers on rank 0.

        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
        metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
        metrics[f"{prefix}logps/chosen"] = model_output["chosen_logps"].detach().mean().cpu()
        metrics[f"{prefix}logps/rejected"] = model_output["rejected_logps"].detach().mean().cpu()
        metrics[f"{prefix}logits/chosen"] = model_output["mean_chosen_logits"].detach().cpu()
        metrics[f"{prefix}logits/rejected"] = model_output["mean_rejected_logits"].detach().cpu()

all of these aren't synced.

Your contribution

current log function looks like:

    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training, including stored metrics.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        # logs either has 'loss' or 'eval_loss'
        train_eval = "train" if "loss" in logs else "eval"
        # Add averaged stored metrics to logs
        for key, metrics in self._stored_metrics[train_eval].items():
            logs[key] = torch.tensor(metrics).mean().item()
        del self._stored_metrics[train_eval]
        return super().log(logs)

it would have this feature if it looks like:

  def log(self, logs: Dict[str, float]) -> None:
      """
      Log `logs` on the various objects watching training, including stored metrics.

      Args:
          logs (`Dict[str, float]`):
              The values to log.
      """
      # logs either has 'loss' or 'eval_loss'
      train_eval = "train" if "loss" in logs else "eval"
      # Add averaged stored metrics to logs
      for key, metrics in self._stored_metrics[train_eval].items():
          if isinstance(metrics[0], torch.Tensor):
              gathered = self._nested_gather([m.cuda() for m in metrics])
              metrics = [g.mean() for g in gathered]
          meaned = torch.tensor(metrics).mean()
          logs[key] = meaned.item()
      del self._stored_metrics[train_eval]
      return super().log(logs)

I'm happy to submit a pr.

@qgallouedec
Copy link
Member

That's a good point! Feel free to open a PR to fix this. I don't think adding a unittest for this is relevant. If possible, add plots (eg, with wandb) before/after to ensure that we aren't introducing a regression

@qgallouedec qgallouedec added 🐛 bug Something isn't working 🏋 DPO Related to DPO labels Dec 13, 2024
@zhc7
Copy link
Author

zhc7 commented Dec 13, 2024

Ofcourse!
image
here's a graph for the same training with and without the modification. You can see the pink line is a lot more smoother. Especially the accuracy graph. My per_device_batch_size is 2 so the accuracy per device can only be 1, 0.5 or 0.

zhc7 added a commit to zhc7/trl that referenced this issue Dec 13, 2024
zhc7 added a commit to zhc7/trl that referenced this issue Dec 13, 2024
@zhc7 zhc7 linked a pull request Dec 13, 2024 that will close this issue
5 tasks
@qgallouedec
Copy link
Member

Perfect!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 DPO Related to DPO
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants