Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use switching function for Coulomb prior #287

Merged
merged 3 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions docs/source/priors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ It is possible to configure more than one prior in this way:

.. code:: yaml

prior_model:
Atomref: {} # No additional arguments
Coulomb:
alpha: 1
max_num_neighbors: 10
prior_model:
Atomref: {} # No additional arguments
Coulomb:
lower_switch_distance: 4
upper_switch_distance: 8
max_num_neighbors: 128



Expand Down
12 changes: 9 additions & 3 deletions tests/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,25 @@ def test_coulomb(dtype):
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
distance_scale = 1e-9 # Convert nm to meters
energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules
alpha = 1.8
lower_switch_distance = 0.9
upper_switch_distance = 1.3

# Use the Coulomb class to compute the energy.

coulomb = Coulomb(alpha, 5, distance_scale=distance_scale, energy_scale=energy_scale)
coulomb = Coulomb(lower_switch_distance, upper_switch_distance, 5, distance_scale=distance_scale, energy_scale=energy_scale)
energy = coulomb.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types), extra_args={'partial_charges':charge})[0]

# Compare to the expected value.

def compute_interaction(pos1, pos2, z1, z2):
delta = pos1-pos2
r = torch.sqrt(torch.dot(delta, delta))
return torch.erf(alpha*r)*138.935*z1*z2/r
if r < lower_switch_distance:
return 0
energy = 138.935*z1*z2/r
if r < upper_switch_distance:
energy *= 0.5-0.5*torch.cos(torch.pi*(r-lower_switch_distance)/(upper_switch_distance-lower_switch_distance))
return energy

expected = 0
for i in range(len(pos)):
Expand Down
22 changes: 14 additions & 8 deletions torchmdnet/priors/coulomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from typing import Optional, Dict

class Coulomb(BasePrior):
"""This class implements a Coulomb potential, scaled by :math:`\\textrm{erf}(\\textrm{alpha}*r)` to reduce its
"""This class implements a Coulomb potential, scaled by a cosine switching function to reduce its
effect at short distances.

Parameters
----------
alpha : float
Scaling factor for the error function.
lower_switch_distance : float
distance below which the interaction strength is zero.
upper_switch_distance : float
distance above which the interaction has full strength
max_num_neighbors : int
Maximum number of neighbors per atom allowed.
distance_scale : float, optional
Expand All @@ -31,20 +33,22 @@ class Coulomb(BasePrior):
The Dataset used with this class must include a `partial_charges` field for each sample, and provide
`distance_scale` and `energy_scale` attributes if they are not explicitly passed as arguments.
"""
def __init__(self, alpha, max_num_neighbors, distance_scale=None, energy_scale=None, box_vecs=None, dataset=None):
def __init__(self, lower_switch_distance, upper_switch_distance, max_num_neighbors, distance_scale=None, energy_scale=None, box_vecs=None, dataset=None):
super(Coulomb, self).__init__()
if distance_scale is None:
distance_scale = dataset.distance_scale
if energy_scale is None:
energy_scale = dataset.energy_scale
self.distance = OptimizedDistance(0, torch.inf, max_num_pairs=-max_num_neighbors)
self.alpha = alpha
self.lower_switch_distance = lower_switch_distance
self.upper_switch_distance = upper_switch_distance
self.max_num_neighbors = max_num_neighbors
self.distance_scale = float(distance_scale)
self.energy_scale = float(energy_scale)
self.initial_box = box_vecs
def get_init_args(self):
return {'alpha': self.alpha,
return {'lower_switch_distance': self.lower_switch_distance,
'upper_switch_distance': self.upper_switch_distance,
'max_num_neighbors': self.max_num_neighbors,
'distance_scale': self.distance_scale,
'energy_scale': self.energy_scale,
Expand Down Expand Up @@ -78,14 +82,16 @@ def post_reduce(self, y, z, pos, batch, box: Optional[torch.Tensor] = None, extr
"""
# Convert to nm and calculate distance.
x = 1e9*self.distance_scale*pos
alpha = self.alpha/(1e9*self.distance_scale)
box = box if box is not None else self.initial_box
edge_index, distance, _ = self.distance(x, batch, box=box)

# Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair
# appears twice.
q = extra_args['partial_charges'][edge_index]
energy = torch.erf(alpha*distance)*q[0]*q[1]/distance
lower = torch.tensor(self.lower_switch_distance)
upper = torch.tensor(self.upper_switch_distance)
phase = (torch.max(lower, torch.min(upper, distance))-lower)/(upper-lower)
energy = (0.5-0.5*torch.cos(torch.pi*phase))*q[0]*q[1]/distance
energy = 0.5*(2.30707e-28/self.energy_scale/self.distance_scale)*scatter(energy, batch[edge_index[0]], dim=0, reduce="sum")
energy = energy.reshape(y.shape)
return y + energy
2 changes: 1 addition & 1 deletion torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_argparse():
# model architecture
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train')
parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model')
parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "alpha"=1}', action="extend", nargs="*")
parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "lower_switch_distance"=4, "upper_switch_distance"=8}', action="extend", nargs="*")

# architectural args
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.')
Expand Down
Loading