Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizations to CudaANISymmetryFunctions #7

Merged
merged 1 commit into from
Oct 2, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 47 additions & 35 deletions ani/CudaANISymmetryFunctions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,20 @@ template <bool PERIODIC, bool TRICLINIC, bool TORCHANI>
__global__ void computeAngularFunctions(int numAtoms, int numSpecies, int numAngular, float angularCutoff, float* angular,
int* neighbors, int* neighborCount, const float3* positions, const float* periodicBoxVectors,
const AngularFunction* angularFunctions, const int* atomSpecies, const int* angularIndex) {
const int warp = (blockIdx.x*blockDim.x+threadIdx.x)/32;
const int indexInWarp = threadIdx.x%32;
const int numWarps = (gridDim.x*blockDim.x)/32;
const float3 invBoxSize = (PERIODIC ? make_float3(1/periodicBoxVectors[0], 1/periodicBoxVectors[4], 1/periodicBoxVectors[8]) : make_float3(0, 0, 0));
const int c1 = numAngular;
const int c2 = numSpecies*(numSpecies+1)*c1/2;

// Each warp loops over atoms.
// Each thread block loops over atoms.

for (int atom1 = warp; atom1 < numAtoms; atom1 += numWarps) {
for (int atom1 = blockIdx.x; atom1 < numAtoms; atom1 += gridDim.x) {
float3 pos1 = positions[atom1];
int numNeighbors = neighborCount[atom1];

// The threads in the warp loop over pairs of atoms from the neighbor list.
// The threads in the block loop over pairs of atoms from the neighbor list.

int numPairs = numNeighbors*(numNeighbors-1)/2;
for (int pair = indexInWarp; pair < numPairs; pair += 32) {
for (int pair = threadIdx.x; pair < numPairs; pair += blockDim.x) {
int i = (int) floorf(numNeighbors-0.5f-sqrtf((numNeighbors-0.5f)*(numNeighbors-0.5f)-2*pair));
int j = pair - i*numNeighbors + (i+1)*(i+2)/2;
int atom2 = neighbors[atom1*numAtoms + i];
Expand Down Expand Up @@ -300,16 +297,30 @@ void CudaANISymmetryFunctions::computeSymmetryFunctions(const float* positions,
// Record the positions and periodic box vectors.

CHECK_RESULT(cudaMemcpyAsync(this->positions, positions, 3*numAtoms*sizeof(float), cudaMemcpyDefault));
if (periodic)
CHECK_RESULT(cudaMemcpyAsync(this->periodicBoxVectors, periodicBoxVectors, 9*sizeof(float), cudaMemcpyDefault));
float* hostBoxVectors;
if (periodic) {
// We'll need to access the box vectors on both host and device. Figure out the most
// efficient way of doing that.

cudaPointerAttributes attrib;
cudaError_t result = cudaPointerGetAttributes(&attrib, periodicBoxVectors);
if (result != cudaSuccess || attrib.hostPointer == 0) {
CHECK_RESULT(cudaMemcpy(this->periodicBoxVectors, periodicBoxVectors, 9*sizeof(float), cudaMemcpyDefault));
hostBoxVectors = this->periodicBoxVectors;
}
else {
CHECK_RESULT(cudaMemcpyAsync(this->periodicBoxVectors, periodicBoxVectors, 9*sizeof(float), cudaMemcpyDefault));
hostBoxVectors = (float*) attrib.hostPointer;
}
}

// Determine whether we have a rectangular or triclinic periodic box.

triclinic = false;
if (periodic)
for (int i = 0 ; i < 3; i++)
for (int j = 0; j < 3; j++)
if (i != j && periodicBoxVectors[3*i+j] != 0)
if (i != j && hostBoxVectors[3*i+j] != 0)
triclinic = true;

// Clear the output arrays.
Expand All @@ -320,7 +331,7 @@ void CudaANISymmetryFunctions::computeSymmetryFunctions(const float* positions,
// Compute the symmetry functions.

int blockSize = 128;
int numBlocks = min(maxBlocks, (int) ceil(numAtoms/4.0));
int numBlocks = min(maxBlocks, numAtoms);
int numRadial = radialFunctions.size();
int numAngular = angularFunctions.size();
if (periodic) {
Expand Down Expand Up @@ -362,58 +373,62 @@ __global__ void backpropRadialFunctions(int numAtoms, int numSpecies, int numRad
float radialCutoff, const float* radialDeriv, float* positionDeriv,
const float3* positions, const float* periodicBoxVectors,
const RadialFunction* radialFunctions, const int* atomSpecies, bool torchani) {
const int warp = (blockIdx.x*blockDim.x+threadIdx.x)/32;
const int indexInWarp = threadIdx.x%32;
const int numWarps = (gridDim.x*blockDim.x)/32;
const float3 invBoxSize = (PERIODIC ? make_float3(1/periodicBoxVectors[0], 1/periodicBoxVectors[4], 1/periodicBoxVectors[8]) : make_float3(0, 0, 0));
const int c1 = numRadial;
const int c2 = numSpecies*c1;
const float radialCutoff2 = radialCutoff*radialCutoff;
const float globalScale = (torchani ? 0.25f : 1.0f);

// Each warp loops over atoms.
// Each thread block loops over atoms.

for (int atom1 = warp; atom1 < numAtoms; atom1 += numWarps) {
for (int atom1 = blockIdx.x; atom1 < numAtoms; atom1 += gridDim.x) {
float3 pos1 = positions[atom1];
float3 posDeriv1 = make_float3(0, 0, 0);

// The threads in the warp loop over second atoms.
// The threads in the block loop over second atoms.

for (int atom2 = indexInWarp; atom2 < numAtoms; atom2 += 32) {
for (int atom2 = atom1+1+threadIdx.x; atom2 < numAtoms; atom2 += blockDim.x) {
float3 pos2 = positions[atom2];
float3 delta;
float r2;
computeDisplacement<PERIODIC, TRICLINIC>(pos1, pos2, delta, r2, periodicBoxVectors, invBoxSize);

// Compute the derivatives of the symmetry functions.

if (r2 < radialCutoff2 && atom1 != atom2) {
if (r2 < radialCutoff2) {
float r = sqrtf(r2);
float rInv = 1/r;
float cutoff = cutoffFunction(r, radialCutoff);
float dCdR = cutoffDeriv(r, radialCutoff);
float3 posDeriv2 = make_float3(0, 0, 0);
for (int i = 0; i < numRadial; i++) {
const RadialFunction fn = radialFunctions[i];
float shifted = r-fn.rs;
float expTerm = expf(-fn.eta*shifted*shifted);
float dVdR = dCdR*expTerm - cutoff*2*fn.eta*shifted*expTerm;
float dEdV = radialDeriv[atom1*c2 + atomSpecies[atom2]*c1 + i] + radialDeriv[atom2*c2 + atomSpecies[atom1]*c1 + i];
float scale = globalScale * dEdV * dVdR * rInv;
posDeriv1.x += scale*delta.x;
posDeriv1.y += scale*delta.y;
posDeriv1.z += scale*delta.z;
posDeriv1.x -= scale*delta.x;
posDeriv1.y -= scale*delta.y;
posDeriv1.z -= scale*delta.z;
posDeriv2.x += scale*delta.x;
posDeriv2.y += scale*delta.y;
posDeriv2.z += scale*delta.z;
}
atomicAdd(&positionDeriv[3*atom2], posDeriv2.x);
atomicAdd(&positionDeriv[3*atom2+1], posDeriv2.y);
atomicAdd(&positionDeriv[3*atom2+2], posDeriv2.z);
}
}
for (int offset = 16; offset > 0; offset /= 2) {
posDeriv1.x += __shfl_down_sync(0xFFFFFFFF, posDeriv1.x, offset);
posDeriv1.y += __shfl_down_sync(0xFFFFFFFF, posDeriv1.y, offset);
posDeriv1.z += __shfl_down_sync(0xFFFFFFFF, posDeriv1.z, offset);
}
if (indexInWarp == 0) {
positionDeriv[3*atom1] -= posDeriv1.x;
positionDeriv[3*atom1+1] -= posDeriv1.y;
positionDeriv[3*atom1+2] -= posDeriv1.z;
if (threadIdx.x%32 == 0) {
atomicAdd(&positionDeriv[3*atom1], posDeriv1.x);
atomicAdd(&positionDeriv[3*atom1+1], posDeriv1.y);
atomicAdd(&positionDeriv[3*atom1+2], posDeriv1.z);
}
}
}
Expand All @@ -423,24 +438,21 @@ __global__ void backpropAngularFunctions(int numAtoms, int numSpecies, int numAn
const float* angularDeriv, float* positionDeriv, int* neighbors, int* neighborCount,
const float3* positions, const float* periodicBoxVectors, const AngularFunction* angularFunctions,
const int* atomSpecies, const int* angularIndex) {
const int warp = (blockIdx.x*blockDim.x+threadIdx.x)/32;
const int indexInWarp = threadIdx.x%32;
const int numWarps = (gridDim.x*blockDim.x)/32;
const float3 invBoxSize = (PERIODIC ? make_float3(1/periodicBoxVectors[0], 1/periodicBoxVectors[4], 1/periodicBoxVectors[8]) : make_float3(0, 0, 0));
const int c1 = numAngular;
const int c2 = numSpecies*(numSpecies+1)*c1/2;

// Each warp loops over atoms.
// Each thread block loops over atoms.

for (int atom1 = warp; atom1 < numAtoms; atom1 += numWarps) {
for (int atom1 = blockIdx.x; atom1 < numAtoms; atom1 += gridDim.x) {
float3 pos1 = positions[atom1];
float3 posDeriv1 = make_float3(0, 0, 0);
int numNeighbors = neighborCount[atom1];

// The threads in the warp loop over pairs of atoms from the neighbor list.
// The threads in the block loop over pairs of atoms from the neighbor list.

int numPairs = numNeighbors*(numNeighbors-1)/2;
for (int pair = indexInWarp; pair < numPairs; pair += 32) {
for (int pair = threadIdx.x; pair < numPairs; pair += blockDim.x) {
int i = (int) floorf(numNeighbors-0.5f-sqrtf((numNeighbors-0.5f)*(numNeighbors-0.5f)-2*pair));
int j = pair - i*numNeighbors + (i+1)*(i+2)/2;
int atom2 = neighbors[atom1*numAtoms + i];
Expand Down Expand Up @@ -538,7 +550,7 @@ __global__ void backpropAngularFunctions(int numAtoms, int numSpecies, int numAn
posDeriv1.y += __shfl_down_sync(0xFFFFFFFF, posDeriv1.y, offset);
posDeriv1.z += __shfl_down_sync(0xFFFFFFFF, posDeriv1.z, offset);
}
if (indexInWarp == 0) {
if (threadIdx.x%32 == 0) {
atomicAdd(&positionDeriv[3*atom1], posDeriv1.x);
atomicAdd(&positionDeriv[3*atom1+1], posDeriv1.y);
atomicAdd(&positionDeriv[3*atom1+2], posDeriv1.z);
Expand All @@ -556,7 +568,7 @@ void CudaANISymmetryFunctions::backprop(const float* radialDeriv, const float* a
// Backpropagate through the symmetry functions.

int blockSize = 128;
int numBlocks = min(maxBlocks, (int) ceil(numAtoms/4.0));
int numBlocks = min(maxBlocks, numAtoms);
int numRadial = radialFunctions.size();
int numAngular = angularFunctions.size();
if (periodic) {
Expand Down