-
Notifications
You must be signed in to change notification settings - Fork 32
/
dain.py
94 lines (72 loc) · 3.13 KB
/
dain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
class DAIN_Layer(nn.Module):
def __init__(self, mode='adaptive_avg', mean_lr=0.00001, gate_lr=0.001, scale_lr=0.00001, input_dim=144):
super(DAIN_Layer, self).__init__()
print("Mode = ", mode)
self.mode = mode
self.mean_lr = mean_lr
self.gate_lr = gate_lr
self.scale_lr = scale_lr
# Parameters for adaptive average
self.mean_layer = nn.Linear(input_dim, input_dim, bias=False)
self.mean_layer.weight.data = torch.FloatTensor(data=np.eye(input_dim, input_dim))
# Parameters for adaptive std
self.scaling_layer = nn.Linear(input_dim, input_dim, bias=False)
self.scaling_layer.weight.data = torch.FloatTensor(data=np.eye(input_dim, input_dim))
# Parameters for adaptive scaling
self.gating_layer = nn.Linear(input_dim, input_dim)
self.eps = 1e-8
def forward(self, x):
# Expecting (n_samples, dim, n_feature_vectors)
# Nothing to normalize
if self.mode == None:
pass
# Do simple average normalization
elif self.mode == 'avg':
avg = torch.mean(x, 2)
avg = avg.resize(avg.size(0), avg.size(1), 1)
x = x - avg
# Perform only the first step (adaptive averaging)
elif self.mode == 'adaptive_avg':
avg = torch.mean(x, 2)
adaptive_avg = self.mean_layer(avg)
adaptive_avg = adaptive_avg.resize(adaptive_avg.size(0), adaptive_avg.size(1), 1)
x = x - adaptive_avg
# Perform the first + second step (adaptive averaging + adaptive scaling )
elif self.mode == 'adaptive_scale':
# Step 1:
avg = torch.mean(x, 2)
adaptive_avg = self.mean_layer(avg)
adaptive_avg = adaptive_avg.resize(adaptive_avg.size(0), adaptive_avg.size(1), 1)
x = x - adaptive_avg
# Step 2:
std = torch.mean(x ** 2, 2)
std = torch.sqrt(std + self.eps)
adaptive_std = self.scaling_layer(std)
adaptive_std[adaptive_std <= self.eps] = 1
adaptive_std = adaptive_std.resize(adaptive_std.size(0), adaptive_std.size(1), 1)
x = x / (adaptive_std)
elif self.mode == 'full':
# Step 1:
avg = torch.mean(x, 2)
adaptive_avg = self.mean_layer(avg)
adaptive_avg = adaptive_avg.resize(adaptive_avg.size(0), adaptive_avg.size(1), 1)
x = x - adaptive_avg
# # Step 2:
std = torch.mean(x ** 2, 2)
std = torch.sqrt(std + self.eps)
adaptive_std = self.scaling_layer(std)
adaptive_std[adaptive_std <= self.eps] = 1
adaptive_std = adaptive_std.resize(adaptive_std.size(0), adaptive_std.size(1), 1)
x = x / adaptive_std
# Step 3:
avg = torch.mean(x, 2)
gate = F.sigmoid(self.gating_layer(avg))
gate = gate.resize(gate.size(0), gate.size(1), 1)
x = x * gate
else:
assert False
return x