Skip to content

Commit

Permalink
A better PyTorch wrapper (#19)
Browse files Browse the repository at this point in the history
* Move CustomANISymmetryFunctions construction to ANISymmetryFunctionsOp

* Move CustomANISymmetryFunctions construction to TorchANISymmetryFunctions.forward

* Move CustomANISymmetryFunctions construction to TorchANISymmetryFunctions.forward

* Create NNPOps::ANISymmetryFunctions namespace

* Simplify names

* Simplify types

* Fix typo

* Implement Holder::is_initialized

* Don't use Optional[Holder]

* Fix serializaton

* Update the benckmark

* Update the build instructions

* Fix the constructor
  • Loading branch information
Raimondas Galvelis authored Apr 26, 2021
1 parent e540a63 commit 3d0bab6
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 124 deletions.
4 changes: 1 addition & 3 deletions pytorch/BenchmarkTorchANISymmetryFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
sum_aev.backward()
grad = positions.grad.clone()

N = 40000
N = 100000
start = time.time()
for _ in range(N):
aev = symmFunc(speciesPositions).aevs
Expand All @@ -55,7 +55,5 @@

aev_error = torch.max(torch.abs(aev - aev_ref))
grad_error = torch.max(torch.abs(grad - grad_ref))
print(aev_error)
print(grad_error)
assert aev_error < 0.0002
assert grad_error < 0.007
33 changes: 13 additions & 20 deletions pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,39 +44,32 @@ print(energy, forces)

### Build & install

- Crate a *Conda* environment
- Get the source code
```bash
$ conda create -n nnpops \
-c pytorch \
-c conda-forge \
cmake \
git \
gxx_linux-64 \
make \
mdtraj \
pytest \
python=3.8 \
pytorch=1.6 \
torchani=2.2
$ conda activate nnpops
$ git clone https://github.com/openmm/NNPOps.git
```
- Get the source code

- Crate a *Conda* environment
```bash
$ git clone https://github.com/peastman/NNPOps.git
$ cd NNPOps
$ conda create -f pytorch/environment.yml
$ conda activate nnpops
```

- Configure, build, and install
```bash
$ mkdir build
$ cd build
$ cmake ../NNPOps/pytorch \
$ cmake ../pytorch \
-DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc \
-DCMAKE_CUDA_HOST_COMPILER=$CXX \
-DTorch_DIR=$CONDA_PREFIX/lib/python3.8/site-packages/torch/share/cmake/Torch \
-DTorch_DIR=$CONDA_PREFIX/lib/python3.9/site-packages/torch/share/cmake/Torch \
-DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
$ make install
```
- Optional: run tests
- Optional: run tests and benchmarks
```bash
$ cd ../NNPOps/pytorch
$ cd ../pytorch
$ pytest TestSymmetryFunctions.py
$ python BenchmarkTorchANISymmetryFunctions.py
```
198 changes: 103 additions & 95 deletions pytorch/SymmetryFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,52 @@
throw std::runtime_error(std::string("Encountered error ")+cudaGetErrorName(result)+" at "+__FILE__+":"+std::to_string(__LINE__));\
}

class CustomANISymmetryFunctions : public torch::CustomClassHolder {
namespace NNPOps {
namespace ANISymmetryFunctions {

class Holder;
using std::vector;
using HolderPtr = torch::intrusive_ptr<Holder>;
using torch::Tensor;
using torch::optional;
using Context = torch::autograd::AutogradContext;
using torch::autograd::tensor_list;

class Holder : public torch::CustomClassHolder {
public:
CustomANISymmetryFunctions(int64_t numSpecies_,
double Rcr,
double Rca,
const std::vector<double>& EtaR,
const std::vector<double>& ShfR,
const std::vector<double>& EtaA,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies_,
const torch::Tensor& positions) : torch::CustomClassHolder() {

// Constructor for an uninitialized object
// Note: this is need for serialization
Holder() : torch::CustomClassHolder() {};

Holder(int64_t numSpecies_,
double Rcr,
double Rca,
const vector<double>& EtaR,
const vector<double>& ShfR,
const vector<double>& EtaA,
const vector<double>& Zeta,
const vector<double>& ShfA,
const vector<double>& ShfZ,
const vector<int64_t>& atomSpecies_,
const Tensor& positions) : torch::CustomClassHolder() {

// Construct an uninitialized object
// Note: this is needed for Python bindings
if (numSpecies_ == 0)
return;

tensorOptions = torch::TensorOptions().device(positions.device()); // Data type of float by default
int numAtoms = atomSpecies_.size();
int numSpecies = numSpecies_;
const std::vector<int> atomSpecies(atomSpecies_.begin(), atomSpecies_.end());
const vector<int> atomSpecies(atomSpecies_.begin(), atomSpecies_.end());

std::vector<RadialFunction> radialFunctions;
vector<RadialFunction> radialFunctions;
for (const float eta: EtaR)
for (const float rs: ShfR)
radialFunctions.push_back({eta, rs});

std::vector<AngularFunction> angularFunctions;
vector<AngularFunction> angularFunctions;
for (const float eta: EtaA)
for (const float zeta: Zeta)
for (const float rs: ShfA)
Expand All @@ -77,11 +98,11 @@ class CustomANISymmetryFunctions : public torch::CustomClassHolder {
positionsGrad = torch::empty({numAtoms, 3}, tensorOptions);
};

torch::autograd::tensor_list forward(const torch::Tensor& positions_, const torch::optional<torch::Tensor>& periodicBoxVectors_) {
tensor_list forward(const Tensor& positions_, const optional<Tensor>& periodicBoxVectors_) {

const torch::Tensor positions = positions_.to(tensorOptions);
const Tensor positions = positions_.to(tensorOptions);

torch::Tensor periodicBoxVectors;
Tensor periodicBoxVectors;
float* periodicBoxVectorsPtr = nullptr;
if (periodicBoxVectors_) {
periodicBoxVectors = periodicBoxVectors_->to(tensorOptions);
Expand All @@ -93,99 +114,86 @@ class CustomANISymmetryFunctions : public torch::CustomClassHolder {
return {radial, angular};
};

torch::Tensor backward(const torch::autograd::tensor_list& grads) {
Tensor backward(const tensor_list& grads) {

const torch::Tensor radialGrad = grads[0].clone();
const torch::Tensor angularGrad = grads[1].clone();
const Tensor radialGrad = grads[0].clone();
const Tensor angularGrad = grads[1].clone();

symFunc->backprop(radialGrad.data_ptr<float>(), angularGrad.data_ptr<float>(), positionsGrad.data_ptr<float>());

return positionsGrad;
}
};

bool is_initialized() {
return bool(symFunc);
};

private:
torch::TensorOptions tensorOptions;
std::shared_ptr<ANISymmetryFunctions> symFunc;
torch::Tensor radial;
torch::Tensor angular;
torch::Tensor positionsGrad;
std::shared_ptr<::ANISymmetryFunctions> symFunc;
Tensor radial;
Tensor angular;
Tensor positionsGrad;
};

class GradANISymmetryFunction : public torch::autograd::Function<GradANISymmetryFunction> {
class AutogradFunctions : public torch::autograd::Function<AutogradFunctions> {

public:
static torch::autograd::tensor_list forward(torch::autograd::AutogradContext *ctx,
int64_t numSpecies,
double Rcr,
double Rca,
const std::vector<double>& EtaR,
const std::vector<double>& ShfR,
const std::vector<double>& EtaA,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies,
const torch::Tensor& positions,
const torch::optional<torch::Tensor>& periodicBoxVectors) {

const auto symFunc = torch::intrusive_ptr<CustomANISymmetryFunctions>::make(
numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions);
ctx->saved_data["symFunc"] = symFunc;

return symFunc->forward(positions, periodicBoxVectors);
static tensor_list forward(Context *ctx,
const HolderPtr& holder,
const Tensor& positions,
const optional<Tensor>& periodicBoxVectors) {

ctx->saved_data["holder"] = holder;

return holder->forward(positions, periodicBoxVectors);
};

static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, const torch::autograd::tensor_list& grads) {

const auto symFunc = ctx->saved_data["symFunc"].toCustomClass<CustomANISymmetryFunctions>();
torch::Tensor positionsGrad = symFunc->backward(grads);
ctx->saved_data.erase("symFunc");

return { torch::Tensor(), // numSpecies
torch::Tensor(), // Rcr
torch::Tensor(), // Rca
torch::Tensor(), // EtaR
torch::Tensor(), // ShfR
torch::Tensor(), // EtaA
torch::Tensor(), // Zeta
torch::Tensor(), // ShfA
torch::Tensor(), // ShfZ
torch::Tensor(), // atomSpecies
positionsGrad, // positions
torch::Tensor()}; // periodicBoxVectors
static tensor_list backward(Context *ctx, const tensor_list& grads) {

const auto holder = ctx->saved_data["holder"].toCustomClass<Holder>();
Tensor positionsGrad = holder->backward(grads);
ctx->saved_data.erase("holder");

return { Tensor(), // holder
positionsGrad, // positions
Tensor() }; // periodicBoxVectors
};
};

static torch::autograd::tensor_list ANISymmetryFunctionsOp(int64_t numSpecies,
double Rcr,
double Rca,
const std::vector<double>& EtaR,
const std::vector<double>& ShfR,
const std::vector<double>& EtaA,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies,
const torch::Tensor& positions,
const torch::optional<torch::Tensor>& periodicBoxVectors) {

return GradANISymmetryFunction::apply(numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions, periodicBoxVectors);
tensor_list operation(const optional<HolderPtr>& holder,
const Tensor& positions,
const optional<Tensor>& periodicBoxVectors) {

return AutogradFunctions::apply(*holder, positions, periodicBoxVectors);
}

TORCH_LIBRARY(NNPOpsANISymmetryFunctions, m) {
m.class_<Holder>("Holder")
.def(torch::init<int64_t, // numSpecies
double, // Rcr
double, // Rca
const vector<double>&, // EtaR
const vector<double>&, // ShfR
const vector<double>&, // EtaA
const vector<double>&, // Zeta
const vector<double>&, // ShfA
const vector<double>&, // ShfZ
const vector<int64_t>&, // atomSpecies
const Tensor&>()) // positions
.def("forward", &Holder::forward)
.def("backward", &Holder::backward)
.def("is_initialized", &Holder::is_initialized)
.def_pickle(
// __getstate__
// Note: nothing is during serialization
[](const HolderPtr& self) -> int64_t { return 0; },
// __setstate__
// Note: a new uninitialized object is create during deserialization
[](int64_t state) -> HolderPtr { return HolderPtr::make(); }
);
m.def("operation", operation);
}

TORCH_LIBRARY(NNPOps, m) {
m.class_<CustomANISymmetryFunctions>("CustomANISymmetryFunctions")
.def(torch::init<int64_t, // numSpecies
double, // Rcr
double, // Rca
const std::vector<double>&, // EtaR
const std::vector<double>&, // ShfR
const std::vector<double>&, // EtaA
const std::vector<double>&, // Zeta
const std::vector<double>&, // ShfA
const std::vector<double>&, // ShfZ
const std::vector<int64_t>&, // atomSpecies
const torch::Tensor&>()) // positions
.def("forward", &CustomANISymmetryFunctions::forward)
.def("backward", &CustomANISymmetryFunctions::backward);
m.def("ANISymmetryFunctions", ANISymmetryFunctionsOp);
}
} // namespace ANISymmetryFunctions
} // namespace NNPOps
23 changes: 17 additions & 6 deletions pytorch/SymmetryFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from torchani.aev import SpeciesAEV

torch.ops.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))
torch.classes.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))

Holder = torch.classes.NNPOpsANISymmetryFunctions.Holder
operation = torch.ops.NNPOpsANISymmetryFunctions.operation

class TorchANISymmetryFunctions(torch.nn.Module):
"""Optimized TorchANI symmetry functions
Expand Down Expand Up @@ -66,7 +70,6 @@ def __init__(self, symmFunc: torchani.AEVComputer):
Arguments:
symmFunc: the instance of torchani.AEVComputer (https://aiqm.github.io/torchani/api.html#torchani.AEVComputer)
"""

super().__init__()

self.numSpecies = symmFunc.num_species
Expand All @@ -79,6 +82,10 @@ def __init__(self, symmFunc: torchani.AEVComputer):
self.ShfA = symmFunc.ShfA[0, 0, :, 0].tolist()
self.ShfZ = symmFunc.ShfZ[0, 0, 0, :].tolist()

# Create an uninitialized holder
self.holder = Holder(0, 0, 0, [], [] , [] , [], [] , [], [], Tensor())
assert not self.holder.is_initialized()

self.triu_index = torch.tensor([0]) # A dummy variable to make TorchScript happy ;)

def forward(self, speciesAndPositions: Tuple[Tensor, Tensor],
Expand All @@ -100,7 +107,6 @@ def forward(self, speciesAndPositions: Tuple[Tensor, Tensor],
species, positions = speciesAndPositions
if species.shape[0] != 1:
raise ValueError('Batched molecule computation is not supported')
species_: List[int] = species[0].tolist() # Explicit type casting for TorchScript
if species.shape + (3,) != positions.shape:
raise ValueError('Inconsistent shapes of "species" and "positions"')
if cell is not None:
Expand All @@ -113,10 +119,15 @@ def forward(self, speciesAndPositions: Tuple[Tensor, Tensor],
if pbc_ != [True, True, True]:
raise ValueError('Only fully periodic systems are supported, i.e. pbc = [True, True, True]')

symFunc = torch.ops.NNPOps.ANISymmetryFunctions
radial, angular = symFunc(self.numSpecies, self.Rcr, self.Rca, self.EtaR, self.ShfR,
self.EtaA, self.Zeta, self.ShfA, self.ShfZ,
species_, positions[0], cell)
if not self.holder.is_initialized():
species_: List[int] = species[0].tolist() # Explicit type casting for TorchScript
self.holder = Holder(self.numSpecies, self.Rcr, self.Rca,
self.EtaR, self.ShfR,
self.EtaA, self.Zeta, self.ShfA, self.ShfZ,
species_, positions)
assert self.holder.is_initialized()

radial, angular = operation(self.holder, positions[0], cell)
features = torch.cat((radial, angular), dim=1).unsqueeze(0)

return SpeciesAEV(species, features)
12 changes: 12 additions & 0 deletions pytorch/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: nnpops
channels:
- conda-forge
dependencies:
- cmake
- gxx_linux-64
- make
- mdtraj
- torchani 2.2
- pytest
- python 3.9
- pytorch 1.8.0

0 comments on commit 3d0bab6

Please sign in to comment.