-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CFConv layer for SchNet #18
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
8e2cee0
Began implementing CFConv
peastman a53b834
Continuing CPU implementation of CFConv
peastman 8d30131
Gradients for CPU implementation of CFConv
peastman d2faee4
Added tests for periodic boundary conditions
peastman 269a8b4
Added more documentation
peastman a6b6c7a
Beginning of CUDA implementation
peastman 8e6fbdb
CUDA derivatives of CFConv
peastman d65ac36
Simplified API to allow more optimizations
peastman 8ad4d56
Restructured neighbor list to improve performance
peastman ebcc354
Further optimizations
peastman 9b43134
More optimizations
peastman 6045e2a
Added README
peastman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
#include "CudaCFConv.h" | ||
#include <cmath> | ||
#include <cstdio> | ||
#include <ctime> | ||
#include <fstream> | ||
#include <iostream> | ||
#include <map> | ||
#include <random> | ||
#include <stdexcept> | ||
#include <string> | ||
#include <vector> | ||
|
||
using namespace std; | ||
|
||
void computeBoxVectors(float a, float b, float c, float alpha, float beta, float gamma, vector<float>& periodicBoxVectors) { | ||
float bx = b*cos(gamma); | ||
float by = b*sin(gamma); | ||
float cx = c*cos(beta); | ||
float cy = c*(cos(alpha)-cos(beta)*cos(gamma))/sin(gamma); | ||
float cz = sqrt(c*c-cx*cx-cy*cy); | ||
float scale1 = std::round(cy/by); | ||
cx -= bx*scale1; | ||
cy -= by*scale1; | ||
float scale2 = std::round(cx/a); | ||
cx -= a*scale2; | ||
float scale3 = std::round(bx/a); | ||
bx -= a*scale3; | ||
periodicBoxVectors.push_back(a); | ||
periodicBoxVectors.push_back(0); | ||
periodicBoxVectors.push_back(0); | ||
periodicBoxVectors.push_back(bx); | ||
periodicBoxVectors.push_back(by); | ||
periodicBoxVectors.push_back(0); | ||
periodicBoxVectors.push_back(cx); | ||
periodicBoxVectors.push_back(cy); | ||
periodicBoxVectors.push_back(cz); | ||
} | ||
|
||
void loadPdb(string filename, vector<float>& positions, vector<float>& periodicBoxVectors) { | ||
ifstream file(filename); | ||
if (!file.is_open()) | ||
throw runtime_error("Failed to open PDB file"); | ||
string line; | ||
while (getline(file, line)) { | ||
if (line.rfind("ATOM", 0) == 0 || line.rfind("HETATM", 0) == 0) { | ||
positions.push_back(stof(line.substr(30, 8))); | ||
positions.push_back(stof(line.substr(38, 8))); | ||
positions.push_back(stof(line.substr(46, 8))); | ||
} | ||
if (line.rfind("CRYST1", 0) == 0) { | ||
float a = stof(line.substr(6, 9)); | ||
float b = stof(line.substr(15, 9)); | ||
float c = stof(line.substr(24, 9)); | ||
float alpha = stof(line.substr(33, 7))*M_PI/180; | ||
float beta = stof(line.substr(40, 7))*M_PI/180; | ||
float gamma = stof(line.substr(47, 7))*M_PI/180; | ||
computeBoxVectors(a, b, c, alpha, beta, gamma, periodicBoxVectors); | ||
} | ||
} | ||
file.close(); | ||
} | ||
|
||
void runBenchmark(int iterations, vector<float>& positions, vector<float>& periodicBoxVectors) { | ||
int numAtoms = positions.size()/3; | ||
int width = 128; | ||
int numGaussians = 50; | ||
float cutoff = 10; | ||
|
||
// Generate random weights and biases. We don't care about the values, since they | ||
// don't affect speed. | ||
|
||
vector<float> w1, b1, w2, b2; | ||
std::default_random_engine generator(0); | ||
std::normal_distribution<double> distribution(0, 1); | ||
for (int i = 0; i < width; i++) { | ||
b1.push_back(distribution(generator)); | ||
b2.push_back(distribution(generator)); | ||
for (int j = 0; j < numGaussians; j++) | ||
w1.push_back(distribution(generator)); | ||
for (int j = 0; j < width; j++) | ||
w2.push_back(distribution(generator)); | ||
} | ||
|
||
// Allocate all the memory we will need. | ||
|
||
CudaCFConvNeighbors neighbors(numAtoms, cutoff, periodicBoxVectors.size() > 0); | ||
CudaCFConv cfconv(numAtoms, width, numGaussians, cutoff, periodicBoxVectors.size() > 0, 0.2, w1.data(), b1.data(), w2.data(), b2.data()); | ||
float *positionsData, *vectorsData, *input, *output, *inputDerivs, *outputDerivs, *positionDerivs; | ||
cudaMalloc(&positionsData, positions.size()*sizeof(float)); | ||
cudaMalloc(&vectorsData, 9*sizeof(float)); | ||
cudaMallocManaged(&input, numAtoms*width*sizeof(float)); | ||
cudaMalloc(&output, numAtoms*width*sizeof(float)); | ||
cudaMalloc(&inputDerivs, numAtoms*width*sizeof(float)); | ||
cudaMalloc(&outputDerivs, numAtoms*width*sizeof(float)); | ||
cudaMalloc(&positionDerivs, positions.size()*sizeof(float)); | ||
cudaMemcpy(positionsData, positions.data(), positions.size()*sizeof(float), cudaMemcpyDefault); | ||
cudaMemcpy(vectorsData, periodicBoxVectors.data(), periodicBoxVectors.size()*sizeof(float), cudaMemcpyDefault); | ||
for (int i = 0; i < width; i++) | ||
for (int j = 0; j < numAtoms; j++) | ||
input[j*width+i] = distribution(generator); | ||
|
||
// Run the benchmark. | ||
|
||
clock_t start = clock(); | ||
for (int i = 0; i < iterations; i++) { | ||
neighbors.build(positionsData, vectorsData); | ||
for (int j = 0; j < 6; j++) { | ||
cfconv.compute(neighbors, positionsData, vectorsData, input, output); | ||
cfconv.backprop(neighbors, positionsData, vectorsData, input, outputDerivs, inputDerivs, positionDerivs); | ||
} | ||
} | ||
cudaDeviceSynchronize(); | ||
clock_t finish = clock(); | ||
double duration = (double) (finish-start)/CLOCKS_PER_SEC; | ||
printf(" %f sec\n", duration); | ||
printf(" %f ms/iteration\n", duration/iterations*1000); | ||
|
||
// Release device memory. | ||
|
||
cudaFree(positionsData); | ||
cudaFree(vectorsData); | ||
cudaFree(input); | ||
cudaFree(output); | ||
cudaFree(inputDerivs); | ||
cudaFree(outputDerivs); | ||
cudaFree(positionDerivs); | ||
} | ||
|
||
int main(int argc, char* argv[]) { | ||
try { | ||
if (argc != 3) | ||
throw runtime_error("Expected two command line arguments"); | ||
vector<float> positions, periodicBoxVectors; | ||
loadPdb(argv[1], positions, periodicBoxVectors); | ||
runBenchmark(stoi(argv[2]), positions, periodicBoxVectors); | ||
} | ||
catch (const exception& e) { | ||
cout << e.what() << endl; | ||
return 1; | ||
} | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
#ifndef CFCONV | ||
#define CFCONV | ||
|
||
/** | ||
* Copyright (c) 2020 Stanford University and the Authors | ||
* Authors: Peter Eastman | ||
* | ||
* Permission is hereby granted, free of charge, to any person obtaining a copy | ||
* of this software and associated documentation files (the "Software"), to deal | ||
* in the Software without restriction, including without limitation the rights | ||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
* copies of the Software, and to permit persons to whom the Software is | ||
* furnished to do so, subject to the following conditions: | ||
* | ||
* The above copyright notice and this permission notice shall be included in all | ||
* copies or substantial portions of the Software. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
* SOFTWARE. | ||
*/ | ||
|
||
/** | ||
* This class represents a neighbor list for use in computing CFConv layers. | ||
* A single object can be used for all layers in the models. Create it at the | ||
* same time you create the model, call build() every time the atom positions | ||
* change, and pass it to the methods of CFConv objects that do computations | ||
* based on positions. | ||
* | ||
* This is an abstract class. Subclasses provide implementations on particular | ||
* types of hardware. | ||
*/ | ||
class CFConvNeighbors { | ||
public: | ||
/** | ||
* Create an object for computing neighbor lists. | ||
* | ||
* @param numAtoms the number of atoms in the system being modeled | ||
* @param cutoff the cutoff distance | ||
* @param periodic whether to use periodic boundary conditions. | ||
*/ | ||
CFConvNeighbors(int numAtoms, float cutoff, bool periodic) : numAtoms(numAtoms), cutoff(cutoff), periodic(periodic) { | ||
} | ||
virtual ~CFConvNeighbors() { | ||
} | ||
/** | ||
* Rebuild the neighbor list based on a new set of positions. | ||
* | ||
* @param positions an array of shape [numAtoms][3] containing the position of each atom | ||
* @param periodicBoxVectors an array of shape [3][3] containing the periodic box vectors. If periodic boundary conditions are | ||
* not used, this is ignored and may be NULL. | ||
*/ | ||
virtual void build(const float* positions, const float* periodicBoxVectors) = 0; | ||
/** | ||
* Get whether the periodic box vectors specified in the most recent call to build() described | ||
* a triclinic (not rectangular) box. | ||
*/ | ||
virtual bool getTriclinic() const = 0; | ||
/** | ||
* Get the number of atoms in the system. | ||
*/ | ||
int getNumAtoms() const { | ||
return numAtoms; | ||
} | ||
/** | ||
* Get the cutoff distance. | ||
*/ | ||
float getCutoff() const { | ||
return cutoff; | ||
} | ||
/** | ||
* Get whether to apply periodic boundary conditions. | ||
*/ | ||
bool getPeriodic() const { | ||
return periodic; | ||
} | ||
private: | ||
int numAtoms; | ||
float cutoff; | ||
bool periodic; | ||
}; | ||
|
||
/** | ||
* This class computes the continuous filter convolution (cfconv) function used in SchNet. | ||
* Create an instance of this class at the same time you create the model and then reuse it | ||
* for every calculation on that model. | ||
* | ||
* For each pair of atoms, it performs the following calculations: | ||
* | ||
* 1. Compute a set of Gaussian basis functions describing the distance between them. | ||
* 2. Pass them through a dense layer. | ||
* 3. Apply a shifted softplus activation function. | ||
* 4. Pass the result through a second dense layer. | ||
* 5. Apply a cosine cutoff function to make interactions smoothly go to zero at the cutoff. | ||
* | ||
* For each atom, the output is the sum over all neighbors of the above calculation multiplied | ||
* by the neighbor's input vector. | ||
* | ||
* This calculation is designed to match the behavior of SchNetPack. It is similar but not | ||
* identical to that described in the original SchNet publication. | ||
* | ||
* This is an abstract class. Subclasses provide implementations on particular | ||
* types of hardware. | ||
*/ | ||
class CFConv { | ||
public: | ||
/** | ||
* Construct on object for computing continuous filter convolution (cfconv) functions. | ||
* | ||
* @param numAtoms the number of atoms in the system | ||
* @param width the number of elements in the input and output vectors | ||
* @param numGaussians the number of Gaussian basis functions to use, uniformly spaced between 0 and cutoff | ||
* @param cutoff the cutoff distance | ||
* @param periodic whether to apply periodic boundary conditions | ||
* @param gaussianWidth the width of the Gaussian basis functions | ||
*/ | ||
CFConv(int numAtoms, int width, int numGaussians, float cutoff, bool periodic, float gaussianWidth) : | ||
numAtoms(numAtoms), width(width), numGaussians(numGaussians), cutoff(cutoff), periodic(periodic), gaussianWidth(gaussianWidth) { | ||
} | ||
virtual ~CFConv() { | ||
} | ||
/** | ||
* Compute the output of the layer. | ||
* | ||
* @param neighbors a neighbor list for accelerating the calculation. You must have already called | ||
* build() on the neighbor list with the same positions and box vectors. | ||
* @param positions an array of shape [numAtoms][3] containing the positions of each atom | ||
* @param periodicBoxVectors an array of shape [3][3] containing the periodic box vectors. If periodic boundary conditions are | ||
* not used, this is ignored and may be NULL. | ||
* @param input an array of shape [numAtoms][width] containing the input vectors | ||
* @param output an array of shape [numAtoms][width] to store the output vectors into | ||
*/ | ||
virtual void compute(const CFConvNeighbors& neighbors, const float* positions, const float* periodicBoxVectors, | ||
const float* input, float* output) = 0; | ||
/** | ||
* Given the derivatives of some function E (typically energy) with respect to the outputs, backpropagate them | ||
* to find the derivates of E with respect to the inputs and atom positions. | ||
* | ||
* @param neighbors a neighbor list for accelerating the calculation. You must have already called | ||
* build() on the neighbor list with the same positions and box vectors. | ||
* @param positions an array of shape [numAtoms][3] containing the positions of each atom | ||
* @param periodicBoxVectors an array of shape [3][3] containing the periodic box vectors. If periodic boundary conditions are | ||
* not used, this is ignored and may be NULL. | ||
* @param input an array of shape [numAtoms][width] containing the input vectors | ||
* @param outputDeriv an array of shape [numAtoms][width] containing the derivative of E with respect to each output value | ||
* @param inputDeriv an array of shape [numAtoms][width] to store the derivative of E with respect to each input value into | ||
* @param positionDeriv an array of shape [numAtoms][3] to store the derivative of E with respect to the atom positions into | ||
*/ | ||
virtual void backprop(const CFConvNeighbors& neighbors, const float* positions, const float* periodicBoxVectors, | ||
const float* input, const float* outputDeriv, float* inputDeriv, float* positionDeriv) = 0; | ||
/** | ||
* Get the number of atoms in the system. | ||
*/ | ||
int getNumAtoms() const { | ||
return numAtoms; | ||
} | ||
/** | ||
* Get the number of elements in the input and output vectors. | ||
*/ | ||
int getWidth() const { | ||
return width; | ||
} | ||
/** | ||
* Get the number of Gaussian basis functions. | ||
*/ | ||
int getNumGaussians() const { | ||
return numGaussians; | ||
} | ||
/** | ||
* Get the cutoff distance. | ||
*/ | ||
float getCutoff() const { | ||
return cutoff; | ||
} | ||
/** | ||
* Get whether to apply periodic boundary conditions. | ||
*/ | ||
bool getPeriodic() const { | ||
return periodic; | ||
} | ||
/** | ||
* Get the width of the Gaussian basis functions. | ||
*/ | ||
float getGaussianWidth() const { | ||
return gaussianWidth; | ||
} | ||
protected: | ||
const int numAtoms, width, numGaussians; | ||
const float cutoff, gaussianWidth; | ||
const bool periodic; | ||
}; | ||
|
||
#endif |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
w1
,b1
, etc. is described above, but they are missing here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, I should remove those. The subclass constructors take those arguments, but they don't pass them on to the parent class.