Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the wrapper of ANISymmetryFunctions #47

Merged
merged 16 commits into from
Jan 24, 2022
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ positions = torch.tensor(molecule.xyz * 10, dtype=torch.float32, requires_grad=T
# Construct ANI-2x and replace its operations with the optimized ones
nnp = torchani.models.ANI2x(periodic_table_index=True).to(device)
nnp.species_converter = TorchANISpeciesConverter(nnp.species_converter, species).to(device)
nnp.aev_computer = TorchANISymmetryFunctions(nnp.aev_computer).to(device)
nnp.aev_computer = TorchANISymmetryFunctions(nnp.species_converter, nnp.aev_computer, species).to(device)
nnp.neural_networks = TorchANIBatchedNN(nnp.species_converter, nnp.neural_networks, species).to(device)
nnp.energy_shifter = TorchANIEnergyShifter(nnp.species_converter, nnp.energy_shifter, species).to(device)

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch/OptimizedTorchANI.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, model: BuiltinModel, atomicNumbers: Tensor) -> None:

# Optimize the components of an ANI model
self.species_converter = TorchANISpeciesConverter(model.species_converter, atomicNumbers)
self.aev_computer = TorchANISymmetryFunctions(model.aev_computer)
self.aev_computer = TorchANISymmetryFunctions(model.species_converter, model.aev_computer, atomicNumbers)
self.neural_networks = TorchANIBatchedNN(model.species_converter, model.neural_networks, atomicNumbers)
self.energy_shifter = TorchANIEnergyShifter(model.species_converter, model.energy_shifter, atomicNumbers)

Expand Down
260 changes: 162 additions & 98 deletions src/pytorch/SymmetryFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
* SOFTWARE.
*/

#include <stdexcept>
#include <torch/script.h>
#include <torch/serialize/archive.h>
#include "CpuANISymmetryFunctions.h"
#ifdef ENABLE_CUDA
#include <stdexcept>
#include <c10/cuda/CUDAStream.h>
#include "CudaANISymmetryFunctions.h"

Expand All @@ -38,21 +39,24 @@ namespace NNPOps {
namespace ANISymmetryFunctions {

class Holder;
using std::vector;
using HolderPtr = torch::intrusive_ptr<Holder>;
using torch::Tensor;
using torch::optional;
using Context = torch::autograd::AutogradContext;
using HolderPtr = torch::intrusive_ptr<Holder>;
using std::string;
using std::vector;
using torch::autograd::tensor_list;
#ifdef ENABLE_CUDA
using torch::cuda::CUDAStream;
using torch::cuda::getCurrentCUDAStream;
#endif
using torch::Device;
using torch::IValue;
using torch::optional;
using torch::Tensor;
using torch::TensorOptions;

class Holder : public torch::CustomClassHolder {
public:

// Constructor for an uninitialized object
// Note: this is need for serialization
Holder() : torch::CustomClassHolder() {};

Holder(int64_t numSpecies_,
Holder(int64_t numSpecies,
double Rcr,
double Rca,
const vector<double>& EtaR,
Expand All @@ -61,105 +65,174 @@ class Holder : public torch::CustomClassHolder {
const vector<double>& Zeta,
const vector<double>& ShfA,
const vector<double>& ShfZ,
const vector<int64_t>& atomSpecies_,
const Tensor& positions) : torch::CustomClassHolder() {

// Construct an uninitialized object
// Note: this is needed for Python bindings
if (numSpecies_ == 0)
return;

tensorOptions = torch::TensorOptions().device(positions.device()); // Data type of float by default
int numAtoms = atomSpecies_.size();
int numSpecies = numSpecies_;
const vector<int> atomSpecies(atomSpecies_.begin(), atomSpecies_.end());

vector<RadialFunction> radialFunctions;
for (const float eta: EtaR)
for (const float rs: ShfR)
radialFunctions.push_back({eta, rs});

vector<AngularFunction> angularFunctions;
for (const float eta: EtaA)
for (const float zeta: Zeta)
for (const float rs: ShfA)
for (const float thetas: ShfZ)
angularFunctions.push_back({eta, rs, zeta, thetas});

const torch::Device& device = tensorOptions.device();
if (device.is_cpu())
symFunc = std::make_shared<CpuANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies, radialFunctions, angularFunctions, true);
#ifdef ENABLE_CUDA
else if (device.is_cuda()) {
// PyTorch allow to chose GPU with "torch.device", but it doesn't set as the default one.
CHECK_CUDA_RESULT(cudaSetDevice(device.index()));
symFunc = std::make_shared<CudaANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies, radialFunctions, angularFunctions, true);
const vector<int64_t>& atomSpecies) :

torch::CustomClassHolder(),
numSpecies(numSpecies),
Rcr(Rcr), Rca(Rca),
EtaR(EtaR), ShfR(ShfR), EtaA(EtaA), Zeta(Zeta), ShfA(ShfA), ShfZ(ShfZ),
atomSpecies(atomSpecies),
device(torch::kCPU),
impl(nullptr)
{};

tensor_list forward(const Tensor& positions, const optional<Tensor>& cellOpt) {

if (positions.scalar_type() != torch::kFloat32)
throw std::runtime_error("The type of \"positions\" has to be float32");
if (positions.dim() != 2)
throw std::runtime_error("The shape of \"positions\" has to have 2 dimensions");
if (positions.size(0) != atomSpecies.size())
throw std::runtime_error("The size of the 1nd dimension of \"positions\" has to be " + std::to_string(atomSpecies.size()));
if (positions.size(1) != 3)
throw std::runtime_error("The size of the 2nd dimension of \"positions\" has to be 3");

Tensor cell;
float* cellPtr = nullptr;
if (cellOpt) {
cell = *cellOpt;

if (cell.scalar_type() != torch::kFloat32)
throw std::runtime_error("The type of \"cell\" has to be float32");
if (cell.dim() != 2)
throw std::runtime_error("The shape of \"cell\" has to have 2 dimensions");
if (cell.size(0) != 3)
throw std::runtime_error("The size of the 1nd dimension of \"cell\" has to be 3");
if (cell.size(1) != 3)
throw std::runtime_error("\"cell\" has to be on the same device as \"positions\"");
if (cell.device() != positions.device())
throw std::runtime_error("The device of \"cell\" has changed");

cellPtr = cell.data_ptr<float>();
}
#endif
else
throw std::runtime_error("Unsupported device: " + device.str());

radial = torch::empty({numAtoms, numSpecies * (int)radialFunctions.size()}, tensorOptions);
angular = torch::empty({numAtoms, numSpecies * (numSpecies + 1) / 2 * (int)angularFunctions.size()}, tensorOptions);
positionsGrad = torch::empty({numAtoms, 3}, tensorOptions);
if (!impl) {
device = positions.device();

int numAtoms = atomSpecies.size();
const vector<int> atomSpecies_(atomSpecies.begin(), atomSpecies.end()); // vector<int64_t> --> vector<int>

vector<RadialFunction> radialFunctions;
for (const float eta: EtaR)
for (const float rs: ShfR)
radialFunctions.push_back({eta, rs});

vector<AngularFunction> angularFunctions;
for (const float eta: EtaA)
for (const float zeta: Zeta)
for (const float rs: ShfA)
for (const float thetas: ShfZ)
angularFunctions.push_back({eta, rs, zeta, thetas});

if (device.is_cpu()) {
impl = std::make_shared<CpuANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies_, radialFunctions, angularFunctions, true);
#ifdef ENABLE_CUDA
cudaSymFunc = dynamic_cast<CudaANISymmetryFunctions*>(symFunc.get());
} else if (device.is_cuda()) {
// PyTorch allow to chose GPU with "torch.device", but it doesn't set as the default one.
CHECK_CUDA_RESULT(cudaSetDevice(device.index()));
impl = std::make_shared<CudaANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies_, radialFunctions, angularFunctions, true);
#endif
};

tensor_list forward(const Tensor& positions_, const optional<Tensor>& periodicBoxVectors_) {
} else
throw std::runtime_error("Unsupported device: " + device.str());

const Tensor positions = positions_.to(tensorOptions);
const TensorOptions tensorOptions = TensorOptions().device(device); // Data type of float by default
radial = torch::empty({numAtoms, numSpecies * (int)radialFunctions.size()}, tensorOptions);
angular = torch::empty({numAtoms, numSpecies * (numSpecies + 1) / 2 * (int)angularFunctions.size()}, tensorOptions);
positionsGrad = torch::empty({numAtoms, 3}, tensorOptions);

Tensor periodicBoxVectors;
float* periodicBoxVectorsPtr = nullptr;
if (periodicBoxVectors_) {
periodicBoxVectors = periodicBoxVectors_->to(tensorOptions);
float* periodicBoxVectorsPtr = periodicBoxVectors.data_ptr<float>();
#ifdef ENABLE_CUDA
cudaImpl = dynamic_cast<CudaANISymmetryFunctions*>(impl.get());
#endif
}

if (positions.device() != device)
throw std::runtime_error("The device of \"positions\" has changed");

#ifdef ENABLE_CUDA
if (cudaSymFunc) {
const torch::cuda::CUDAStream stream = torch::cuda::getCurrentCUDAStream(tensorOptions.device().index());
cudaSymFunc->setStream(stream.stream());
if (cudaImpl) {
const CUDAStream stream = getCurrentCUDAStream(device.index());
cudaImpl->setStream(stream.stream());
}
#endif

symFunc->computeSymmetryFunctions(positions.data_ptr<float>(), periodicBoxVectorsPtr, radial.data_ptr<float>(), angular.data_ptr<float>());
impl->computeSymmetryFunctions(positions.data_ptr<float>(), cellPtr, radial.data_ptr<float>(), angular.data_ptr<float>());

return {radial, angular};
};

Tensor backward(const tensor_list& grads) {
tensor_list backward(const tensor_list& grads) {

const Tensor radialGrad = grads[0].clone();
const Tensor angularGrad = grads[1].clone();

#ifdef ENABLE_CUDA
if (cudaSymFunc) {
const torch::cuda::CUDAStream stream = torch::cuda::getCurrentCUDAStream(tensorOptions.device().index());
cudaSymFunc->setStream(stream.stream());
if (cudaImpl) {
const CUDAStream stream = getCurrentCUDAStream(device.index());
cudaImpl->setStream(stream.stream());
}
#endif

symFunc->backprop(radialGrad.data_ptr<float>(), angularGrad.data_ptr<float>(), positionsGrad.data_ptr<float>());
impl->backprop(radialGrad.data_ptr<float>(), angularGrad.data_ptr<float>(), positionsGrad.data_ptr<float>());

return positionsGrad;
return { Tensor(), positionsGrad, Tensor() }; // empty grad for the holder and periodicBoxVectors
};

bool is_initialized() {
return bool(symFunc);
static const string serialize(const HolderPtr& self) {

torch::serialize::OutputArchive archive;
archive.write("numSpecies", self->numSpecies);
archive.write("Rcr", self->Rcr);
archive.write("Rca", self->Rca);
archive.write("EtaR", self->EtaR);
archive.write("ShfR", self->ShfR);
archive.write("EtaA", self->EtaA);
archive.write("Zeta", self->Zeta);
archive.write("ShfA", self->ShfA);
archive.write("ShfZ", self->ShfZ);
archive.write("atomSpecies", self->atomSpecies);

std::stringstream stream;
archive.save_to(stream);
return stream.str();
};

static HolderPtr deserialize(const string& state) {

std::stringstream stream(state);
torch::serialize::InputArchive archive;
archive.load_from(stream, torch::kCPU);

IValue numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies;
archive.read("numSpecies", numSpecies);
archive.read("Rcr", Rcr);
archive.read("Rca", Rca);
archive.read("EtaR", EtaR);
archive.read("ShfR", ShfR);
archive.read("EtaA", EtaA);
archive.read("Zeta", Zeta);
archive.read("ShfA", ShfA);
archive.read("ShfZ", ShfZ);
archive.read("atomSpecies", atomSpecies);

return HolderPtr::make(numSpecies.toInt(), Rcr.toDouble(), Rca.toDouble(),
EtaR.toDoubleVector(), ShfR.toDoubleVector(), EtaA.toDoubleVector(),
Zeta.toDoubleVector(), ShfA.toDoubleVector(), ShfZ.toDoubleVector(),
atomSpecies.toIntVector());
}

private:
torch::TensorOptions tensorOptions;
std::shared_ptr<::ANISymmetryFunctions> symFunc;
int numSpecies;
double Rcr, Rca;
vector<double> EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ;
vector<int64_t> atomSpecies;
TensorOptions tensorOptions;
Device device;
std::shared_ptr<::ANISymmetryFunctions> impl;
Tensor radial;
Tensor angular;
Tensor positionsGrad;
#ifdef ENABLE_CUDA
CudaANISymmetryFunctions* cudaSymFunc;
CudaANISymmetryFunctions* cudaImpl;
#endif
};

Expand All @@ -179,12 +252,9 @@ class AutogradFunctions : public torch::autograd::Function<AutogradFunctions> {
static tensor_list backward(Context *ctx, const tensor_list& grads) {

const auto holder = ctx->saved_data["holder"].toCustomClass<Holder>();
Tensor positionsGrad = holder->backward(grads);
ctx->saved_data.erase("holder");

return { Tensor(), // holder
positionsGrad, // positions
Tensor() }; // periodicBoxVectors
return holder->backward(grads);
};
};

Expand All @@ -197,27 +267,21 @@ tensor_list operation(const optional<HolderPtr>& holder,

TORCH_LIBRARY(NNPOpsANISymmetryFunctions, m) {
m.class_<Holder>("Holder")
.def(torch::init<int64_t, // numSpecies
double, // Rcr
double, // Rca
const vector<double>&, // EtaR
const vector<double>&, // ShfR
const vector<double>&, // EtaA
const vector<double>&, // Zeta
const vector<double>&, // ShfA
const vector<double>&, // ShfZ
const vector<int64_t>&, // atomSpecies
const Tensor&>()) // positions
.def(torch::init<int64_t, // numSpecies
double, // Rcr
double, // Rca
const vector<double>&, // EtaR
const vector<double>&, // ShfR
const vector<double>&, // EtaA
const vector<double>&, // Zeta
const vector<double>&, // ShfA
const vector<double>&, // ShfZ
const vector<int64_t>&>()) // atomSpecies
.def("forward", &Holder::forward)
.def("backward", &Holder::backward)
.def("is_initialized", &Holder::is_initialized)
.def_pickle(
// __getstate__
// Note: nothing is during serialization
[](const HolderPtr& self) -> int64_t { return 0; },
// __setstate__
// Note: a new uninitialized object is create during deserialization
[](int64_t state) -> HolderPtr { return HolderPtr::make(); }
[](const HolderPtr& self) -> const string { return Holder::serialize(self); }, // __getstate__
[](const string& state) -> HolderPtr { return Holder::deserialize(state); } // __setstate__
);
m.def("operation", operation);
}
Expand Down
Loading