diff --git a/docs/api.rst b/docs/api.rst index a69a656c7..f2c147089 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -41,7 +41,6 @@ Utilities .. autofunction:: torchani.utils.map2central .. autoclass:: torchani.utils.ChemicalSymbolsToInts :members: -.. autofunction:: torchani.utils.hessian .. autofunction:: torchani.utils.vibrational_analysis .. autofunction:: torchani.utils.get_atomic_masses diff --git a/examples/jit.py b/examples/jit.py index ce2a7fd77..bbf9b6cad 100644 --- a/examples/jit.py +++ b/examples/jit.py @@ -69,7 +69,7 @@ # # - uses double as dtype instead of float # - don't care about periodic boundary condition -# - in addition to energies, allow returnsing optionally forces, and hessians +# - in addition to energies, allow returning optionally forces # - when indexing atom species, use its index in the periodic table instead of 0, 1, 2, 3, ... # # you could do the following: @@ -81,34 +81,28 @@ def __init__(self): # self.model = torchani.models.ANI1x(periodic_table_index=True)[0].double() # self.model = torchani.models.ANI1ccx(periodic_table_index=True).double() - def forward(self, species: Tensor, coordinates: Tensor, return_forces: bool = False, - return_hessians: bool = False) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: - if return_forces or return_hessians: + def forward(self, species: Tensor, coordinates: Tensor, return_forces: bool = False) -> Tuple[Tensor, Optional[Tensor]]: + if return_forces: coordinates.requires_grad_(True) energies = self.model((species, coordinates)).energies forces: Optional[Tensor] = None # noqa: E701 - hessians: Optional[Tensor] = None - if return_forces or return_hessians: - grad = torch.autograd.grad([energies.sum()], [coordinates], create_graph=return_hessians)[0] + if return_forces: + grad = torch.autograd.grad([energies.sum()], [coordinates])[0] assert grad is not None forces = -grad - if return_hessians: - hessians = torchani.utils.hessian(coordinates, forces=forces) - return energies, forces, hessians + return energies, forces custom_model = CustomModule() compiled_custom_model = torch.jit.script(custom_model) torch.jit.save(compiled_custom_model, 'compiled_custom_model.pt') loaded_compiled_custom_model = torch.jit.load('compiled_custom_model.pt') -energies, forces, hessians = custom_model(species, coordinates, True, True) -energies_jit, forces_jit, hessians_jit = loaded_compiled_custom_model(species, coordinates, True, True) +energies, forces = custom_model(species, coordinates, True) +energies_jit, forces_jit = loaded_compiled_custom_model(species, coordinates, True) print('Energy, eager mode vs loaded jit:', energies.item(), energies_jit.item()) print() print('Force, eager mode vs loaded jit:\n', forces.squeeze(0), '\n', forces_jit.squeeze(0)) print() -torch.set_printoptions(sci_mode=False, linewidth=1000) -print('Hessian, eager mode vs loaded jit:\n', hessians.squeeze(0), '\n', hessians_jit.squeeze(0)) diff --git a/examples/vibration_analysis.py b/examples/vibration_analysis.py index 0b41af8ce..7df98ab51 100644 --- a/examples/vibration_analysis.py +++ b/examples/vibration_analysis.py @@ -47,18 +47,12 @@ masses = torchani.utils.get_atomic_masses(species) ############################################################################### -# To do vibration analysis, we first need to generate a graph that computes -# energies from species and coordinates. The code to generate a graph of energy -# is the same as the code to compute energy: -energies = model((species, coordinates)).energies +# We can use :func:`torch.autograd.functional.hessian` to compute hessian: +hessian = torch.autograd.functional.hessian(lambda x: model((species, x)).energies, coordinates) ############################################################################### -# We can now use the energy graph to compute analytical Hessian matrix: -hessian = torchani.utils.hessian(coordinates, energies=energies) - -############################################################################### -# The Hessian matrix should have shape `(1, 9, 9)`, where 1 means there is only -# one molecule to compute, 9 means `3 atoms * 3D space = 9 degree of freedom`. +# The Hessian matrix should have shape `(1, 3, 3, 1, 3, 3)`, where 1 means there +# is only one molecule to compute, 3 means 3 atoms and 3D space. print(hessian.shape) ############################################################################### diff --git a/tests/test_utils.py b/tests/test_utils.py index 188de67e0..406bf0b55 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,4 @@ import unittest -import torch import torchani @@ -10,9 +9,6 @@ def testChemicalSymbolsToInts(self): self.assertEqual(len(str2i), 6) self.assertListEqual(str2i('BACCC').tolist(), [1, 0, 2, 2, 2]) - def testHessianJIT(self): - torch.jit.script(torchani.utils.hessian) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_vibrational.py b/tests/test_vibrational.py index c0f7f28a5..ab69fff5e 100644 --- a/tests/test_vibrational.py +++ b/tests/test_vibrational.py @@ -39,8 +39,7 @@ def testVibrationalWavenumbers(self): # compute vibrational by torchani species = model.species_to_tensor(molecule.get_chemical_symbols()).unsqueeze(0) coordinates = torch.from_numpy(molecule.get_positions()).unsqueeze(0).requires_grad_(True) - _, energies = model((species, coordinates)) - hessian = torchani.utils.hessian(coordinates, energies=energies) + hessian = torch.autograd.functional.hessian(lambda x: model((species, x)).energies, coordinates) freq2, modes2, _, _ = torchani.utils.vibrational_analysis(masses[species], hessian) freq2 = freq2[6:].float() modes2 = modes2[6:] diff --git a/torchani/utils.py b/torchani/utils.py index f8a0aa089..e4fc49a56 100644 --- a/torchani/utils.py +++ b/torchani/utils.py @@ -241,43 +241,6 @@ def __len__(self): return len(self.rev_species) -def _get_derivatives_not_none(x: Tensor, y: Tensor, retain_graph: Optional[bool] = None, create_graph: bool = False) -> Tensor: - ret = torch.autograd.grad([y.sum()], [x], retain_graph=retain_graph, create_graph=create_graph)[0] - assert ret is not None - return ret - - -def hessian(coordinates: Tensor, energies: Optional[Tensor] = None, forces: Optional[Tensor] = None) -> Tensor: - """Compute analytical hessian from the energy graph or force graph. - - Arguments: - coordinates (:class:`torch.Tensor`): Tensor of shape `(molecules, atoms, 3)` - energies (:class:`torch.Tensor`): Tensor of shape `(molecules,)`, if specified, - then `forces` must be `None`. This energies must be computed from - `coordinates` in a graph. - forces (:class:`torch.Tensor`): Tensor of shape `(molecules, atoms, 3)`, if specified, - then `energies` must be `None`. This forces must be computed from - `coordinates` in a graph. - - Returns: - :class:`torch.Tensor`: Tensor of shape `(molecules, 3A, 3A)` where A is the number of - atoms in each molecule - """ - if energies is None and forces is None: - raise ValueError('Energies or forces must be specified') - if energies is not None and forces is not None: - raise ValueError('Energies or forces can not be specified at the same time') - if forces is None: - assert energies is not None - forces = -_get_derivatives_not_none(coordinates, energies, create_graph=True) - flattened_force = forces.flatten(start_dim=1) - force_components = flattened_force.unbind(dim=1) - return -torch.stack([ - _get_derivatives_not_none(coordinates, f, retain_graph=True).flatten(start_dim=1) - for f in force_components - ], dim=1) - - class FreqsModes(NamedTuple): freqs: Tensor modes: Tensor @@ -317,6 +280,8 @@ def vibrational_analysis(masses, hessian, mode_type='MDU', unit='cm^-1'): raise ValueError('Only meV and cm^-1 are supported right now') assert hessian.shape[0] == 1, 'Currently only supporting computing one molecule a time' + degree_of_freedom = hessian.shape[1] * hessian.shape[2] + hessian = hessian.reshape(1, degree_of_freedom, degree_of_freedom) # Solving the eigenvalue problem: Hq = w^2 * T q # where H is the Hessian matrix, q is the normal coordinates, # T = diag(m1, m1, m1, m2, m2, m2, ....) is the mass @@ -423,6 +388,5 @@ def get_atomic_masses(species): """.strip().split() -__all__ = ['pad_atomic_properties', 'present_species', 'hessian', - 'vibrational_analysis', 'strip_redundant_padding', - 'ChemicalSymbolsToInts', 'get_atomic_masses'] +__all__ = ['pad_atomic_properties', 'present_species', 'vibrational_analysis', + 'strip_redundant_padding', 'ChemicalSymbolsToInts', 'get_atomic_masses']