diff --git a/CHANGELOG.md b/CHANGELOG.md index 10812c7abd..b5df8be67a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Fixed a bug when using encoders with `RegressionModel` and series with a non-evenly spaced frequency (e.g. Month Begin). This raised an error during lagged data creation when trying to divide a pd.Timedelta by the ambiguous frequency. [#2034](https://github.com/unit8co/darts/pull/2034) by [Antoine Madrona](https://github.com/madtoinou). - Fixed a bug when loading a `TorchForecastingModel` that was trained with a precision other than `float64`. [#2046](https://github.com/unit8co/darts/pull/2046) by [Freddie Hsin-Fu Huang](https://github.com/Hsinfu). - Fixed broken links in the `Transfer learning` example notebook with publicly hosted version of the three datasets. [#2067](https://github.com/unit8co/darts/pull/2067) by [Antoine Madrona](https://github.com/madtoinou). +- Fixed a bug when using `NLinearModel` on multivariate series with covariates and `normalize=True`. [#2072](https://github.com/unit8co/darts/pull/2072) by [Antoine Madrona](https://github.com/madtoinou). ### For developers of the library: diff --git a/darts/models/forecasting/nlinear.py b/darts/models/forecasting/nlinear.py index 2120e32de9..438a5faf5e 100644 --- a/darts/models/forecasting/nlinear.py +++ b/darts/models/forecasting/nlinear.py @@ -144,7 +144,8 @@ def forward( if self.normalize: # get last values only for target features seq_last = x[:, -1:, : self.output_dim].detach() - x = x - seq_last + # normalize the target features only (ignore the covariates) + x[:, :, : self.output_dim] = x[:, :, : self.output_dim] - seq_last x = self.layer(x.view(batch, -1)) # (batch, out_len * out_dim * nr_params) x = x.view( @@ -174,6 +175,7 @@ def forward( x = x.view(batch, self.output_chunk_length, self.output_dim, self.nr_params) if self.normalize: + # model only forecasts target components, no need to slice x = x + seq_last.view(seq_last.shape + (1,)) return x