Skip to content

Commit

Permalink
finish last thought with the repository. will revisit attention for t…
Browse files Browse the repository at this point in the history
…ime series at the next sota
  • Loading branch information
lucidrains committed Nov 19, 2023
1 parent c2aa7f0 commit 899c3d9
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 4 deletions.
32 changes: 30 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,41 @@ preds = model(time_series)
# -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137))
```

### iTransformer with Normalization Statistics Conditining

Reversible instance normalization, but all statistics across variates are concatted and projected into a conditioning vector for FiLM conditioning after each layernorm in the transformer.

```python
import torch
from iTransformer import iTransformerNormConditioned

# using solar energy settings

model = iTransformerNormConditioned(
num_variates = 137,
lookback_len = 96, # or the lookback length in the paper
dim = 256, # model dimensions
depth = 6, # depth
heads = 8, # attention heads
dim_head = 64, # head dimension
pred_length = (12, 24, 36, 48), # can be one prediction, or many
num_tokens_per_variate = 1, # experimental setting that projects each variate to more than one token. the idea is that the network can learn to divide up into time tokens for more granular attention across time. thanks to flash attention, you should be able to accommodate long sequence lengths just fine
)

time_series = torch.randn(2, 96, 137) # (batch, lookback len, variates)

preds = model(time_series)

# preds -> Dict[int, Tensor[batch, pred_length, variate]]
# -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137))
```

## Todo

- [x] beef up the transformer with latest findings
- [x] improvise a 2d version across both variates and time
- [x] improvise a version that includes fft tokens

- [ ] improvise a variant that uses adaptive normalization conditioned on statistics across all variates
- [x] improvise a variant that uses adaptive normalization conditioned on statistics across all variates

## Citation

Expand Down
4 changes: 3 additions & 1 deletion iTransformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@

from iTransformer.iTransformer2D import iTransformer2D

from iTransformer.iTransformerFFT import iTransformerFFT
from iTransformer.iTransformerFFT import iTransformerFFT

from iTransformer.iTransformerNormConditioned import iTransformerNormConditioned
256 changes: 256 additions & 0 deletions iTransformer/iTransformerNormConditioned.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F

from beartype import beartype
from beartype.typing import Optional, Union, Tuple

from einops import rearrange, reduce, repeat, pack, unpack
from einops.layers.torch import Rearrange

from iTransformer.attend import Attend

# helper functions

def exists(v):
return v is not None

def default(v, d):
return v if exists(v) else d

def identity(t, *args, **kwargs):
return t

def cast_tuple(t):
return (t,) if not isinstance(t, tuple) else t

# attention

class Attention(Module):
def __init__(
self,
dim,
dim_head = 32,
heads = 4,
dropout = 0.,
flash = True
):
super().__init__()
self.scale = dim_head ** -0.5
dim_inner = dim_head * heads

self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias = False),
Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
)

self.to_v_gates = nn.Sequential(
nn.Linear(dim, dim_inner, bias = False),
nn.SiLU(),
Rearrange('b n (h d) -> b h n d', h = heads)
)

self.attend = Attend(flash = flash, dropout = dropout)

self.to_out = nn.Sequential(
Rearrange('b h n d -> b n (h d)'),
nn.Linear(dim_inner, dim, bias = False),
nn.Dropout(dropout)
)

def forward(self, x):
q, k, v = self.to_qkv(x)

out = self.attend(q, k, v)

out = out * self.to_v_gates(x)
return self.to_out(out)

# feedforward

class GEGLU(Module):
def forward(self, x):
x, gate = rearrange(x, '... (r d) -> r ... d', r = 2)
return x * F.gelu(gate)

def FeedForward(dim, mult = 4, dropout = 0.):
dim_inner = int(dim * mult * 2 / 3)
return nn.Sequential(
nn.Linear(dim, dim_inner * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim)
)

# film conditioning

class FiLM(Module):
def __init__(self, dim_in, dim_cond):
super().__init__()
self.to_film_gamma_beta = nn.Sequential(
nn.Linear(dim_in, dim_cond * 2),
Rearrange('... (r d) -> r ... d', r = 2)
)

def forward(self, x, cond):
gamma, beta = self.to_film_gamma_beta(cond)
return x * gamma + beta

# main class

class iTransformerNormConditioned(Module):
@beartype
def __init__(
self,
*,
num_variates: int,
lookback_len: int,
depth: int,
dim: int,
num_tokens_per_variate = 1,
pred_length: Union[int, Tuple[int, ...]],
dim_head = 32,
heads = 4,
attn_dropout = 0.,
ff_mult = 4,
ff_dropout = 0.,
num_mem_tokens = 4,
flash_attn = True
):
super().__init__()
self.num_variates = num_variates
self.lookback_len = lookback_len

self.mem_tokens = nn.Parameter(torch.randn(num_mem_tokens, dim)) if num_mem_tokens > 0 else None

pred_length = cast_tuple(pred_length)
self.pred_length = pred_length

dim_cond = dim * 4

self.to_norm_condition = nn.Sequential(
nn.Linear(num_variates * 2, dim_cond),
nn.SiLU()
)

self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn),
nn.LayerNorm(dim, elementwise_affine = False),
FiLM(dim_cond, dim),
FeedForward(dim, mult = ff_mult, dropout = ff_dropout),
nn.LayerNorm(dim, elementwise_affine = False),
FiLM(dim_cond, dim),
]))

self.mlp_in = nn.Sequential(
nn.Linear(lookback_len, dim * num_tokens_per_variate),
Rearrange('b v (n d) -> b (v n) d', n = num_tokens_per_variate),
nn.LayerNorm(dim)
)

self.pred_heads = ModuleList([])

for one_pred_length in pred_length:
head = nn.Sequential(
Rearrange('b (v n) d -> b v (n d)', n = num_tokens_per_variate),
nn.Linear(dim * num_tokens_per_variate, one_pred_length),
Rearrange('b v n -> b n v')
)

self.pred_heads.append(head)

@beartype
def forward(
self,
x: Tensor,
targets: Optional[Union[Tensor, Tuple[Tensor, ...]]] = None,
eps = 1e-5

):
"""
einstein notation
b - batch
n - time
v - variate
s - norm statistics
"""
has_mem = exists(self.mem_tokens)
assert x.shape[1:] == (self.lookback_len, self.num_variates)

# the crux of the paper is basically treating variates as the spatial dimension in attention
# there is a lot of opportunity to improve on this, if the paper is successfully replicated

x = rearrange(x, 'b n v -> b v n')

# normalize

mean = x.mean(dim = -1, keepdim = True)
var = x.var(dim = -1, unbiased = False, keepdim = True)

x = (x - mean) * var.clamp(min = eps).rsqrt()

# concat statistics for adaptive layernorm

norm_stats = torch.cat((mean, var), dim = -1)
norm_stats = rearrange(norm_stats, 'b v s -> b 1 (v s)')

cond = self.to_norm_condition(norm_stats)

# mlp to tokens

x = self.mlp_in(x)

# memory tokens

if has_mem:
m = repeat(self.mem_tokens, 'm d -> b m d', b = x.shape[0])
x, mem_ps = pack([m, x], 'b * d')

# attention and feedforward layers

for attn, attn_post_norm, attn_film, ff, ff_post_norm, ff_film in self.layers:
x = attn(x) + x
x = attn_post_norm(x)
x = attn_film(x, cond)

x = ff(x) + x
x = ff_post_norm(x)
x = ff_film(x, cond)

# splice out memory tokens

if has_mem:
_, x = unpack(x, mem_ps, 'b * d')

# denormalize

x = (x * var.sqrt()) + mean

# predicting multiple times

pred_list = [fn(x) for fn in self.pred_heads]

# calculate loss if targets is passed in

if exists(targets):
targets = cast_tuple(targets)
assert len(targets) == len(pred_list)

assert self.training
mse_loss = 0.
for target, pred in zip(targets, pred_list):
assert targets.shape == pred_list.shape

mse_loss = mse_loss + F.mse_loss(target, pred)

return mse_loss

if len(pred_list) == 0:
return pred_list[0]

pred_dict = dict(zip(self.pred_length, pred_list))
return pred_dict
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'iTransformer',
packages = find_packages(exclude=[]),
version = '0.4.0',
version = '0.4.1',
license='MIT',
description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting',
author = 'Phil Wang',
Expand Down

0 comments on commit 899c3d9

Please sign in to comment.