Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A better PyTorch wrapper #19

Merged
merged 14 commits into from
Apr 26, 2021
61 changes: 18 additions & 43 deletions pytorch/SymmetryFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,62 +114,37 @@ class CustomANISymmetryFunctions : public torch::CustomClassHolder {
class GradANISymmetryFunction : public torch::autograd::Function<GradANISymmetryFunction> {

public:
static torch::autograd::tensor_list forward(torch::autograd::AutogradContext *ctx,
int64_t numSpecies,
double Rcr,
double Rca,
const std::vector<double>& EtaR,
const std::vector<double>& ShfR,
const std::vector<double>& EtaA,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies,
const torch::Tensor& positions,
const torch::optional<torch::Tensor>& periodicBoxVectors) {

const auto symFunc = torch::intrusive_ptr<CustomANISymmetryFunctions>::make(
numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions);
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext *ctx,
const torch::intrusive_ptr<CustomANISymmetryFunctions>& symFunc,
const torch::Tensor& positions,
const torch::optional<torch::Tensor>& periodicBoxVectors) {

ctx->saved_data["symFunc"] = symFunc;

return symFunc->forward(positions, periodicBoxVectors);
};

static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, const torch::autograd::tensor_list& grads) {
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext *ctx,
const torch::autograd::tensor_list& grads) {

const auto symFunc = ctx->saved_data["symFunc"].toCustomClass<CustomANISymmetryFunctions>();
torch::Tensor positionsGrad = symFunc->backward(grads);
ctx->saved_data.erase("symFunc");

return { torch::Tensor(), // numSpecies
torch::Tensor(), // Rcr
torch::Tensor(), // Rca
torch::Tensor(), // EtaR
torch::Tensor(), // ShfR
torch::Tensor(), // EtaA
torch::Tensor(), // Zeta
torch::Tensor(), // ShfA
torch::Tensor(), // ShfZ
torch::Tensor(), // atomSpecies
positionsGrad, // positions
torch::Tensor()}; // periodicBoxVectors
return { torch::Tensor(), // symFunc
positionsGrad, // positions
torch::Tensor() }; // periodicBoxVectors
};
};

static torch::autograd::tensor_list ANISymmetryFunctionsOp(int64_t numSpecies,
double Rcr,
double Rca,
const std::vector<double>& EtaR,
const std::vector<double>& ShfR,
const std::vector<double>& EtaA,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies,
const torch::Tensor& positions,
const torch::optional<torch::Tensor>& periodicBoxVectors) {

return GradANISymmetryFunction::apply(numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions, periodicBoxVectors);
static torch::autograd::tensor_list ANISymmetryFunctionsOp(
const torch::optional<torch::intrusive_ptr<CustomANISymmetryFunctions>>& symFunc,
const torch::Tensor& positions,
const torch::optional<torch::Tensor>& periodicBoxVectors) {

return GradANISymmetryFunction::apply(*symFunc, positions, periodicBoxVectors);
}

TORCH_LIBRARY(NNPOps, m) {
Expand Down
19 changes: 14 additions & 5 deletions pytorch/SymmetryFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from torchani.aev import SpeciesAEV

torch.ops.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))
torch.classes.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))


class TorchANISymmetryFunctions(torch.nn.Module):
"""Optimized TorchANI symmetry functions
Expand Down Expand Up @@ -61,12 +63,13 @@ class TorchANISymmetryFunctions(torch.nn.Module):

>>> print(energy, forces)
"""
holder: Optional[torch.classes.NNPOps.CustomANISymmetryFunctions]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call it something more descriptive like symmetryFunctions? In my example, the name holder was a ValueHolder object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still working on cleaning up the code, but to call an instance symmetryFunctions is misleading and might be confused with an actual reference to the function symFunc.


def __init__(self, symmFunc: torchani.AEVComputer):
"""
Arguments:
symmFunc: the instance of torchani.AEVComputer (https://aiqm.github.io/torchani/api.html#torchani.AEVComputer)
"""

super().__init__()

self.numSpecies = symmFunc.num_species
Expand All @@ -79,6 +82,8 @@ def __init__(self, symmFunc: torchani.AEVComputer):
self.ShfA = symmFunc.ShfA[0, 0, :, 0].tolist()
self.ShfZ = symmFunc.ShfZ[0, 0, 0, :].tolist()

self.holder = None

self.triu_index = torch.tensor([0]) # A dummy variable to make TorchScript happy ;)

def forward(self, speciesAndPositions: Tuple[Tensor, Tensor],
Expand All @@ -100,7 +105,6 @@ def forward(self, speciesAndPositions: Tuple[Tensor, Tensor],
species, positions = speciesAndPositions
if species.shape[0] != 1:
raise ValueError('Batched molecule computation is not supported')
species_: List[int] = species[0].tolist() # Explicit type casting for TorchScript
if species.shape + (3,) != positions.shape:
raise ValueError('Inconsistent shapes of "species" and "positions"')
if cell is not None:
Expand All @@ -113,10 +117,15 @@ def forward(self, speciesAndPositions: Tuple[Tensor, Tensor],
if pbc_ != [True, True, True]:
raise ValueError('Only fully periodic systems are supported, i.e. pbc = [True, True, True]')

if self.holder is None:
SymClass = torch.classes.NNPOps.CustomANISymmetryFunctions
species_: List[int] = species[0].tolist() # Explicit type casting for TorchScript
self.holder = SymClass(self.numSpecies, self.Rcr, self.Rca, self.EtaR, self.ShfR,
self.EtaA, self.Zeta, self.ShfA, self.ShfZ,
species_, positions)

symFunc = torch.ops.NNPOps.ANISymmetryFunctions
radial, angular = symFunc(self.numSpecies, self.Rcr, self.Rca, self.EtaR, self.ShfR,
self.EtaA, self.Zeta, self.ShfA, self.ShfZ,
species_, positions[0], cell)
radial, angular = symFunc(self.holder, positions[0], cell)
features = torch.cat((radial, angular), dim=1).unsqueeze(0)

return SpeciesAEV(species, features)