Skip to content

Commit

Permalink
Use PyTorch autograd's hessian (#532)
Browse files Browse the repository at this point in the history
* Use PyTorch autograd's hessian

* fix test

* save

* clean

* save

* save

* drop hessian from jit example
  • Loading branch information
zasdfgbnm authored Nov 13, 2020
1 parent ea51fad commit bd9d888
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 71 deletions.
1 change: 0 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ Utilities
.. autofunction:: torchani.utils.map2central
.. autoclass:: torchani.utils.ChemicalSymbolsToInts
:members:
.. autofunction:: torchani.utils.hessian
.. autofunction:: torchani.utils.vibrational_analysis
.. autofunction:: torchani.utils.get_atomic_masses

Expand Down
22 changes: 8 additions & 14 deletions examples/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
#
# - uses double as dtype instead of float
# - don't care about periodic boundary condition
# - in addition to energies, allow returnsing optionally forces, and hessians
# - in addition to energies, allow returning optionally forces
# - when indexing atom species, use its index in the periodic table instead of 0, 1, 2, 3, ...
#
# you could do the following:
Expand All @@ -81,34 +81,28 @@ def __init__(self):
# self.model = torchani.models.ANI1x(periodic_table_index=True)[0].double()
# self.model = torchani.models.ANI1ccx(periodic_table_index=True).double()

def forward(self, species: Tensor, coordinates: Tensor, return_forces: bool = False,
return_hessians: bool = False) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
if return_forces or return_hessians:
def forward(self, species: Tensor, coordinates: Tensor, return_forces: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
if return_forces:
coordinates.requires_grad_(True)

energies = self.model((species, coordinates)).energies

forces: Optional[Tensor] = None # noqa: E701
hessians: Optional[Tensor] = None
if return_forces or return_hessians:
grad = torch.autograd.grad([energies.sum()], [coordinates], create_graph=return_hessians)[0]
if return_forces:
grad = torch.autograd.grad([energies.sum()], [coordinates])[0]
assert grad is not None
forces = -grad
if return_hessians:
hessians = torchani.utils.hessian(coordinates, forces=forces)
return energies, forces, hessians
return energies, forces


custom_model = CustomModule()
compiled_custom_model = torch.jit.script(custom_model)
torch.jit.save(compiled_custom_model, 'compiled_custom_model.pt')
loaded_compiled_custom_model = torch.jit.load('compiled_custom_model.pt')
energies, forces, hessians = custom_model(species, coordinates, True, True)
energies_jit, forces_jit, hessians_jit = loaded_compiled_custom_model(species, coordinates, True, True)
energies, forces = custom_model(species, coordinates, True)
energies_jit, forces_jit = loaded_compiled_custom_model(species, coordinates, True)

print('Energy, eager mode vs loaded jit:', energies.item(), energies_jit.item())
print()
print('Force, eager mode vs loaded jit:\n', forces.squeeze(0), '\n', forces_jit.squeeze(0))
print()
torch.set_printoptions(sci_mode=False, linewidth=1000)
print('Hessian, eager mode vs loaded jit:\n', hessians.squeeze(0), '\n', hessians_jit.squeeze(0))
14 changes: 4 additions & 10 deletions examples/vibration_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,12 @@
masses = torchani.utils.get_atomic_masses(species)

###############################################################################
# To do vibration analysis, we first need to generate a graph that computes
# energies from species and coordinates. The code to generate a graph of energy
# is the same as the code to compute energy:
energies = model((species, coordinates)).energies
# We can use :func:`torch.autograd.functional.hessian` to compute hessian:
hessian = torch.autograd.functional.hessian(lambda x: model((species, x)).energies, coordinates)

###############################################################################
# We can now use the energy graph to compute analytical Hessian matrix:
hessian = torchani.utils.hessian(coordinates, energies=energies)

###############################################################################
# The Hessian matrix should have shape `(1, 9, 9)`, where 1 means there is only
# one molecule to compute, 9 means `3 atoms * 3D space = 9 degree of freedom`.
# The Hessian matrix should have shape `(1, 3, 3, 1, 3, 3)`, where 1 means there
# is only one molecule to compute, 3 means 3 atoms and 3D space.
print(hessian.shape)

###############################################################################
Expand Down
4 changes: 0 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import unittest
import torch
import torchani


Expand All @@ -10,9 +9,6 @@ def testChemicalSymbolsToInts(self):
self.assertEqual(len(str2i), 6)
self.assertListEqual(str2i('BACCC').tolist(), [1, 0, 2, 2, 2])

def testHessianJIT(self):
torch.jit.script(torchani.utils.hessian)


if __name__ == '__main__':
unittest.main()
3 changes: 1 addition & 2 deletions tests/test_vibrational.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def testVibrationalWavenumbers(self):
# compute vibrational by torchani
species = model.species_to_tensor(molecule.get_chemical_symbols()).unsqueeze(0)
coordinates = torch.from_numpy(molecule.get_positions()).unsqueeze(0).requires_grad_(True)
_, energies = model((species, coordinates))
hessian = torchani.utils.hessian(coordinates, energies=energies)
hessian = torch.autograd.functional.hessian(lambda x: model((species, x)).energies, coordinates)
freq2, modes2, _, _ = torchani.utils.vibrational_analysis(masses[species], hessian)
freq2 = freq2[6:].float()
modes2 = modes2[6:]
Expand Down
44 changes: 4 additions & 40 deletions torchani/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,43 +241,6 @@ def __len__(self):
return len(self.rev_species)


def _get_derivatives_not_none(x: Tensor, y: Tensor, retain_graph: Optional[bool] = None, create_graph: bool = False) -> Tensor:
ret = torch.autograd.grad([y.sum()], [x], retain_graph=retain_graph, create_graph=create_graph)[0]
assert ret is not None
return ret


def hessian(coordinates: Tensor, energies: Optional[Tensor] = None, forces: Optional[Tensor] = None) -> Tensor:
"""Compute analytical hessian from the energy graph or force graph.
Arguments:
coordinates (:class:`torch.Tensor`): Tensor of shape `(molecules, atoms, 3)`
energies (:class:`torch.Tensor`): Tensor of shape `(molecules,)`, if specified,
then `forces` must be `None`. This energies must be computed from
`coordinates` in a graph.
forces (:class:`torch.Tensor`): Tensor of shape `(molecules, atoms, 3)`, if specified,
then `energies` must be `None`. This forces must be computed from
`coordinates` in a graph.
Returns:
:class:`torch.Tensor`: Tensor of shape `(molecules, 3A, 3A)` where A is the number of
atoms in each molecule
"""
if energies is None and forces is None:
raise ValueError('Energies or forces must be specified')
if energies is not None and forces is not None:
raise ValueError('Energies or forces can not be specified at the same time')
if forces is None:
assert energies is not None
forces = -_get_derivatives_not_none(coordinates, energies, create_graph=True)
flattened_force = forces.flatten(start_dim=1)
force_components = flattened_force.unbind(dim=1)
return -torch.stack([
_get_derivatives_not_none(coordinates, f, retain_graph=True).flatten(start_dim=1)
for f in force_components
], dim=1)


class FreqsModes(NamedTuple):
freqs: Tensor
modes: Tensor
Expand Down Expand Up @@ -317,6 +280,8 @@ def vibrational_analysis(masses, hessian, mode_type='MDU', unit='cm^-1'):
raise ValueError('Only meV and cm^-1 are supported right now')

assert hessian.shape[0] == 1, 'Currently only supporting computing one molecule a time'
degree_of_freedom = hessian.shape[1] * hessian.shape[2]
hessian = hessian.reshape(1, degree_of_freedom, degree_of_freedom)
# Solving the eigenvalue problem: Hq = w^2 * T q
# where H is the Hessian matrix, q is the normal coordinates,
# T = diag(m1, m1, m1, m2, m2, m2, ....) is the mass
Expand Down Expand Up @@ -423,6 +388,5 @@ def get_atomic_masses(species):
""".strip().split()


__all__ = ['pad_atomic_properties', 'present_species', 'hessian',
'vibrational_analysis', 'strip_redundant_padding',
'ChemicalSymbolsToInts', 'get_atomic_masses']
__all__ = ['pad_atomic_properties', 'present_species', 'vibrational_analysis',
'strip_redundant_padding', 'ChemicalSymbolsToInts', 'get_atomic_masses']

1 comment on commit bd9d888

@IgnacioJPickering
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zasdfgbnm I remember a couple of months ago I tried pytorch's hessian implementation and it was significantly slower than your handwritten code in GPU, did you run any comparisons?

Please sign in to comment.