Skip to content

Commit

Permalink
address #6, last commit for the repo given official code is out
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 25, 2023
1 parent 800ce3f commit a2a6ddf
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 10 deletions.
25 changes: 18 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ from iTransformer import iTransformer

model = iTransformer(
num_variates = 137,
lookback_len = 96, # or the lookback length in the paper
dim = 256, # model dimensions
depth = 6, # depth
heads = 8, # attention heads
dim_head = 64, # head dimension
pred_length = (12, 24, 36, 48), # can be one prediction, or many
num_tokens_per_variate = 1 # experimental setting that projects each variate to more than one token. the idea is that the network can learn to divide up into time tokens for more granular attention across time. thanks to flash attention, you should be able to accommodate long sequence lengths just fine
lookback_len = 96, # or the lookback length in the paper
dim = 256, # model dimensions
depth = 6, # depth
heads = 8, # attention heads
dim_head = 64, # head dimension
pred_length = (12, 24, 36, 48), # can be one prediction, or many
num_tokens_per_variate = 1, # experimental setting that projects each variate to more than one token. the idea is that the network can learn to divide up into time tokens for more granular attention across time. thanks to flash attention, you should be able to accommodate long sequence lengths just fine
use_reversible_instance_norm = True # use reversible instance normalization, proposed here https://openreview.net/forum?id=cGDAkQo1C0p . may be redundant given the layernorms within iTransformer (and whatever else attention learns emergently on the first layer, prior to the first layernorm). if i come across some time, i'll gather up all the statistics across variates, project them, and condition the transformer a bit further. that makes more sense
)

time_series = torch.randn(2, 96, 137) # (batch, lookback len, variates)
Expand Down Expand Up @@ -114,3 +115,13 @@ preds = model(time_series)
note = {(Accelerated article preview)},
}
```

```bibtex
@inproceedings{kim2022reversible,
title = {Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift},
author = {Taesung Kim and Jinhee Kim and Yunwon Tae and Cheonbok Park and Jang-Ho Choi and Jaegul Choo},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=cGDAkQo1C0p}
}
```
5 changes: 4 additions & 1 deletion iTransformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from iTransformer.iTransformer import iTransformer
from iTransformer.iTransformer import (
iTransformer,
RevIN
)
45 changes: 44 additions & 1 deletion iTransformer/iTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn.functional as F

from beartype import beartype
from beartype.typing import Optional, Union, Tuple
from beartype.typing import Optional, Union, Tuple, Callable

from einops import rearrange, reduce, repeat, pack, unpack
from einops.layers.torch import Rearrange
Expand All @@ -19,9 +19,40 @@ def exists(v):
def default(v, d):
return v if exists(v) else d

def identity(t, *args, **kwargs):
return t

def cast_tuple(t):
return (t,) if not isinstance(t, tuple) else t

# reversible instance normalization
# proposed in https://openreview.net/forum?id=cGDAkQo1C0p

class RevIN(Module):
def __init__(self, num_variates, eps = 1e-5):
super().__init__()
self.eps = eps
self.num_variates = num_variates
self.gamma = nn.Parameter(torch.ones(num_variates, 1))
self.beta = nn.Parameter(torch.zeros(num_variates, 1))

@beartype
def forward(self, x) -> Tuple[Tensor, Callable[Tensor, Tensor]]:
assert x.shape[1] == self.num_variates

var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
var_rsqrt = var.clamp(min = self.eps).rsqrt()
instance_normalized = (x - mean) * var_rsqrt
rescaled = instance_normalized * self.gamma + self.beta

def reverse_fn(scaled_output):
clamped_gamma = torch.sign(self.gamma) * self.gamma.abs().clamp(min = self.eps)
unscaled_output = (scaled_output - self.beta) / clamped_gamma
return unscaled_output * var.sqrt() + mean

return rescaled, reverse_fn

# attention

class Attention(Module):
Expand Down Expand Up @@ -99,6 +130,7 @@ def __init__(
ff_mult = 4,
ff_dropout = 0.,
num_mem_tokens = 4,
use_reversible_instance_norm = False,
flash_attn = True
):
super().__init__()
Expand All @@ -110,6 +142,8 @@ def __init__(
pred_length = cast_tuple(pred_length)
self.pred_length = pred_length

self.reversible_instance_norm = RevIN(num_variates) if use_reversible_instance_norm else None

self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Expand Down Expand Up @@ -156,6 +190,10 @@ def forward(
# there is a lot of opportunity to improve on this, if the paper is successfully replicated

x = rearrange(x, 'b n v -> b v n')

if exists(self.reversible_instance_norm):
x, reverse_fn = self.reversible_instance_norm(x)

x = self.mlp_in(x)

# memory tokens
Expand All @@ -177,6 +215,11 @@ def forward(
if has_mem:
_, x = unpack(x, mem_ps, 'b * d')

# reversible instance normaization, if needed

if exists(self.reversible_instance_norm):
x = reverse_fn(x)

# predicting multiple times

pred_list = [fn(x) for fn in self.pred_heads]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'iTransformer',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.2.0',
license='MIT',
description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting',
author = 'Phil Wang',
Expand Down

0 comments on commit a2a6ddf

Please sign in to comment.