Skip to content

Commit

Permalink
CFConv layer for SchNet (#18)
Browse files Browse the repository at this point in the history
* Began implementing CFConv

* Continuing CPU implementation of CFConv

* Gradients for CPU implementation of CFConv

* Added tests for periodic boundary conditions

* Added more documentation

* Beginning of CUDA implementation

* CUDA derivatives of CFConv

* Simplified API to allow more optimizations

* Restructured neighbor list to improve performance

* Further optimizations

* More optimizations

* Added README
  • Loading branch information
peastman authored Nov 24, 2020
1 parent f9716fd commit 668cf83
Show file tree
Hide file tree
Showing 11 changed files with 1,788 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ani/ANISymmetryFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class ANISymmetryFunctions {
numAtoms(numAtoms), numSpecies(numSpecies), radialCutoff(radialCutoff), angularCutoff(angularCutoff),
periodic(periodic), atomSpecies(atomSpecies), radialFunctions(radialFunctions), angularFunctions(angularFunctions), torchani(torchani) {
}
virtual ~ANISymmetryFunctions() {
}
/**
* Compute the symmetry functions.
*
Expand Down
142 changes: 142 additions & 0 deletions schnet/BenchmarkCudaCFConv.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#include "CudaCFConv.h"
#include <cmath>
#include <cstdio>
#include <ctime>
#include <fstream>
#include <iostream>
#include <map>
#include <random>
#include <stdexcept>
#include <string>
#include <vector>

using namespace std;

void computeBoxVectors(float a, float b, float c, float alpha, float beta, float gamma, vector<float>& periodicBoxVectors) {
float bx = b*cos(gamma);
float by = b*sin(gamma);
float cx = c*cos(beta);
float cy = c*(cos(alpha)-cos(beta)*cos(gamma))/sin(gamma);
float cz = sqrt(c*c-cx*cx-cy*cy);
float scale1 = std::round(cy/by);
cx -= bx*scale1;
cy -= by*scale1;
float scale2 = std::round(cx/a);
cx -= a*scale2;
float scale3 = std::round(bx/a);
bx -= a*scale3;
periodicBoxVectors.push_back(a);
periodicBoxVectors.push_back(0);
periodicBoxVectors.push_back(0);
periodicBoxVectors.push_back(bx);
periodicBoxVectors.push_back(by);
periodicBoxVectors.push_back(0);
periodicBoxVectors.push_back(cx);
periodicBoxVectors.push_back(cy);
periodicBoxVectors.push_back(cz);
}

void loadPdb(string filename, vector<float>& positions, vector<float>& periodicBoxVectors) {
ifstream file(filename);
if (!file.is_open())
throw runtime_error("Failed to open PDB file");
string line;
while (getline(file, line)) {
if (line.rfind("ATOM", 0) == 0 || line.rfind("HETATM", 0) == 0) {
positions.push_back(stof(line.substr(30, 8)));
positions.push_back(stof(line.substr(38, 8)));
positions.push_back(stof(line.substr(46, 8)));
}
if (line.rfind("CRYST1", 0) == 0) {
float a = stof(line.substr(6, 9));
float b = stof(line.substr(15, 9));
float c = stof(line.substr(24, 9));
float alpha = stof(line.substr(33, 7))*M_PI/180;
float beta = stof(line.substr(40, 7))*M_PI/180;
float gamma = stof(line.substr(47, 7))*M_PI/180;
computeBoxVectors(a, b, c, alpha, beta, gamma, periodicBoxVectors);
}
}
file.close();
}

void runBenchmark(int iterations, vector<float>& positions, vector<float>& periodicBoxVectors) {
int numAtoms = positions.size()/3;
int width = 128;
int numGaussians = 50;
float cutoff = 10;

// Generate random weights and biases. We don't care about the values, since they
// don't affect speed.

vector<float> w1, b1, w2, b2;
std::default_random_engine generator(0);
std::normal_distribution<double> distribution(0, 1);
for (int i = 0; i < width; i++) {
b1.push_back(distribution(generator));
b2.push_back(distribution(generator));
for (int j = 0; j < numGaussians; j++)
w1.push_back(distribution(generator));
for (int j = 0; j < width; j++)
w2.push_back(distribution(generator));
}

// Allocate all the memory we will need.

CudaCFConvNeighbors neighbors(numAtoms, cutoff, periodicBoxVectors.size() > 0);
CudaCFConv cfconv(numAtoms, width, numGaussians, cutoff, periodicBoxVectors.size() > 0, 0.2, w1.data(), b1.data(), w2.data(), b2.data());
float *positionsData, *vectorsData, *input, *output, *inputDerivs, *outputDerivs, *positionDerivs;
cudaMalloc(&positionsData, positions.size()*sizeof(float));
cudaMalloc(&vectorsData, 9*sizeof(float));
cudaMallocManaged(&input, numAtoms*width*sizeof(float));
cudaMalloc(&output, numAtoms*width*sizeof(float));
cudaMalloc(&inputDerivs, numAtoms*width*sizeof(float));
cudaMalloc(&outputDerivs, numAtoms*width*sizeof(float));
cudaMalloc(&positionDerivs, positions.size()*sizeof(float));
cudaMemcpy(positionsData, positions.data(), positions.size()*sizeof(float), cudaMemcpyDefault);
cudaMemcpy(vectorsData, periodicBoxVectors.data(), periodicBoxVectors.size()*sizeof(float), cudaMemcpyDefault);
for (int i = 0; i < width; i++)
for (int j = 0; j < numAtoms; j++)
input[j*width+i] = distribution(generator);

// Run the benchmark.

clock_t start = clock();
for (int i = 0; i < iterations; i++) {
neighbors.build(positionsData, vectorsData);
for (int j = 0; j < 6; j++) {
cfconv.compute(neighbors, positionsData, vectorsData, input, output);
cfconv.backprop(neighbors, positionsData, vectorsData, input, outputDerivs, inputDerivs, positionDerivs);
}
}
cudaDeviceSynchronize();
clock_t finish = clock();
double duration = (double) (finish-start)/CLOCKS_PER_SEC;
printf(" %f sec\n", duration);
printf(" %f ms/iteration\n", duration/iterations*1000);

// Release device memory.

cudaFree(positionsData);
cudaFree(vectorsData);
cudaFree(input);
cudaFree(output);
cudaFree(inputDerivs);
cudaFree(outputDerivs);
cudaFree(positionDerivs);
}

int main(int argc, char* argv[]) {
try {
if (argc != 3)
throw runtime_error("Expected two command line arguments");
vector<float> positions, periodicBoxVectors;
loadPdb(argv[1], positions, periodicBoxVectors);
runBenchmark(stoi(argv[2]), positions, periodicBoxVectors);
}
catch (const exception& e) {
cout << e.what() << endl;
return 1;
}
return 0;
}
197 changes: 197 additions & 0 deletions schnet/CFConv.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#ifndef CFCONV
#define CFCONV

/**
* Copyright (c) 2020 Stanford University and the Authors
* Authors: Peter Eastman
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/

/**
* This class represents a neighbor list for use in computing CFConv layers.
* A single object can be used for all layers in the models. Create it at the
* same time you create the model, call build() every time the atom positions
* change, and pass it to the methods of CFConv objects that do computations
* based on positions.
*
* This is an abstract class. Subclasses provide implementations on particular
* types of hardware.
*/
class CFConvNeighbors {
public:
/**
* Create an object for computing neighbor lists.
*
* @param numAtoms the number of atoms in the system being modeled
* @param cutoff the cutoff distance
* @param periodic whether to use periodic boundary conditions.
*/
CFConvNeighbors(int numAtoms, float cutoff, bool periodic) : numAtoms(numAtoms), cutoff(cutoff), periodic(periodic) {
}
virtual ~CFConvNeighbors() {
}
/**
* Rebuild the neighbor list based on a new set of positions.
*
* @param positions an array of shape [numAtoms][3] containing the position of each atom
* @param periodicBoxVectors an array of shape [3][3] containing the periodic box vectors. If periodic boundary conditions are
* not used, this is ignored and may be NULL.
*/
virtual void build(const float* positions, const float* periodicBoxVectors) = 0;
/**
* Get whether the periodic box vectors specified in the most recent call to build() described
* a triclinic (not rectangular) box.
*/
virtual bool getTriclinic() const = 0;
/**
* Get the number of atoms in the system.
*/
int getNumAtoms() const {
return numAtoms;
}
/**
* Get the cutoff distance.
*/
float getCutoff() const {
return cutoff;
}
/**
* Get whether to apply periodic boundary conditions.
*/
bool getPeriodic() const {
return periodic;
}
private:
int numAtoms;
float cutoff;
bool periodic;
};

/**
* This class computes the continuous filter convolution (cfconv) function used in SchNet.
* Create an instance of this class at the same time you create the model and then reuse it
* for every calculation on that model.
*
* For each pair of atoms, it performs the following calculations:
*
* 1. Compute a set of Gaussian basis functions describing the distance between them.
* 2. Pass them through a dense layer.
* 3. Apply a shifted softplus activation function.
* 4. Pass the result through a second dense layer.
* 5. Apply a cosine cutoff function to make interactions smoothly go to zero at the cutoff.
*
* For each atom, the output is the sum over all neighbors of the above calculation multiplied
* by the neighbor's input vector.
*
* This calculation is designed to match the behavior of SchNetPack. It is similar but not
* identical to that described in the original SchNet publication.
*
* This is an abstract class. Subclasses provide implementations on particular
* types of hardware.
*/
class CFConv {
public:
/**
* Construct on object for computing continuous filter convolution (cfconv) functions.
*
* @param numAtoms the number of atoms in the system
* @param width the number of elements in the input and output vectors
* @param numGaussians the number of Gaussian basis functions to use, uniformly spaced between 0 and cutoff
* @param cutoff the cutoff distance
* @param periodic whether to apply periodic boundary conditions
* @param gaussianWidth the width of the Gaussian basis functions
*/
CFConv(int numAtoms, int width, int numGaussians, float cutoff, bool periodic, float gaussianWidth) :
numAtoms(numAtoms), width(width), numGaussians(numGaussians), cutoff(cutoff), periodic(periodic), gaussianWidth(gaussianWidth) {
}
virtual ~CFConv() {
}
/**
* Compute the output of the layer.
*
* @param neighbors a neighbor list for accelerating the calculation. You must have already called
* build() on the neighbor list with the same positions and box vectors.
* @param positions an array of shape [numAtoms][3] containing the positions of each atom
* @param periodicBoxVectors an array of shape [3][3] containing the periodic box vectors. If periodic boundary conditions are
* not used, this is ignored and may be NULL.
* @param input an array of shape [numAtoms][width] containing the input vectors
* @param output an array of shape [numAtoms][width] to store the output vectors into
*/
virtual void compute(const CFConvNeighbors& neighbors, const float* positions, const float* periodicBoxVectors,
const float* input, float* output) = 0;
/**
* Given the derivatives of some function E (typically energy) with respect to the outputs, backpropagate them
* to find the derivates of E with respect to the inputs and atom positions.
*
* @param neighbors a neighbor list for accelerating the calculation. You must have already called
* build() on the neighbor list with the same positions and box vectors.
* @param positions an array of shape [numAtoms][3] containing the positions of each atom
* @param periodicBoxVectors an array of shape [3][3] containing the periodic box vectors. If periodic boundary conditions are
* not used, this is ignored and may be NULL.
* @param input an array of shape [numAtoms][width] containing the input vectors
* @param outputDeriv an array of shape [numAtoms][width] containing the derivative of E with respect to each output value
* @param inputDeriv an array of shape [numAtoms][width] to store the derivative of E with respect to each input value into
* @param positionDeriv an array of shape [numAtoms][3] to store the derivative of E with respect to the atom positions into
*/
virtual void backprop(const CFConvNeighbors& neighbors, const float* positions, const float* periodicBoxVectors,
const float* input, const float* outputDeriv, float* inputDeriv, float* positionDeriv) = 0;
/**
* Get the number of atoms in the system.
*/
int getNumAtoms() const {
return numAtoms;
}
/**
* Get the number of elements in the input and output vectors.
*/
int getWidth() const {
return width;
}
/**
* Get the number of Gaussian basis functions.
*/
int getNumGaussians() const {
return numGaussians;
}
/**
* Get the cutoff distance.
*/
float getCutoff() const {
return cutoff;
}
/**
* Get whether to apply periodic boundary conditions.
*/
bool getPeriodic() const {
return periodic;
}
/**
* Get the width of the Gaussian basis functions.
*/
float getGaussianWidth() const {
return gaussianWidth;
}
protected:
const int numAtoms, width, numGaussians;
const float cutoff, gaussianWidth;
const bool periodic;
};

#endif
Loading

0 comments on commit 668cf83

Please sign in to comment.