Skip to content

Commit

Permalink
Adding PAINNStack (#273)
Browse files Browse the repository at this point in the history
  • Loading branch information
RylieWeaver authored Oct 10, 2024
1 parent abd07fb commit ab8891f
Show file tree
Hide file tree
Showing 5 changed files with 337 additions and 3 deletions.
311 changes: 311 additions & 0 deletions hydragnn/models/PAINNStack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
##############################################################################
# Copyright (c) 2024, Oak Ridge National Laboratory #
# All rights reserved. #
# #
# This file is part of HydraGNN and is distributed under a BSD 3-clause #
# license. For the licensing terms see the LICENSE file in the top-level #
# directory. #
# #
# SPDX-License-Identifier: BSD-3-Clause #
##############################################################################

# Adapted From the Following:
# Github: https://github.com/nityasagarjena/PaiNN-model/blob/main/PaiNN/model.py
# Paper: https://arxiv.org/pdf/2102.03150


import torch
from torch import nn
from torch_geometric import nn as geom_nn
from torch.utils.checkpoint import checkpoint

from .Base import Base


class PAINNStack(Base):
"""
Generates angles, distances, to/from indices, radial basis
functions and spherical basis functions for learning.
"""

def __init__(
self,
# edge_dim: int, # To-Do: Add edge_features
num_radial: int,
radius: float,
*args,
**kwargs
):
# self.edge_dim = edge_dim
self.num_radial = num_radial
self.radius = radius

super().__init__(*args, **kwargs)

def _init_conv(self):
last_layer = 1 == self.num_conv_layers
self.graph_convs.append(self.get_conv(self.input_dim, self.hidden_dim))
self.feature_layers.append(nn.Identity())
for i in range(self.num_conv_layers - 1):
last_layer = i == self.num_conv_layers - 2
conv = self.get_conv(self.hidden_dim, self.hidden_dim, last_layer)
self.graph_convs.append(conv)
self.feature_layers.append(nn.Identity())

def get_conv(self, input_dim, output_dim, last_layer=False):
hidden_dim = output_dim if input_dim == 1 else input_dim
assert (
hidden_dim > 1
), "PainnNet requires more than one hidden dimension between input_dim and output_dim."
self_inter = PainnMessage(
node_size=input_dim, edge_size=self.num_radial, cutoff=self.radius
)
cross_inter = PainnUpdate(node_size=input_dim, last_layer=last_layer)
"""
The following linear layers are to get the correct sizing of embeddings. This is
necessary to use the hidden_dim, output_dim of HYDRAGNN's stacked conv layers correctly
because node_scalar and node-vector are updated through a sum.
"""
node_embed_out = nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.Tanh(),
nn.Linear(output_dim, output_dim),
) # Tanh activation is necessary to prevent exploding gradients when learning from random signals in test_graphs.py
vec_embed_out = nn.Linear(input_dim, output_dim) if not last_layer else None

if not last_layer:
return geom_nn.Sequential(
"x, v, pos, edge_index, diff, dist",
[
(self_inter, "x, v, edge_index, diff, dist -> x, v"),
(cross_inter, "x, v -> x, v"),
(node_embed_out, "x -> x"),
(vec_embed_out, "v -> v"),
(lambda x, v, pos: [x, v, pos], "x, v, pos -> x, v, pos"),
],
)
else:
return geom_nn.Sequential(
"x, v, pos, edge_index, diff, dist",
[
(self_inter, "x, v, edge_index, diff, dist -> x, v"),
(
cross_inter,
"x, v -> x",
), # v is not updated in the last layer to avoid hanging gradients
(
node_embed_out,
"x -> x",
), # No need to embed down v because it's not used anymore
(lambda x, v, pos: [x, v, pos], "x, v, pos -> x, v, pos"),
],
)

def forward(self, data):
data, conv_args = self._conv_args(
data
) # Added v to data here (necessary for PAINN Stack)
x = data.x
v = data.v
pos = data.pos

### encoder part ####
for conv, feat_layer in zip(self.graph_convs, self.feature_layers):
if not self.conv_checkpointing:
c, v, pos = conv(x=x, v=v, pos=pos, **conv_args) # Added v here
else:
c, v, pos = checkpoint( # Added v here
conv, use_reentrant=False, x=x, v=v, pos=pos, **conv_args
)
x = self.activation_function(feat_layer(c))

#### multi-head decoder part####
# shared dense layers for graph level output
if data.batch is None:
x_graph = x.mean(dim=0, keepdim=True)
else:
x_graph = geom_nn.global_mean_pool(x, data.batch.to(x.device))
outputs = []
outputs_var = []
for head_dim, headloc, type_head in zip(
self.head_dims, self.heads_NN, self.head_type
):
if type_head == "graph":
x_graph_head = self.graph_shared(x_graph)
output_head = headloc(x_graph_head)
outputs.append(output_head[:, :head_dim])
outputs_var.append(output_head[:, head_dim:] ** 2)
else:
if self.node_NN_type == "conv":
for conv, batch_norm in zip(headloc[0::2], headloc[1::2]):
c, v, pos = conv(x=x, v=v, pos=pos, **conv_args)
c = batch_norm(c)
x = self.activation_function(c)
x_node = x
else:
x_node = headloc(x=x, batch=data.batch)
outputs.append(x_node[:, :head_dim])
outputs_var.append(x_node[:, head_dim:] ** 2)
if self.var_output:
return outputs, outputs_var
return outputs

def _conv_args(self, data):
assert (
data.pos is not None
), "PAINNNet requires node positions (data.pos) to be set."

# Calculate relative vectors and distances
i, j = data.edge_index[0], data.edge_index[1]
diff = data.pos[i] - data.pos[j]
dist = diff.pow(2).sum(dim=-1).sqrt()
norm_diff = diff / dist.unsqueeze(-1)

# Instantiate tensor to hold equivariant traits
v = torch.zeros(data.x.size(0), 3, data.x.size(1), device=data.x.device)
data.v = v

conv_args = {
"edge_index": data.edge_index.t().to(torch.long),
"diff": norm_diff,
"dist": dist,
}

return data, conv_args


class PainnMessage(nn.Module):
"""Message function"""

def __init__(self, node_size: int, edge_size: int, cutoff: float):
super().__init__()

self.node_size = node_size
self.edge_size = edge_size
self.cutoff = cutoff

self.scalar_message_mlp = nn.Sequential(
nn.Linear(node_size, node_size),
nn.SiLU(),
nn.Linear(node_size, node_size * 3),
)

self.filter_layer = nn.Linear(edge_size, node_size * 3)

def forward(self, node_scalar, node_vector, edge, edge_diff, edge_dist):
# remember to use v_j, s_j but not v_i, s_i
filter_weight = self.filter_layer(
sinc_expansion(edge_dist, self.edge_size, self.cutoff)
)
filter_weight = filter_weight * cosine_cutoff(edge_dist, self.cutoff).unsqueeze(
-1
)
scalar_out = self.scalar_message_mlp(node_scalar)
filter_out = filter_weight * scalar_out[edge[:, 1]]

gate_state_vector, gate_edge_vector, message_scalar = torch.split(
filter_out,
self.node_size,
dim=1,
)

# num_pairs * 3 * node_size, num_pairs * node_size
message_vector = node_vector[edge[:, 1]] * gate_state_vector.unsqueeze(1)
edge_vector = gate_edge_vector.unsqueeze(1) * (
edge_diff / edge_dist.unsqueeze(-1)
).unsqueeze(-1)
message_vector = message_vector + edge_vector

# sum message
residual_scalar = torch.zeros_like(node_scalar)
residual_vector = torch.zeros_like(node_vector)
residual_scalar.index_add_(0, edge[:, 0], message_scalar)
residual_vector.index_add_(0, edge[:, 0], message_vector)

# new node state
new_node_scalar = node_scalar + residual_scalar
new_node_vector = node_vector + residual_vector

return new_node_scalar, new_node_vector


class PainnUpdate(nn.Module):
"""Update function"""

def __init__(self, node_size: int, last_layer=False):
super().__init__()

self.update_U = nn.Linear(node_size, node_size)
self.update_V = nn.Linear(node_size, node_size)
self.last_layer = last_layer

if not self.last_layer:
self.update_mlp = nn.Sequential(
nn.Linear(node_size * 2, node_size),
nn.SiLU(),
nn.Linear(node_size, node_size * 3),
)
else:
self.update_mlp = nn.Sequential(
nn.Linear(node_size * 2, node_size),
nn.SiLU(),
nn.Linear(node_size, node_size * 2),
)

def forward(self, node_scalar, node_vector):
Uv = self.update_U(node_vector)
Vv = self.update_V(node_vector)

Vv_norm = torch.linalg.norm(Vv, dim=1)
mlp_input = torch.cat((Vv_norm, node_scalar), dim=1)
mlp_output = self.update_mlp(mlp_input)

if not self.last_layer:
a_vv, a_sv, a_ss = torch.split(
mlp_output,
node_vector.shape[-1],
dim=1,
)

delta_v = a_vv.unsqueeze(1) * Uv
inner_prod = torch.sum(Uv * Vv, dim=1)
delta_s = a_sv * inner_prod + a_ss

return node_scalar + delta_s, node_vector + delta_v
else:
a_sv, a_ss = torch.split(
mlp_output,
node_vector.shape[-1],
dim=1,
)

inner_prod = torch.sum(Uv * Vv, dim=1)
delta_s = a_sv * inner_prod + a_ss

return node_scalar + delta_s


def sinc_expansion(edge_dist: torch.Tensor, edge_size: int, cutoff: float):
"""
Calculate sinc radial basis function:
sin(n * pi * d / d_cut) / d
"""
n = torch.arange(edge_size, device=edge_dist.device) + 1
return torch.sin(
edge_dist.unsqueeze(-1) * n * torch.pi / cutoff
) / edge_dist.unsqueeze(-1)


def cosine_cutoff(edge_dist: torch.Tensor, cutoff: float):
"""
Calculate cutoff value based on distance.
This uses the cosine Behler-Parinello cutoff function:
f(d) = 0.5 * (cos(pi * d / d_cut) + 1) for d < d_cut and 0 otherwise
"""
return torch.where(
edge_dist < cutoff,
0.5 * (torch.cos(torch.pi * edge_dist / cutoff) + 1),
torch.tensor(0.0, device=edge_dist.device, dtype=edge_dist.dtype),
)
20 changes: 20 additions & 0 deletions hydragnn/models/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from hydragnn.models.DIMEStack import DIMEStack
from hydragnn.models.EGCLStack import EGCLStack
from hydragnn.models.PNAEqStack import PNAEqStack
from hydragnn.models.PAINNStack import PAINNStack

from hydragnn.utils.distributed import get_device
from hydragnn.utils.profiling_and_tracing.time_utils import Timer
Expand Down Expand Up @@ -331,6 +332,25 @@ def create_model(
num_nodes=num_nodes,
)

elif model_type == "PAINN":
model = PAINNStack(
# edge_dim, # To-do add edge_features
num_radial,
radius,
input_dim,
hidden_dim,
output_dim,
output_type,
output_heads,
activation_function,
loss_function_type,
equivariance,
loss_weights=task_weights,
freeze_conv=freeze_conv,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
)

elif model_type == "PNAEq":
assert pna_deg is not None, "PNAEq requires degree input."
model = PNAEqStack(
Expand Down
2 changes: 1 addition & 1 deletion hydragnn/utils/input_config_parsing/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def update_config(config, train_loader, val_loader, test_loader):


def update_config_equivariance(config):
equivariant_models = ["EGNN", "SchNet", "PNAEq"]
equivariant_models = ["EGNN", "SchNet", "PNAEq", "PAINN"]
if "equivariance" in config and config["equivariance"]:
assert (
config["model_type"] in equivariant_models
Expand Down
2 changes: 1 addition & 1 deletion tests/test_forces_equivariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@pytest.mark.parametrize("example", ["LennardJones"])
@pytest.mark.parametrize(
"model_type", ["SchNet", "EGNN", "DimeNet", "PNAPlus", "PNAEq"]
"model_type", ["SchNet", "EGNN", "DimeNet", "PAINN", "PNAPlus"]
)
@pytest.mark.mpi_skip()
def pytest_examples(example, model_type):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def unittest_train_model(
"DimeNet": [0.50, 0.50],
"EGNN": [0.20, 0.20],
"PNAEq": [0.60, 0.60],
"PAINN": [0.60, 0.60],
}
if use_lengths and ("vector" not in ci_input):
thresholds["CGCNN"] = [0.175, 0.175]
Expand Down Expand Up @@ -208,6 +209,7 @@ def unittest_train_model(
"DimeNet",
"EGNN",
"PNAEq",
"PAINN",
],
)
@pytest.mark.parametrize("ci_input", ["ci.json", "ci_multihead.json"])
Expand All @@ -222,7 +224,7 @@ def pytest_train_model_lengths(model_type, overwrite_data=False):


# Test across equivariant models
@pytest.mark.parametrize("model_type", ["EGNN", "SchNet", "PNAEq"])
@pytest.mark.parametrize("model_type", ["EGNN", "SchNet", "PNAEq", "PAINN"])
def pytest_train_equivariant_model(model_type, overwrite_data=False):
unittest_train_model(model_type, "ci_equivariant.json", False, overwrite_data)

Expand All @@ -246,6 +248,7 @@ def pytest_train_model_vectoroutput(model_type, overwrite_data=False):
"DimeNet",
"EGNN",
"PNAEq",
"PAINN",
],
)
def pytest_train_model_conv_head(model_type, overwrite_data=False):
Expand Down

0 comments on commit ab8891f

Please sign in to comment.