Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow appending extra values to embedding vector #289

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,14 @@ def test_gradients(model_name):
torch.autograd.gradcheck(
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
)


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("use_batch", [True, False])
def test_extra_embedding(model_name, use_batch):
z, pos, batch = create_example_batch()
args = load_example_args(model_name, prior_model=None)
args["extra_embedding"] = ["atomic", "global"]
model = create_model(args)
batch = batch if use_batch else None
model(z, pos, batch=batch, extra_args={'atomic':torch.rand(6), 'global':torch.rand(2)})
21 changes: 19 additions & 2 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def create_model(args, prior_model=None, mean=None, std=None):
args["static_shapes"] = False
if "vector_cutoff" not in args:
args["vector_cutoff"] = False
if "extra_embedding" not in args:
extra_embedding = None
elif isinstance(args["extra_embedding"], str):
extra_embedding = [args["extra_embedding"]]
else:
extra_embedding = args["extra_embedding"]

shared_args = dict(
hidden_channels=args["embedding_dimension"],
Expand All @@ -57,6 +63,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
else None
),
dtype=dtype,
extra_embedding=extra_embedding
)

# representation network
Expand Down Expand Up @@ -370,7 +377,7 @@ def forward(
If this is omitted, periodic boundary conditions are not applied.
q (Tensor, optional): Atomic charges in the molecule. Shape: (N,).
s (Tensor, optional): Atomic spins in the molecule. Shape: (N,).
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model.
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the model.
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise.
Expand All @@ -380,9 +387,19 @@ def forward(

if self.derivative:
pos.requires_grad_(True)
if self.representation_model.extra_embedding is None:
extra_embedding_args = None
else:
extra = []
for arg in self.representation_model.extra_embedding:
t = extra_args[arg]
if t.shape != z.shape:
t = t[batch]
extra.append(t)
extra_embedding_args = tuple(extra)
# run the potentially wrapped representation model
x, v, z, pos, batch = self.representation_model(
z, pos, batch, box=box, q=q, s=s
z, pos, batch, box=box, q=q, s=s, extra_embedding_args=extra_embedding_args
)
# apply the output network
x = self.output_model.pre_reduce(x, v, z, pos, batch)
Expand Down
24 changes: 21 additions & 3 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ class TensorNet(nn.Module):
(default: :obj:`True`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)
extra_embedding (tuple, optional): the names of extra fields to append to the embedding
vector for each atom
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -139,6 +142,7 @@ def __init__(
check_errors=True,
dtype=torch.float32,
box_vecs=None,
extra_embedding=None
):
super(TensorNet, self).__init__()

Expand All @@ -163,6 +167,7 @@ def __init__(
self.activation = activation
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
self.extra_embedding = extra_embedding
act_class = act_class_mapping[activation]
self.distance_expansion = rbf_class_mapping[rbf_type](
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf
Expand All @@ -176,6 +181,7 @@ def __init__(
trainable_rbf,
max_z,
dtype,
extra_embedding
)

self.layers = nn.ModuleList()
Expand Down Expand Up @@ -228,6 +234,7 @@ def forward(
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_embedding_args: Optional[Tuple[Tensor]] = None
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
# Obtain graph, with distances and relative position vectors
edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
Expand Down Expand Up @@ -258,7 +265,7 @@ def forward(
# Normalizing edge vectors by their length can result in NaNs, breaking Autograd.
# I avoid dividing by zero by setting the weight of self edges and self loops to 1
edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1)
X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr)
X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr, extra_embedding_args)
for layer in self.layers:
X = layer(X, edge_index, edge_weight, edge_attr, q)
I, A, S = decompose_tensor(X)
Expand Down Expand Up @@ -287,6 +294,7 @@ def __init__(
trainable_rbf=False,
max_z=128,
dtype=torch.float32,
extra_embedding=None
):
super(TensorEmbedding, self).__init__()

Expand All @@ -297,6 +305,10 @@ def __init__(
self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper)
self.max_z = max_z
self.emb = nn.Embedding(max_z, hidden_channels, dtype=dtype)
if extra_embedding is not None:
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
else:
self.reshape_embedding = None
self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype)
self.act = activation()
self.linears_tensor = nn.ModuleList()
Expand All @@ -319,15 +331,20 @@ def reset_parameters(self):
self.distance_proj2.reset_parameters()
self.distance_proj3.reset_parameters()
self.emb.reset_parameters()
if self.reshape_embedding is not None:
self.reshape_embedding.reset_parameters()
self.emb2.reset_parameters()
for linear in self.linears_tensor:
linear.reset_parameters()
for linear in self.linears_scalar:
linear.reset_parameters()
self.init_norm.reset_parameters()

def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor) -> Tensor:
def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor, extra_embedding_args: Optional[Tuple[Tensor]]) -> Tensor:
Z = self.emb(z)
if self.reshape_embedding is not None:
Z = torch.cat((Z,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
Z = self.reshape_embedding(Z)
Zij = self.emb2(
Z.index_select(0, edge_index.t().reshape(-1)).view(
-1, self.hidden_channels * 2
Expand Down Expand Up @@ -362,8 +379,9 @@ def forward(
edge_weight: Tensor,
edge_vec_norm: Tensor,
edge_attr: Tensor,
extra_embedding_args: Optional[Tuple[Tensor]]
) -> Tensor:
Zij = self._get_atomic_number_message(z, edge_index)
Zij = self._get_atomic_number_message(z, edge_index, extra_embedding_args)
Iij, Aij, Sij = self._get_tensor_messages(
Zij, edge_weight, edge_vec_norm, edge_attr
)
Expand Down
16 changes: 15 additions & 1 deletion torchmdnet/models/torchmd_et.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ class TorchMD_ET(nn.Module):
(default: :obj:`False`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)

extra_embedding (tuple, optional): the names of extra fields to append to the embedding
vector for each atom
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -102,6 +104,7 @@ def __init__(
box_vecs=None,
vector_cutoff=False,
dtype=torch.float32,
extra_embedding=None
):
super(TorchMD_ET, self).__init__()

Expand Down Expand Up @@ -133,10 +136,15 @@ def __init__(
self.cutoff_upper = cutoff_upper
self.max_z = max_z
self.dtype = dtype
self.extra_embedding = extra_embedding

act_class = act_class_mapping[activation]

self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)
if extra_embedding is not None:
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
else:
self.reshape_embedding = None

self.distance = OptimizedDistance(
cutoff_lower,
Expand Down Expand Up @@ -181,6 +189,8 @@ def __init__(

def reset_parameters(self):
self.embedding.reset_parameters()
if self.reshape_embedding is not None:
self.reshape_embedding.reset_parameters()
self.distance_expansion.reset_parameters()
if self.neighbor_embedding is not None:
self.neighbor_embedding.reset_parameters()
Expand All @@ -196,8 +206,12 @@ def forward(
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_embedding_args: Optional[Tuple[Tensor]] = None
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
x = self.embedding(z)
if self.reshape_embedding is not None:
x = torch.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
x = self.reshape_embedding(x)

edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
# This assert must be here to convince TorchScript that edge_vec is not None
Expand Down
16 changes: 15 additions & 1 deletion torchmdnet/models/torchmd_gn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ class TorchMD_GN(nn.Module):
(default: :obj:`None`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)

extra_embedding (tuple, optional): the names of extra fields to append to the embedding
vector for each atom
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -107,6 +109,7 @@ def __init__(
aggr="add",
dtype=torch.float32,
box_vecs=None,
extra_embedding=None
):
super(TorchMD_GN, self).__init__()

Expand Down Expand Up @@ -136,10 +139,15 @@ def __init__(
self.cutoff_upper = cutoff_upper
self.max_z = max_z
self.aggr = aggr
self.extra_embedding = extra_embedding

act_class = act_class_mapping[activation]

self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)
if extra_embedding is not None:
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
else:
self.reshape_embedding = None

self.distance = OptimizedDistance(
cutoff_lower,
Expand Down Expand Up @@ -184,6 +192,8 @@ def __init__(

def reset_parameters(self):
self.embedding.reset_parameters()
if self.reshape_embedding is not None:
self.reshape_embedding.reset_parameters()
self.distance_expansion.reset_parameters()
if self.neighbor_embedding is not None:
self.neighbor_embedding.reset_parameters()
Expand All @@ -198,8 +208,12 @@ def forward(
box: Optional[Tensor] = None,
s: Optional[Tensor] = None,
q: Optional[Tensor] = None,
extra_embedding_args: Optional[Tuple[Tensor]] = None
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
x = self.embedding(z)
if self.reshape_embedding is not None:
x = torch.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
x = self.reshape_embedding(x)

edge_index, edge_weight, _ = self.distance(pos, batch, box)
edge_attr = self.distance_expansion(edge_weight)
Expand Down
16 changes: 15 additions & 1 deletion torchmdnet/models/torchmd_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ class TorchMD_T(nn.Module):
(default: :obj:`None`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)

extra_embedding (tuple, optional): the names of extra fields to append to the embedding
vector for each atom
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -98,6 +100,7 @@ def __init__(
max_num_neighbors=32,
dtype=torch.float,
box_vecs=None,
extra_embedding=None
):
super(TorchMD_T, self).__init__()

Expand All @@ -124,11 +127,16 @@ def __init__(
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
self.max_z = max_z
self.extra_embedding = extra_embedding

act_class = act_class_mapping[activation]
attn_act_class = act_class_mapping[attn_activation]

self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)
if extra_embedding is not None:
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
else:
self.reshape_embedding = None

self.distance = OptimizedDistance(
cutoff_lower,
Expand Down Expand Up @@ -177,6 +185,8 @@ def __init__(

def reset_parameters(self):
self.embedding.reset_parameters()
if self.reshape_embedding is not None:
self.reshape_embedding.reset_parameters()
self.distance_expansion.reset_parameters()
if self.neighbor_embedding is not None:
self.neighbor_embedding.reset_parameters()
Expand All @@ -192,8 +202,12 @@ def forward(
box: Optional[Tensor] = None,
s: Optional[Tensor] = None,
q: Optional[Tensor] = None,
extra_embedding_args: Optional[Tuple[Tensor]] = None
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
x = self.embedding(z)
if self.reshape_embedding is not None:
x = torch.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
x = self.reshape_embedding(x)

edge_index, edge_weight, _ = self.distance(pos, batch, box)
edge_attr = self.distance_expansion(edge_weight)
Expand Down
5 changes: 5 additions & 0 deletions torchmdnet/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, model):

super().__init__()
self.model = model
self.extra_embedding = model.extra_embedding

self.neighbors = CFConvNeighbors(self.model.cutoff_upper)

Expand All @@ -58,12 +59,16 @@ def forward(
box: Optional[pt.Tensor] = None,
q: Optional[pt.Tensor] = None,
s: Optional[pt.Tensor] = None,
extra_embedding_args: Optional[Tuple[pt.Tensor]] = None
) -> Tuple[pt.Tensor, Optional[pt.Tensor], pt.Tensor, pt.Tensor, pt.Tensor]:

assert pt.all(batch == 0)
assert box is None, "Box is not supported"

x = self.model.embedding(z)
if self.model.reshape_embedding is not None:
x = pt.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
x = self.model.reshape_embedding(x)

self.neighbors.build(pos)
for inter, conv in zip(self.model.interactions, self.convs):
Expand Down
1 change: 1 addition & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_argparse():
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.')
parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state. Set this to True if your dataset contains spin states and you want them passed down to the model.')
parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension')
parser.add_argument('--extra-embedding', type=str, default=None, help='Extra fields of the dataset to pass to the model and append to the embedding vector.', action="extend", nargs="*")
parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model')
parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model')
parser.add_argument('--activation', type=str, default='silu', choices=list(act_class_mapping.keys()), help='Activation function')
Expand Down
Loading