-
Notifications
You must be signed in to change notification settings - Fork 3
/
models.py
295 lines (260 loc) · 10.4 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import functools
import math
from typing import Callable, Optional, Union
import e3nn_jax as e3nn
import haiku as hk
import jax
import jax.numpy as jnp
from ..tools import safe_norm
from .blocks import (
EquivariantProductBasisBlock,
InteractionBlock,
LinearNodeEmbeddingBlock,
LinearReadoutBlock,
NonLinearReadoutBlock,
RadialEmbeddingBlock,
)
try:
from profile_nn_jax import profile
except ImportError:
def profile(_, x):
return x
class MACE(hk.Module):
def __init__(
self,
*,
output_irreps: e3nn.Irreps, # Irreps of the output, default 1x0e
r_max: float,
num_interactions: int, # Number of interactions (layers), default 2
hidden_irreps: e3nn.Irreps, # 256x0e or 128x0e + 128x1o
readout_mlp_irreps: e3nn.Irreps, # Hidden irreps of the MLP in last readout, default 16x0e
avg_num_neighbors: float,
num_species: int,
num_features: int = None, # Number of features per node, default gcd of hidden_irreps multiplicities
avg_r_min: float = None,
radial_basis: Callable[[jnp.ndarray], jnp.ndarray],
radial_envelope: Callable[[jnp.ndarray], jnp.ndarray],
# Number of zero derivatives at small and large distances, default 4 and 2
# If both are None, it uses a smooth C^inf envelope function
max_ell: int = 3, # Max spherical harmonic degree, default 3
epsilon: Optional[float] = None,
correlation: int = 3, # Correlation order at each layer (~ node_features^correlation), default 3
gate: Callable = jax.nn.silu, # activation function
symmetric_tensor_product_basis: bool = True,
off_diagonal: bool = False,
interaction_irreps: Union[str, e3nn.Irreps] = "o3_restricted", # or o3_full
node_embedding: hk.Module = LinearNodeEmbeddingBlock,
):
super().__init__()
output_irreps = e3nn.Irreps(output_irreps)
hidden_irreps = e3nn.Irreps(hidden_irreps)
readout_mlp_irreps = e3nn.Irreps(readout_mlp_irreps)
if num_features is None:
self.num_features = functools.reduce(
math.gcd, (mul for mul, _ in hidden_irreps)
)
self.hidden_irreps = e3nn.Irreps(
[(mul // self.num_features, ir) for mul, ir in hidden_irreps]
)
else:
self.num_features = num_features
self.hidden_irreps = hidden_irreps
self.sh_irreps = e3nn.Irreps.spherical_harmonics(max_ell)[1:] # discard 0e
if interaction_irreps == "o3_restricted":
self.interaction_irreps = e3nn.Irreps.spherical_harmonics(max_ell)
elif interaction_irreps == "o3_full":
self.interaction_irreps = e3nn.Irreps(e3nn.Irrep.iterator(max_ell))
else:
self.interaction_irreps = e3nn.Irreps(interaction_irreps)
self.r_max = r_max
self.correlation = correlation
self.avg_num_neighbors = avg_num_neighbors
self.epsilon = epsilon
self.readout_mlp_irreps = readout_mlp_irreps
self.activation = gate
self.num_interactions = num_interactions
self.output_irreps = output_irreps
self.num_species = num_species
self.symmetric_tensor_product_basis = symmetric_tensor_product_basis
self.off_diagonal = off_diagonal
# Embeddings
self.node_embedding = node_embedding(
self.num_species, self.num_features * self.hidden_irreps
)
self.radial_embedding = RadialEmbeddingBlock(
r_max=r_max,
avg_r_min=avg_r_min,
basis_functions=radial_basis,
envelope_function=radial_envelope,
)
def __call__(
self,
vectors: jnp.ndarray, # [n_edges, 3]
node_specie: jnp.ndarray, # [n_nodes] int between 0 and num_species-1
senders: jnp.ndarray, # [n_edges]
receivers: jnp.ndarray, # [n_edges]
) -> e3nn.IrrepsArray:
assert vectors.ndim == 2 and vectors.shape[1] == 3
assert node_specie.ndim == 1
assert senders.ndim == 1 and receivers.ndim == 1
assert vectors.shape[0] == senders.shape[0] == receivers.shape[0]
# Embeddings
node_feats = self.node_embedding(node_specie).astype(
vectors.dtype
) # [n_nodes, feature * irreps]
node_feats = profile("embedding: node_feats", node_feats)
lengths = safe_norm(vectors, axis=-1)
edge_attrs = e3nn.concatenate(
[
self.radial_embedding(lengths),
e3nn.spherical_harmonics(
self.sh_irreps,
vectors / lengths[..., None],
normalize=False,
normalization="component",
),
]
) # [n_edges, irreps]
edge_attrs = profile("embedding: edge_attrs", edge_attrs)
# Interactions
outputs = []
for i in range(self.num_interactions):
first = i == 0
last = i == self.num_interactions - 1
hidden_irreps = (
self.hidden_irreps
if not last
else self.hidden_irreps.filter(self.output_irreps)
)
node_outputs, node_feats = MACELayer(
first=first,
last=last,
num_features=self.num_features,
interaction_irreps=self.interaction_irreps,
hidden_irreps=hidden_irreps,
avg_num_neighbors=self.avg_num_neighbors,
activation=self.activation,
num_species=self.num_species,
epsilon=self.epsilon,
correlation=self.correlation,
output_irreps=self.output_irreps,
readout_mlp_irreps=self.readout_mlp_irreps,
symmetric_tensor_product_basis=self.symmetric_tensor_product_basis,
off_diagonal=self.off_diagonal,
name=f"layer_{i}",
)(
node_feats,
node_specie,
edge_attrs,
senders,
receivers,
)
outputs += [node_outputs] # list of [n_nodes, output_irreps]
return e3nn.stack(outputs, axis=1) # [n_nodes, num_interactions, output_irreps]
class MACELayer(hk.Module):
def __init__(
self,
*,
first: bool,
last: bool,
num_features: int,
interaction_irreps: e3nn.Irreps,
hidden_irreps: e3nn.Irreps,
activation: Callable,
num_species: int,
epsilon: Optional[float],
name: Optional[str],
# InteractionBlock:
avg_num_neighbors: float,
# EquivariantProductBasisBlock:
correlation: int,
symmetric_tensor_product_basis: bool,
off_diagonal: bool,
# ReadoutBlock:
output_irreps: e3nn.Irreps,
readout_mlp_irreps: e3nn.Irreps,
) -> None:
super().__init__(name=name)
self.first = first
self.last = last
self.num_features = num_features
self.interaction_irreps = interaction_irreps
self.hidden_irreps = hidden_irreps
self.avg_num_neighbors = avg_num_neighbors
self.activation = activation
self.num_species = num_species
self.epsilon = epsilon
self.correlation = correlation
self.output_irreps = output_irreps
self.readout_mlp_irreps = readout_mlp_irreps
self.symmetric_tensor_product_basis = symmetric_tensor_product_basis
self.off_diagonal = off_diagonal
def __call__(
self,
node_feats: e3nn.IrrepsArray, # [n_nodes, irreps]
node_specie: jnp.ndarray, # [n_nodes] int between 0 and num_species-1
edge_attrs: e3nn.IrrepsArray, # [n_edges, irreps]
senders: jnp.ndarray, # [n_edges]
receivers: jnp.ndarray, # [n_edges]
):
node_feats = profile(f"{self.name}: node_feats", node_feats)
sc = None
if not self.first:
sc = e3nn.haiku.Linear(
self.num_features * self.hidden_irreps,
num_indexed_weights=self.num_species,
name="skip_tp",
)(
node_specie, node_feats
) # [n_nodes, feature * hidden_irreps]
sc = profile(f"{self.name}: self-connexion", sc)
node_feats = InteractionBlock(
target_irreps=self.num_features * self.interaction_irreps,
avg_num_neighbors=self.avg_num_neighbors,
activation=self.activation,
)(
node_feats=node_feats,
edge_attrs=edge_attrs,
receivers=receivers,
senders=senders,
)
if self.epsilon is not None:
node_feats *= self.epsilon
else:
node_feats /= jnp.sqrt(self.avg_num_neighbors)
node_feats = profile(f"{self.name}: node_feats after interaction", node_feats)
if self.first:
# Selector TensorProduct
node_feats = e3nn.haiku.Linear(
self.num_features * self.interaction_irreps,
num_indexed_weights=self.num_species,
name="skip_tp_first",
)(node_specie, node_feats)
node_feats = profile(
f"{self.name}: node_feats after skip_tp_first", node_feats
)
sc = None
node_feats = EquivariantProductBasisBlock(
target_irreps=self.num_features * self.hidden_irreps,
correlation=self.correlation,
num_species=self.num_species,
symmetric_tensor_product_basis=self.symmetric_tensor_product_basis,
off_diagonal=self.off_diagonal,
)(node_feats=node_feats, node_specie=node_specie)
node_feats = profile(f"{self.name}: node_feats after tensor power", node_feats)
if sc is not None:
node_feats = node_feats + sc # [n_nodes, feature * hidden_irreps]
if not self.last:
node_outputs = LinearReadoutBlock(self.output_irreps)(
node_feats
) # [n_nodes, output_irreps]
else: # Non linear readout for last layer
node_outputs = NonLinearReadoutBlock(
self.readout_mlp_irreps,
self.output_irreps,
activation=self.activation,
)(
node_feats
) # [n_nodes, output_irreps]
node_outputs = profile(f"{self.name}: node_outputs", node_outputs)
return node_outputs, node_feats