-
Notifications
You must be signed in to change notification settings - Fork 14
/
hnn.py
106 lines (84 loc) · 3.7 KB
/
hnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from torch import nn
class HNN(nn.Module):
'''Learn arbitrary vector fields that are sums of conservative and solenoidal fields'''
def __init__(self, input_dim, differentiable_model, field_type='solenoidal',
baseline=False, assume_canonical_coords=True):
super(HNN, self).__init__()
self.baseline = baseline
self.differentiable_model = differentiable_model
self.assume_canonical_coords = assume_canonical_coords
self.M = self.permutation_tensor(input_dim) # Levi-Civita permutation tensor
self.field_type = field_type
def forward(self, x):
# traditional forward pass
if self.baseline:
return self.differentiable_model(x)
y = self.differentiable_model(x)
assert y.dim() == 2 and y.shape[1] == 2, "Output tensor should have shape [batch_size, 2]"
return y.split(1, 1)
def time_derivative(self, x, t=None, separate_fields=False):
'''NEURAL ODE-STLE VECTOR FIELD'''
if self.baseline:
return self.differentiable_model(x)
'''NEURAL HAMILTONIAN-STLE VECTOR FIELD'''
F1, F2 = self.forward(x) # traditional forward pass
conservative_field = torch.zeros_like(x) # start out with both components set to 0
solenoidal_field = torch.zeros_like(x)
if self.field_type != 'solenoidal':
dF1 = torch.autograd.grad(F1.sum(), x, create_graph=True)[0] # gradients for conservative field
conservative_field = dF1 @ torch.eye(*self.M.shape)
if self.field_type != 'conservative':
dF2 = torch.autograd.grad(F2.sum(), x, create_graph=True)[0] # gradients for solenoidal field
solenoidal_field = dF2 @ self.M.t()
if separate_fields:
return [conservative_field, solenoidal_field]
return conservative_field + solenoidal_field
def permutation_tensor(self, n):
M = None
if self.assume_canonical_coords:
M = torch.eye(n)
M = torch.cat([M[n // 2:], -M[:n // 2]])
else:
'''Constructs the Levi-Civita permutation tensor'''
M = torch.ones(n, n) # matrix of ones
M *= 1 - torch.eye(n) # clear diagonals
M[::2] *= -1 # pattern of signs
M[:, ::2] *= -1
for i in range(n): # make asymmetric
for j in range(i + 1, n):
M[i, j] *= -1
return M
class MLP(nn.Module):
'''Just a salt-of-the-earth MLP'''
def __init__(self, input_dim, hidden_dim, output_dim, nonlinearity='tanh'):
super(MLP, self).__init__()
self.linear1 = torch.nn.Linear(input_dim, hidden_dim)
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.linear3 = torch.nn.Linear(hidden_dim, output_dim, bias=None)
for l in [self.linear1, self.linear2, self.linear3]:
torch.nn.init.orthogonal_(l.weight) # use a principled initialization
self.nonlinearity = choose_nonlinearity(nonlinearity)
def forward(self, x, separate_fields=False):
h = self.nonlinearity(self.linear1(x))
h = self.nonlinearity(self.linear2(h))
return self.linear3(h)
def choose_nonlinearity(name):
nl = None
if name == 'tanh':
nl = torch.tanh
elif name == 'relu':
nl = torch.relu
elif name == 'sigmoid':
nl = torch.sigmoid
elif name == 'softplus':
nl = torch.nn.functional.softplus
elif name == 'selu':
nl = torch.nn.functional.selu
elif name == 'elu':
nl = torch.nn.functional.elu
elif name == 'swish':
nl = lambda x: x * torch.sigmoid(x)
else:
raise ValueError("nonlinearity not recognized")
return nl