From a2a6ddfd5645f25ed71a51d3edc2bec2d3613cb5 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 25 Oct 2023 10:18:54 -0700 Subject: [PATCH] address https://github.com/lucidrains/iTransformer/issues/6, last commit for the repo given official code is out --- README.md | 25 ++++++++++++++------ iTransformer/__init__.py | 5 +++- iTransformer/iTransformer.py | 45 +++++++++++++++++++++++++++++++++++- setup.py | 2 +- 4 files changed, 67 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index e22d445..885517a 100644 --- a/README.md +++ b/README.md @@ -30,13 +30,14 @@ from iTransformer import iTransformer model = iTransformer( num_variates = 137, - lookback_len = 96, # or the lookback length in the paper - dim = 256, # model dimensions - depth = 6, # depth - heads = 8, # attention heads - dim_head = 64, # head dimension - pred_length = (12, 24, 36, 48), # can be one prediction, or many - num_tokens_per_variate = 1 # experimental setting that projects each variate to more than one token. the idea is that the network can learn to divide up into time tokens for more granular attention across time. thanks to flash attention, you should be able to accommodate long sequence lengths just fine + lookback_len = 96, # or the lookback length in the paper + dim = 256, # model dimensions + depth = 6, # depth + heads = 8, # attention heads + dim_head = 64, # head dimension + pred_length = (12, 24, 36, 48), # can be one prediction, or many + num_tokens_per_variate = 1, # experimental setting that projects each variate to more than one token. the idea is that the network can learn to divide up into time tokens for more granular attention across time. thanks to flash attention, you should be able to accommodate long sequence lengths just fine + use_reversible_instance_norm = True # use reversible instance normalization, proposed here https://openreview.net/forum?id=cGDAkQo1C0p . may be redundant given the layernorms within iTransformer (and whatever else attention learns emergently on the first layer, prior to the first layernorm). if i come across some time, i'll gather up all the statistics across variates, project them, and condition the transformer a bit further. that makes more sense ) time_series = torch.randn(2, 96, 137) # (batch, lookback len, variates) @@ -114,3 +115,13 @@ preds = model(time_series) note = {(Accelerated article preview)}, } ``` + +```bibtex +@inproceedings{kim2022reversible, + title = {Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift}, + author = {Taesung Kim and Jinhee Kim and Yunwon Tae and Cheonbok Park and Jang-Ho Choi and Jaegul Choo}, + booktitle = {International Conference on Learning Representations}, + year = {2022}, + url = {https://openreview.net/forum?id=cGDAkQo1C0p} +} +``` diff --git a/iTransformer/__init__.py b/iTransformer/__init__.py index d72669a..1abc21d 100644 --- a/iTransformer/__init__.py +++ b/iTransformer/__init__.py @@ -1 +1,4 @@ -from iTransformer.iTransformer import iTransformer +from iTransformer.iTransformer import ( + iTransformer, + RevIN +) diff --git a/iTransformer/iTransformer.py b/iTransformer/iTransformer.py index 8e901aa..a78b419 100644 --- a/iTransformer/iTransformer.py +++ b/iTransformer/iTransformer.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from beartype import beartype -from beartype.typing import Optional, Union, Tuple +from beartype.typing import Optional, Union, Tuple, Callable from einops import rearrange, reduce, repeat, pack, unpack from einops.layers.torch import Rearrange @@ -19,9 +19,40 @@ def exists(v): def default(v, d): return v if exists(v) else d +def identity(t, *args, **kwargs): + return t + def cast_tuple(t): return (t,) if not isinstance(t, tuple) else t +# reversible instance normalization +# proposed in https://openreview.net/forum?id=cGDAkQo1C0p + +class RevIN(Module): + def __init__(self, num_variates, eps = 1e-5): + super().__init__() + self.eps = eps + self.num_variates = num_variates + self.gamma = nn.Parameter(torch.ones(num_variates, 1)) + self.beta = nn.Parameter(torch.zeros(num_variates, 1)) + + @beartype + def forward(self, x) -> Tuple[Tensor, Callable[Tensor, Tensor]]: + assert x.shape[1] == self.num_variates + + var = torch.var(x, dim = -1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = -1, keepdim = True) + var_rsqrt = var.clamp(min = self.eps).rsqrt() + instance_normalized = (x - mean) * var_rsqrt + rescaled = instance_normalized * self.gamma + self.beta + + def reverse_fn(scaled_output): + clamped_gamma = torch.sign(self.gamma) * self.gamma.abs().clamp(min = self.eps) + unscaled_output = (scaled_output - self.beta) / clamped_gamma + return unscaled_output * var.sqrt() + mean + + return rescaled, reverse_fn + # attention class Attention(Module): @@ -99,6 +130,7 @@ def __init__( ff_mult = 4, ff_dropout = 0., num_mem_tokens = 4, + use_reversible_instance_norm = False, flash_attn = True ): super().__init__() @@ -110,6 +142,8 @@ def __init__( pred_length = cast_tuple(pred_length) self.pred_length = pred_length + self.reversible_instance_norm = RevIN(num_variates) if use_reversible_instance_norm else None + self.layers = ModuleList([]) for _ in range(depth): self.layers.append(ModuleList([ @@ -156,6 +190,10 @@ def forward( # there is a lot of opportunity to improve on this, if the paper is successfully replicated x = rearrange(x, 'b n v -> b v n') + + if exists(self.reversible_instance_norm): + x, reverse_fn = self.reversible_instance_norm(x) + x = self.mlp_in(x) # memory tokens @@ -177,6 +215,11 @@ def forward( if has_mem: _, x = unpack(x, mem_ps, 'b * d') + # reversible instance normaization, if needed + + if exists(self.reversible_instance_norm): + x = reverse_fn(x) + # predicting multiple times pred_list = [fn(x) for fn in self.pred_heads] diff --git a/setup.py b/setup.py index 7ba04f6..0b8096b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'iTransformer', packages = find_packages(exclude=[]), - version = '0.1.0', + version = '0.2.0', license='MIT', description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting', author = 'Phil Wang',