diff --git a/docs/source/priors.rst b/docs/source/priors.rst index 65c16b46e..1261c6b93 100644 --- a/docs/source/priors.rst +++ b/docs/source/priors.rst @@ -31,11 +31,12 @@ It is possible to configure more than one prior in this way: .. code:: yaml - prior_model: - Atomref: {} # No additional arguments - Coulomb: - alpha: 1 - max_num_neighbors: 10 + prior_model: + Atomref: {} # No additional arguments + Coulomb: + lower_switch_distance: 4 + upper_switch_distance: 8 + max_num_neighbors: 128 diff --git a/tests/test_priors.py b/tests/test_priors.py index 1c00af60d..61c2166d5 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -88,11 +88,12 @@ def test_coulomb(dtype): types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types distance_scale = 1e-9 # Convert nm to meters energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules - alpha = 1.8 + lower_switch_distance = 0.9 + upper_switch_distance = 1.3 # Use the Coulomb class to compute the energy. - coulomb = Coulomb(alpha, 5, distance_scale=distance_scale, energy_scale=energy_scale) + coulomb = Coulomb(lower_switch_distance, upper_switch_distance, 5, distance_scale=distance_scale, energy_scale=energy_scale) energy = coulomb.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types), extra_args={'partial_charges':charge})[0] # Compare to the expected value. @@ -100,7 +101,12 @@ def test_coulomb(dtype): def compute_interaction(pos1, pos2, z1, z2): delta = pos1-pos2 r = torch.sqrt(torch.dot(delta, delta)) - return torch.erf(alpha*r)*138.935*z1*z2/r + if r < lower_switch_distance: + return 0 + energy = 138.935*z1*z2/r + if r < upper_switch_distance: + energy *= 0.5-0.5*torch.cos(torch.pi*(r-lower_switch_distance)/(upper_switch_distance-lower_switch_distance)) + return energy expected = 0 for i in range(len(pos)): diff --git a/torchmdnet/priors/coulomb.py b/torchmdnet/priors/coulomb.py index 69943de1f..449e2c530 100644 --- a/torchmdnet/priors/coulomb.py +++ b/torchmdnet/priors/coulomb.py @@ -8,13 +8,15 @@ from typing import Optional, Dict class Coulomb(BasePrior): - """This class implements a Coulomb potential, scaled by :math:`\\textrm{erf}(\\textrm{alpha}*r)` to reduce its + """This class implements a Coulomb potential, scaled by a cosine switching function to reduce its effect at short distances. Parameters ---------- - alpha : float - Scaling factor for the error function. + lower_switch_distance : float + distance below which the interaction strength is zero. + upper_switch_distance : float + distance above which the interaction has full strength max_num_neighbors : int Maximum number of neighbors per atom allowed. distance_scale : float, optional @@ -31,20 +33,22 @@ class Coulomb(BasePrior): The Dataset used with this class must include a `partial_charges` field for each sample, and provide `distance_scale` and `energy_scale` attributes if they are not explicitly passed as arguments. """ - def __init__(self, alpha, max_num_neighbors, distance_scale=None, energy_scale=None, box_vecs=None, dataset=None): + def __init__(self, lower_switch_distance, upper_switch_distance, max_num_neighbors, distance_scale=None, energy_scale=None, box_vecs=None, dataset=None): super(Coulomb, self).__init__() if distance_scale is None: distance_scale = dataset.distance_scale if energy_scale is None: energy_scale = dataset.energy_scale self.distance = OptimizedDistance(0, torch.inf, max_num_pairs=-max_num_neighbors) - self.alpha = alpha + self.lower_switch_distance = lower_switch_distance + self.upper_switch_distance = upper_switch_distance self.max_num_neighbors = max_num_neighbors self.distance_scale = float(distance_scale) self.energy_scale = float(energy_scale) self.initial_box = box_vecs def get_init_args(self): - return {'alpha': self.alpha, + return {'lower_switch_distance': self.lower_switch_distance, + 'upper_switch_distance': self.upper_switch_distance, 'max_num_neighbors': self.max_num_neighbors, 'distance_scale': self.distance_scale, 'energy_scale': self.energy_scale, @@ -78,14 +82,16 @@ def post_reduce(self, y, z, pos, batch, box: Optional[torch.Tensor] = None, extr """ # Convert to nm and calculate distance. x = 1e9*self.distance_scale*pos - alpha = self.alpha/(1e9*self.distance_scale) box = box if box is not None else self.initial_box edge_index, distance, _ = self.distance(x, batch, box=box) # Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair # appears twice. q = extra_args['partial_charges'][edge_index] - energy = torch.erf(alpha*distance)*q[0]*q[1]/distance + lower = torch.tensor(self.lower_switch_distance) + upper = torch.tensor(self.upper_switch_distance) + phase = (torch.max(lower, torch.min(upper, distance))-lower)/(upper-lower) + energy = (0.5-0.5*torch.cos(torch.pi*phase))*q[0]*q[1]/distance energy = 0.5*(2.30707e-28/self.energy_scale/self.distance_scale)*scatter(energy, batch[edge_index[0]], dim=0, reduce="sum") energy = energy.reshape(y.shape) return y + energy diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 0b169f96c..a51cfe45f 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -74,7 +74,7 @@ def get_argparse(): # model architecture parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train') parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model') - parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "alpha"=1}', action="extend", nargs="*") + parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "lower_switch_distance"=4, "upper_switch_distance"=8}', action="extend", nargs="*") # architectural args parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.')