Skip to content

Commit

Permalink
address #15
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 16, 2023
1 parent 2d65cc8 commit 5c4f9ac
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 6 deletions.
3 changes: 2 additions & 1 deletion iTransformer/iTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
ff_dropout = 0.,
num_mem_tokens = 4,
use_reversible_instance_norm = False,
reversible_instance_norm_affine = False,
flash_attn = True
):
super().__init__()
Expand All @@ -115,7 +116,7 @@ def __init__(
pred_length = cast_tuple(pred_length)
self.pred_length = pred_length

self.reversible_instance_norm = RevIN(num_variates) if use_reversible_instance_norm else None
self.reversible_instance_norm = RevIN(num_variates, affine = reversible_instance_norm_affine) if use_reversible_instance_norm else None

self.layers = ModuleList([])
for _ in range(depth):
Expand Down
3 changes: 2 additions & 1 deletion iTransformer/iTransformer2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(
ff_dropout = 0.,
num_mem_tokens = 4,
use_reversible_instance_norm = False,
reversible_instance_norm_affine = True,
flash_attn = True
):
super().__init__()
Expand All @@ -175,7 +176,7 @@ def __init__(
pred_length = cast_tuple(pred_length)
self.pred_length = pred_length

self.reversible_instance_norm = RevIN(num_variates) if use_reversible_instance_norm else None
self.reversible_instance_norm = RevIN(num_variates, affine = reversible_instance_norm_affine) if use_reversible_instance_norm else None

rotary_emb = RotaryEmbedding(dim_head)

Expand Down
11 changes: 8 additions & 3 deletions iTransformer/revin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@
# proposed in https://openreview.net/forum?id=cGDAkQo1C0p

class RevIN(Module):
def __init__(self, num_variates, eps = 1e-5):
def __init__(
self,
num_variates,
affine = True,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.num_variates = num_variates
self.gamma = nn.Parameter(torch.ones(num_variates, 1))
self.beta = nn.Parameter(torch.zeros(num_variates, 1))
self.gamma = nn.Parameter(torch.ones(num_variates, 1), requires_grad = affine)
self.beta = nn.Parameter(torch.zeros(num_variates, 1), requires_grad = affine)

def forward(self, x, return_statistics = False):
assert x.shape[1] == self.num_variates
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'iTransformer',
packages = find_packages(exclude=[]),
version = '0.3.4',
version = '0.3.5',
license='MIT',
description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting',
author = 'Phil Wang',
Expand Down

0 comments on commit 5c4f9ac

Please sign in to comment.