Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rin to all torch models #1969

Merged
merged 12 commits into from
Sep 2, 2023
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Improved**
- `TimeSeries` with a `RangeIndex` starting in the negative start are now supported by `historical_forecasts`. [#1866](https://github.com/unit8co/darts/pull/1866) by [Antoine Madrona](https://github.com/madtoinou).
- Added a new argument `start_format` to `historical_forecasts()`, `backtest()` and `gridsearch` that allows to use an integer `start` either as the index position or index value/label for `series` indexed with a `pd.RangeIndex`. [#1866](https://github.com/unit8co/darts/pull/1866) by [Antoine Madrona](https://github.com/madtoinou).
- Added `RINorm` (Reversible Instance Norm) as an input normalization option for all `TorchForecastingModel` except `RNNModel`. Activate it with model creation parameter `use_reversible_instance_norm`. [#1969](https://github.com/unit8co/darts/pull/1969) by [Dennis Bader](https://github.com/dennisbader).
- Reduced the size of the Darts docker image `unit8/darts:latest`, and included all optional models as well as dev requirements. [#1878](https://github.com/unit8co/darts/pull/1878) by [Alex Colpitts](https://github.com/alexcolpitts96).

**Fixed**
Expand Down Expand Up @@ -60,7 +61,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Other improvements:
- Improved static covariates column naming when using `StaticCovariatesTransformer` with a `sklearn.preprocessing.OneHotEncoder`. [#1863](https://github.com/unit8co/darts/pull/1863) by [Anne de Vries](https://github.com/anne-devries).
- Added `MSTL` (Season-Trend decomposition using LOESS for multiple seasonalities) as a `method` option for `extract_trend_and_seasonality()`. [#1879](https://github.com/unit8co/darts/pull/1879) by [Alex Colpitts](https://github.com/alexcolpitts96).
- Added `RINorm` (Reversible Instance Norm) as a new input normalization option for `TorchForecastingModel`. So far only `TiDEModel` supports it with model creation parameter `user_reversible_instance_norm`. [#1865](https://github.com/unit8co/darts/issues/1856) by [Alex Colpitts](https://github.com/alexcolpitts96).
- Added `RINorm` (Reversible Instance Norm) as a new input normalization option for `TorchForecastingModel`. So far only `TiDEModel` supports it with model creation parameter `use_reversible_instance_norm`. [#1865](https://github.com/unit8co/darts/issues/1856) by [Alex Colpitts](https://github.com/alexcolpitts96).
- Improvements to `TimeSeries.plot()`: custom axes are now properly supported with parameter `ax`. Axis is now returned for downstream tasks. [#1916](https://github.com/unit8co/darts/pull/1916) by [Dennis Bader](https://github.com/dennisbader).

**Fixed**
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ on bringing more models and features.
| [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | [TBATS paper](https://robjhyndman.com/papers/ComplexSeasonality.pdf) | 🟩 🟥 | 🟥 🟥 🟥 | 🟩 🟥 | 🟥 |
| [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | [Theta](https://robjhyndman.com/papers/Theta.pdf) & [4 Theta](https://github.com/Mcompetitions/M4-methods/blob/master/4Theta%20method.R) | 🟩 🟥 | 🟥 🟥 🟥 | 🟥 🟥 | 🟥 |
| [StatsForecastAutoTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_theta.html#darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | 🟩 🟥 | 🟥 🟥 🟥 | 🟩 🟥 | 🟥 |
| [Prophet](file:///Users/dennisbader/projects/unit8/darts/docs/build/html/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) (see [install notes](https://github.com/unit8co/darts/blob/master/INSTALL.md#enabling-support-for-facebook-prophet)) | [Prophet repo](https://github.com/facebook/prophet) | 🟩 🟥 | 🟥 🟩 🟥 | 🟩 🟥 | 🟥 |
| [FFT](file:///Users/dennisbader/projects/unit8/darts/docs/build/html/generated_api/darts.models.forecasting.fft.html#darts.models.forecasting.fft.FFT) (Fast Fourier Transform) | | 🟩 🟥 | 🟥 🟥 🟥 | 🟥 🟥 | 🟥 |
| [Prophet](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) (see [install notes](https://github.com/unit8co/darts/blob/master/INSTALL.md#enabling-support-for-facebook-prophet)) | [Prophet repo](https://github.com/facebook/prophet) | 🟩 🟥 | 🟥 🟩 🟥 | 🟩 🟥 | 🟥 |
| [FFT](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.fft.html#darts.models.forecasting.fft.FFT) (Fast Fourier Transform) | | 🟩 🟥 | 🟥 🟥 🟥 | 🟥 🟥 | 🟥 |
| [KalmanForecaster](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.kalman_forecaster.html#darts.models.forecasting.kalman_forecaster.KalmanForecaster) using the Kalman filter and N4SID for system identification | [N4SID paper](https://people.duke.edu/~hpgavin/SystemID/References/VanOverschee-Automatica-1994.pdf) | 🟩 🟩 | 🟥 🟩 🟥 | 🟩 🟥 | 🟥 |
| [Croston](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.croston.html#darts.models.forecasting.croston.Croston) method | | 🟩 🟥 | 🟥 🟥 🟥 | 🟥 🟥 | 🟥 |
| **Regression Models**<br/>([GlobalForecastingModel](https://unit8co.github.io/darts/userguide/covariates.html#global-forecasting-models-gfms)) | | | | | |
Expand Down
5 changes: 0 additions & 5 deletions darts/models/components/statsforecast_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
"""
StatsForecast utils
-----------
"""

import numpy as np

# In a normal distribution, 68.27 percentage of values lie within one standard deviation of the mean
Expand Down
14 changes: 13 additions & 1 deletion darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import torch.nn as nn

from darts.logging import get_logger, raise_if_not
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
from darts.models.forecasting.pl_forecasting_module import (
PLPastCovariatesModule,
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel

logger = get_logger(__name__)
Expand Down Expand Up @@ -101,6 +104,7 @@ def __init__(
last = feature
self.fc = nn.Sequential(*feats)

@io_processor
def forward(self, x_in: Tuple):
x, _ = x_in
# data is of size (batch_size, input_chunk_length, input_size)
Expand Down Expand Up @@ -194,6 +198,9 @@ def __init__(
to using a constant learning rate. Default: ``None``.
lr_scheduler_kwargs
Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``.
use_reversible_instance_norm
Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [1]_.
It is only applied to the features of the target series and not the covariates.
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
batch_size
Number of time series (input and output sequences) used in each training pass. Default: ``32``.
n_epochs
Expand Down Expand Up @@ -299,6 +306,11 @@ def encode_year(idx):
show_warnings
whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of
your forecasting use case. Default: ``False``.

References
----------
.. [1] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p
"""
super().__init__(**self._extract_torch_model_params(**self.model_params))

Expand Down
11 changes: 10 additions & 1 deletion darts/models/forecasting/dlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import torch.nn as nn

from darts.logging import raise_if
from darts.models.forecasting.pl_forecasting_module import PLMixedCovariatesModule
from darts.models.forecasting.pl_forecasting_module import (
PLMixedCovariatesModule,
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel

MixedCovariatesTrainTensorType = Tuple[
Expand Down Expand Up @@ -155,6 +158,7 @@ def _create_linear_layer(in_dim, out_dim):
layer_in_dim_static_cov, layer_out_dim
)

@io_processor
def forward(
self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
):
Expand Down Expand Up @@ -295,6 +299,9 @@ def __init__(
to using a constant learning rate. Default: ``None``.
lr_scheduler_kwargs
Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``.
use_reversible_instance_norm
Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_.
It is only applied to the features of the target series and not the covariates.
batch_size
Number of time series (input and output sequences) used in each training pass. Default: ``32``.
n_epochs
Expand Down Expand Up @@ -405,6 +412,8 @@ def encode_year(idx):
----------
.. [1] Zeng, A., Chen, M., Zhang, L., & Xu, Q. (2022).
Are Transformers Effective for Time Series Forecasting?. arXiv preprint arXiv:2205.13504.
.. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p
"""
super().__init__(**self._extract_torch_model_params(**self.model_params))

Expand Down
11 changes: 10 additions & 1 deletion darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import torch.nn as nn

from darts.logging import get_logger, raise_if_not, raise_log
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
from darts.models.forecasting.pl_forecasting_module import (
PLPastCovariatesModule,
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
from darts.utils.torch import MonteCarloDropout

Expand Down Expand Up @@ -490,6 +493,7 @@ def __init__(
self.stacks_list[-1].blocks[-1].backcast_linear_layer.requires_grad_(False)
self.stacks_list[-1].blocks[-1].backcast_g.requires_grad_(False)

@io_processor
def forward(self, x_in: Tuple):
x, _ = x_in

Expand Down Expand Up @@ -616,6 +620,9 @@ def __init__(
to using a constant learning rate. Default: ``None``.
lr_scheduler_kwargs
Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``.
use_reversible_instance_norm
Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_.
It is only applied to the features of the target series and not the covariates.
batch_size
Number of time series (input and output sequences) used in each training pass. Default: ``32``.
n_epochs
Expand Down Expand Up @@ -725,6 +732,8 @@ def encode_year(idx):
References
----------
.. [1] https://openreview.net/forum?id=r1ecqn4YwB
.. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p
"""
super().__init__(**self._extract_torch_model_params(**self.model_params))

Expand Down
11 changes: 10 additions & 1 deletion darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import torch.nn.functional as F

from darts.logging import get_logger, raise_if_not
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
from darts.models.forecasting.pl_forecasting_module import (
PLPastCovariatesModule,
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
from darts.utils.torch import MonteCarloDropout

Expand Down Expand Up @@ -417,6 +420,7 @@ def __init__(
# on this params (the last block backcast is not part of the final output of the net).
self.stacks_list[-1].blocks[-1].backcast_linear_layer.requires_grad_(False)

@io_processor
def forward(self, x_in: Tuple):
x, _ = x_in

Expand Down Expand Up @@ -552,6 +556,9 @@ def __init__(
to using a constant learning rate. Default: ``None``.
lr_scheduler_kwargs
Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``.
use_reversible_instance_norm
Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_.
It is only applied to the features of the target series and not the covariates.
batch_size
Number of time series (input and output sequences) used in each training pass. Default: ``32``.
n_epochs
Expand Down Expand Up @@ -662,6 +669,8 @@ def encode_year(idx):
----------
.. [1] C. Challu et al. "N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting",
https://arxiv.org/abs/2201.12886
.. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p
"""
super().__init__(**self._extract_torch_model_params(**self.model_params))

Expand Down
11 changes: 10 additions & 1 deletion darts/models/forecasting/nlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import torch.nn as nn

from darts.logging import raise_if
from darts.models.forecasting.pl_forecasting_module import PLMixedCovariatesModule
from darts.models.forecasting.pl_forecasting_module import (
PLMixedCovariatesModule,
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel


Expand Down Expand Up @@ -106,6 +109,7 @@ def _create_linear_layer(in_dim, out_dim):
layer_in_dim_static_cov, layer_out_dim
)

@io_processor
def forward(
self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
):
Expand Down Expand Up @@ -246,6 +250,9 @@ def __init__(
to using a constant learning rate. Default: ``None``.
lr_scheduler_kwargs
Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``.
use_reversible_instance_norm
Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_.
It is only applied to the features of the target series and not the covariates.
batch_size
Number of time series (input and output sequences) used in each training pass. Default: ``32``.
n_epochs
Expand Down Expand Up @@ -354,6 +361,8 @@ def encode_year(idx):
----------
.. [1] Zeng, A., Chen, M., Zhang, L., & Xu, Q. (2022).
Are Transformers Effective for Time Series Forecasting?. arXiv preprint arXiv:2205.13504.
.. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p
"""
super().__init__(**self._extract_torch_model_params(**self.model_params))

Expand Down
Loading
Loading