From d5f8d68e794108b822ec2ce5721f89c4c12abf18 Mon Sep 17 00:00:00 2001 From: dennisbader Date: Sat, 26 Oct 2024 15:24:31 +0200 Subject: [PATCH] fix failing torch test for torchmetrics when multi output --- darts/models/forecasting/pl_forecasting_module.py | 11 +++++++++-- .../forecasting/test_torch_forecasting_model.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index 6318afb4cd..b8c96623a9 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -400,11 +400,18 @@ def _update_metrics(self, output, target, metrics): return if self.likelihood: - metrics.update(self.likelihood.sample(output), target) + pred = self.likelihood.sample(output) else: # If there's no likelihood, nr_params=1, and we need to squeeze out the # last dimension of model output, for properly computing the metric. - metrics.update(output.squeeze(dim=-1), target) + pred = output.squeeze(dim=-1) + + # torch metrics require 2D targets of shape (batch size * ocl, num targets) + if self.n_targets > 1: + target = target.reshape(-1, self.n_targets) + pred = pred.reshape(-1, self.n_targets) + + metrics.update(pred, target) def _compute_metrics(self, metrics): if not len(metrics): diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 2668d8d767..9ae63e80f7 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1310,7 +1310,7 @@ def test_metrics(self): 10, 10, n_epochs=1, - torch_metrics=metric, + torch_metrics=metric_collection, pl_trainer_kwargs=model_kwargs, ) model.fit(self.multivariate_series)