Skip to content

Commit

Permalink
update rearrange
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonLeeJSL committed Dec 6, 2023
1 parent 6554d67 commit be6620a
Show file tree
Hide file tree
Showing 5 changed files with 539 additions and 131 deletions.
4 changes: 4 additions & 0 deletions d4ft/integral/gto/tensorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,12 @@ def tensorize(
for i in range(4)
]
)
import time
t1 = time.time()
t_abcd = hartree_uncontracted(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms)
jax.block_until_ready(t_abcd)
t2 = time.time()
print("Current abab cuda time =",t2-t1)
if not cgto:
return t_abcd
counts_abcd_i = idx_counts[:, 4]
Expand Down
207 changes: 189 additions & 18 deletions d4ft/native/obara_saika/eri_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ void Hartree_32::Gpu(cudaStream_t stream,

void Hartree_64::Gpu(cudaStream_t stream,
Array<const int>& N,
Array<const int>& screened_length,
Array<const int>& thread_load,
Array<const int64_t>& thread_num,
Array<const int64_t>& screened_length,
Array<const int>& n,
Array<const double>& r,
Array<const double>& z,
Expand All @@ -95,20 +97,170 @@ void Hartree_64::Gpu(cudaStream_t stream,
Array<const int>& sorted_ab_idx,
Array<const int>& sorted_cd_idx,
Array<const int>& screened_cd_idx_start,
Array<const int>& screened_idx_offset,
//Array<const int>& screened_idx_offset,
Array<const int>& ab_thread_num,
Array<const int>& ab_thread_offset,
Array<const double>& pgto_coeff,
Array<const double>& pgto_normalization_factor,
Array<const int>& pgto_idx_to_cgto_idx,
Array<const double>& rdm1,
Array<const int>& n_cgto,
Array<const int>& n_pgto,
Array<double>& output) {
// Prescreening
int* idx_4c;
int idx_length;
cudaMemcpy(&idx_length, screened_length.ptr, sizeof(int), cudaMemcpyDeviceToHost);
cudaMalloc((void **)&idx_4c, 2 * idx_length * sizeof(int));
std::cout<<idx_length<<std::endl;
int* thread_ab_index;
int64_t thread_length;
cudaMemcpy(&thread_length, thread_num.ptr, sizeof(int64_t), cudaMemcpyDeviceToHost);
cudaMalloc((void **)&thread_ab_index, thread_length * sizeof(int));
std::cout<<thread_length<<std::endl;
int num_cd = sorted_cd_idx.spec->shape[0];
int* ncd;
double* rcd;
double* zcd;
cudaMalloc((void **)&ncd, 3 * 2 * num_cd * sizeof(int));
cudaMalloc((void **)&rcd, 3 * 2 * num_cd * sizeof(double));
cudaMalloc((void **)&zcd, 2 * num_cd * sizeof(double));

// Pre-screen, result is (ab_index, cd_index), i.e. (ab, cd)
hemi::ExecutionPolicy ep;
ep.setStream(stream);
hemi::parallel_for(ep, 0, ab_thread_num.spec->shape[0], [=] HEMI_LAMBDA(int index) {
for(int i = 0; i < ab_thread_num.ptr[index]; i++ ){
int loc;
loc = ab_thread_offset.ptr[index] + i;
thread_ab_index[loc] = index;
// output.ptr[loc] = sorted_ab_idx.ptr[index]; // ab
// output.ptr[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd
}
__syncthreads();
});

// get ncd, rcd, zcd in cd order
hemi::parallel_for(ep, 0, sorted_cd_idx.spec->shape[0], [=] HEMI_LAMBDA(int index) {
int cd;
int c,d;
cd = sorted_cd_idx.ptr[index];
triu_ij_from_index(N.ptr[0], cd, &c, &d);
ncd[6*index + 0] = n.ptr[0 * N.ptr[0] + c];
ncd[6*index + 1] = n.ptr[1 * N.ptr[0] + c];
ncd[6*index + 2] = n.ptr[2 * N.ptr[0] + c];
ncd[6*index + 3] = n.ptr[0 * N.ptr[0] + d];
ncd[6*index + 4] = n.ptr[1 * N.ptr[0] + d];
ncd[6*index + 5] = n.ptr[2 * N.ptr[0] + d];
rcd[6*index + 0] = r.ptr[0 * N.ptr[0] + c];
rcd[6*index + 1] = r.ptr[1 * N.ptr[0] + c];
rcd[6*index + 2] = r.ptr[2 * N.ptr[0] + c];
rcd[6*index + 3] = r.ptr[0 * N.ptr[0] + d];
rcd[6*index + 4] = r.ptr[1 * N.ptr[0] + d];
rcd[6*index + 5] = r.ptr[2 * N.ptr[0] + d];
zcd[2*index + 0] = z.ptr[c];
zcd[2*index + 1] = z.ptr[d];
__syncthreads();
});

cudaMemset(output.ptr, 0, sizeof(double));
// Now we have ab cd, we can compute eri and contract it to output
// For contract, we need 1. count 2. pgto normalization coeff 3. pgto coeff 4.rdm1 (Mocoeff)
hemi::parallel_for(ep, 0, thread_length, [=] HEMI_LAMBDA(int64_t index) {
int a, b, c, d; // pgto 4c idx
int i, j, k, l; // cgto 4c idx
int ab_index, cd_index;
int ab, cd;
double eri_result;
double Na, Nb, Nc, Nd;
double Ca, Cb, Cc, Cd;
double Mab, Mcd;
int count;
int nax, nay, naz, nbx, nby, nbz, ncx, ncy, ncz, ndx, ndy, ndz;
double rax, ray, raz, rbx, rby, rbz, rcx, rcy, rcz, rdx, rdy, rdz;
double za, zb, zc, zd;

ab_index = thread_ab_index[index];
ab = sorted_ab_idx.ptr[ab_index];
triu_ij_from_index(N.ptr[0], ab, &a, &b);
nax = n.ptr[0 * N.ptr[0] + a];
nay = n.ptr[1 * N.ptr[0] + a];
naz = n.ptr[2 * N.ptr[0] + a];
nbx = n.ptr[0 * N.ptr[0] + b];
nby = n.ptr[1 * N.ptr[0] + b];
nbz = n.ptr[2 * N.ptr[0] + b];
rax = r.ptr[0 * N.ptr[0] + a];
ray = r.ptr[1 * N.ptr[0] + a];
raz = r.ptr[2 * N.ptr[0] + a];
rbx = r.ptr[0 * N.ptr[0] + b];
rby = r.ptr[1 * N.ptr[0] + b];
rbz = r.ptr[2 * N.ptr[0] + b];
za = z.ptr[a];
zb = z.ptr[b];
Ca = pgto_coeff.ptr[a];
Cb = pgto_coeff.ptr[b];
Na = pgto_normalization_factor.ptr[a];
Nb = pgto_normalization_factor.ptr[b];
i = pgto_idx_to_cgto_idx.ptr[a];
j = pgto_idx_to_cgto_idx.ptr[b];
Mab = rdm1.ptr[i*n_cgto.ptr[0] + j];
eri_result = 0;
for(int cur_ptr = 0; cur_ptr < thread_load.ptr[0]; cur_ptr++ ){
cd_index = screened_cd_idx_start.ptr[ab_index] + index % ab_thread_num.ptr[ab_index] + cur_ptr * ab_thread_num.ptr[ab_index];
if(cd_index < num_cd){
ncx = ncd[6*cd_index + 0];
ncy = ncd[6*cd_index + 1];
ncz = ncd[6*cd_index + 2];
ndx = ncd[6*cd_index + 3];
ndy = ncd[6*cd_index + 4];
ndz = ncd[6*cd_index + 5];
rcx = rcd[6*cd_index + 0];
rcy = rcd[6*cd_index + 1];
rcz = rcd[6*cd_index + 2];
rdx = rcd[6*cd_index + 3];
rdy = rcd[6*cd_index + 4];
rdz = rcd[6*cd_index + 5];
zc = zcd[2*cd_index + 0];
zd = zcd[2*cd_index + 1];

cd = sorted_cd_idx.ptr[cd_index];
triu_ij_from_index(N.ptr[0], cd, &c, &d);
get_symmetry_count(a, b, c, d, &count);
double dcount = static_cast<double>(count);
Cc = pgto_coeff.ptr[c];
Cd = pgto_coeff.ptr[d];
Nc = pgto_normalization_factor.ptr[c];
Nd = pgto_normalization_factor.ptr[d];
k = pgto_idx_to_cgto_idx.ptr[c];
l = pgto_idx_to_cgto_idx.ptr[d];
Mcd = rdm1.ptr[k*n_cgto.ptr[0] + l];
eri_result += eri<double>(nax, nay, naz, // a
nbx, nby, nbz, // b
ncx, ncy, ncz, // c
ndx, ndy, ndz, // d
rax, ray, raz, // a
rbx, rby, rbz, // b
rcx, rcy, rcz, // c
rdx, rdy, rdz, // d
za, zb, zc, zd, // z
min_a.ptr, min_c.ptr,
max_ab.ptr, max_cd.ptr, Ms.ptr) * dcount * Na * Nb * Nc * Nd * Ca * Cb * Cc * Cd * Mab * Mcd;
// eri_result += dcount * Na * Nb * Nc * Nd * Ca * Cb * Cc * Cd * Mab * Mcd;
}
}
atomicAdd(output.ptr, eri_result);
__syncthreads();
});



// Build abcd index list version
/*
int* ab_cd_idx;
int64_t ab_cd_idx_length;
cudaMemcpy(&ab_cd_idx_length, screened_length.ptr, sizeof(int64_t), cudaMemcpyDeviceToHost);
cudaMalloc((void **)&ab_cd_idx, 2 * ab_cd_idx_length * sizeof(int));
std::cout<<ab_cd_idx_length<<std::endl;
int num_cd = sorted_cd_idx.spec->shape[0];
// cudaEvent_t start, stop;
// cudaEventCreate(&start);
// cudaEventCreate(&stop);
// cudaEventRecord(start);
// Pre-screen, result is (ab_index, cd_index), i.e. (ab, cd)
hemi::ExecutionPolicy ep;
Expand All @@ -117,40 +269,48 @@ void Hartree_64::Gpu(cudaStream_t stream,
for(int i = screened_cd_idx_start.ptr[index]; i < num_cd; i++ ){
int loc;
loc = screened_idx_offset.ptr[index] + i - screened_cd_idx_start.ptr[index];
idx_4c[loc] = sorted_ab_idx.ptr[index]; // ab
idx_4c[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd
ab_cd_idx[loc] = sorted_ab_idx.ptr[index]; // ab
ab_cd_idx[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd
// output.ptr[loc] = sorted_ab_idx.ptr[index]; // ab
// output.ptr[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd
}
__syncthreads();
});
cudaMemset(output.ptr, 0, sizeof(double));
// Now we have ab cd, we can compute eri and contract it to output
// For contract, we need 1. count 2. pgto normalization coeff 3. pgto coeff 4.rdm1 (Mocoeff)
hemi::parallel_for(ep, 0, idx_length, [=] HEMI_LAMBDA(int index) {
hemi::parallel_for(ep, 0, ab_cd_idx_length, [=] HEMI_LAMBDA(int64_t index) {
int a, b, c, d; // pgto 4c idx
int i, j, k, l; // cgto 4c idx
double eri_result;
double Na, Nb, Nc, Nd;
double Ca, Cb, Cc, Cd;
double Mab, Mcd;
int count;
triu_ij_from_index(N.ptr[0], idx_4c[index], &a, &b);
triu_ij_from_index(N.ptr[0], idx_4c[index + screened_length.ptr[0]], &c, &d);
triu_ij_from_index(N.ptr[0], ab_cd_idx[index], &a, &b);
triu_ij_from_index(N.ptr[0], ab_cd_idx[index + screened_length.ptr[0]], &c, &d);
get_symmetry_count(a, b, c, d, &count);
double dcount = static_cast<double>(count);
Ca = pgto_coeff.ptr[a];
// pgto coeff
Ca = pgto_coeff.ptr[a];
Cb = pgto_coeff.ptr[b];
Cc = pgto_coeff.ptr[c];
Cd = pgto_coeff.ptr[d];
// pgto normalization factor
Na = pgto_normalization_factor.ptr[a];
Nb = pgto_normalization_factor.ptr[b];
Nc = pgto_normalization_factor.ptr[c];
Nd = pgto_normalization_factor.ptr[d];
// cgto i j k l
i = pgto_idx_to_cgto_idx.ptr[a];
j = pgto_idx_to_cgto_idx.ptr[b];
k = pgto_idx_to_cgto_idx.ptr[c];
l = pgto_idx_to_cgto_idx.ptr[d];
// rdm1_ab, rdm1_cd
Mab = rdm1.ptr[i*n_cgto.ptr[0] + j];
Mcd = rdm1.ptr[k*n_cgto.ptr[0] + k];
Mcd = rdm1.ptr[k*n_cgto.ptr[0] + l];
eri_result = eri<double>(n.ptr[0 * N.ptr[0] + a], n.ptr[1 * N.ptr[0] + a], n.ptr[2 * N.ptr[0] + a], // a
n.ptr[0 * N.ptr[0] + b], n.ptr[1 * N.ptr[0] + b], n.ptr[2 * N.ptr[0] + b], // b
n.ptr[0 * N.ptr[0] + c], n.ptr[1 * N.ptr[0] + c], n.ptr[2 * N.ptr[0] + c], // c
Expand All @@ -162,10 +322,21 @@ void Hartree_64::Gpu(cudaStream_t stream,
z.ptr[a], z.ptr[b], z.ptr[c], z.ptr[d], // z
min_a.ptr, min_c.ptr, max_ab.ptr, max_cd.ptr, Ms.ptr);
eri_result = eri_result * dcount * Na * Nb * Nc * Nd * Ca * Cb * Cc * Cd * Mab * Mcd;
// eri_result = Mab;
// output.ptr[index] = eri_result;
// eri_result = eri_result * dcount * Na * Nb * Nc * Nd * Ca * Cb * Cc * Cd; // * Mab * Mcd;
// prod result from rdm1
atomicAdd(output.ptr, eri_result);
__syncthreads();
});

// cudaEventRecord(stop);
// cudaEventSynchronize(stop);
// float milliseconds = 0;
// cudaEventElapsedTime(&milliseconds, start, stop);
// std::cout << "Elapsed time: " << milliseconds << " ms" << std::endl;
// cudaEventDestroy(start);
// cudaEventDestroy(stop);
// std::cout<<index_4c.spec->shape[0]<<std::endl;
// hemi::ExecutionPolicy ep;
// ep.setStream(stream);
Expand All @@ -189,7 +360,7 @@ void Hartree_64::Gpu(cudaStream_t stream,
// r.ptr[0 * N.ptr[0] + l], r.ptr[1 * N.ptr[0] + l], r.ptr[2 * N.ptr[0] + l], // d
// z.ptr[i], z.ptr[j], z.ptr[k], z.ptr[l], // z
// min_a.ptr, min_c.ptr, max_ab.ptr, max_cd.ptr, Ms.ptr);
// });
// });*/
}

// template <typename FLOAT>
Expand Down Expand Up @@ -246,7 +417,7 @@ void Hartree_64_uncontracted::Gpu(cudaStream_t stream,
// std::cout<<index_4c.spec->shape[0]<<std::endl;
hemi::ExecutionPolicy ep;
ep.setStream(stream);
hemi::parallel_for(ep, 0, index_4c.spec->shape[0], [=] HEMI_LAMBDA(int index) {
hemi::parallel_for(ep, 0, index_4c.spec->shape[0], [=] HEMI_LAMBDA(int64_t index) {
int i, j, k, l, ij, kl;
// triu_ij_from_index(num_unique_ij(N.ptr[0]), index_4c.ptr[index], &ij, &kl);
// triu_ij_from_index(N.ptr[0], ij, &i, &j);
Expand Down
35 changes: 25 additions & 10 deletions d4ft/native/obara_saika/eri_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ class Hartree_64 {
public:
// template <typename FLOAT>
static auto ShapeInference(const Spec<int>& shape1,
const Spec<int>& shape10,
const Spec<int>& shape21,
const Spec<int64_t>& shape10,
const Spec<int64_t>& shape22,
const Spec<int>& shape2,
const Spec<double>& shape3,
const Spec<double>& shape4,
Expand All @@ -100,12 +102,15 @@ class Hartree_64 {
const Spec<int>& shape11,
const Spec<int>& shape12,
const Spec<int>& shape13,
const Spec<int>& shape14,
// const Spec<int>& shape14,
const Spec<int>& shape23,
const Spec<int>& shape24,
const Spec<double>& shape15,
const Spec<double>& shape16,
const Spec<int>& shape17,
const Spec<double>& shape18,
const Spec<int>& shape19) {
const Spec<int>& shape19,
const Spec<int>& shape20) {
// double n2 = shape4.shape[0]*(shape4.shape[0]+1)/2;
// double n4 = n2*(n2+1)/2;
// int n4_int = static_cast<int>(n4);
Expand All @@ -119,33 +124,40 @@ class Hartree_64 {
// std::memcpy(out.ptr, arg1.ptr, sizeof(float) * arg1.spec->Size());
// }
// template <typename FLOAT>
static void Cpu(Array<const int>& N,
Array<const int>& screened_length,
static void Cpu(Array<const int>& N,
Array<const int>& thread_load,
Array<const int64_t>& thread_num,
Array<const int64_t>& screened_length,
Array<const int>& n,
Array<const double>& r,
Array<const double>& z,
Array<const int>& min_a,
Array<const int>& min_c,
Array<const int>& max_ab,
Array<const int>& max_cd,
Array<const int>& Ms,
Array<const int>& Ms,
Array<const int>& sorted_ab_idx,
Array<const int>& sorted_cd_idx,
Array<const int>& screened_cd_idx_start,
Array<const int>& screened_idx_offset,
// Array<const int>& screened_idx_offset,
Array<const int>& ab_thread_num,
Array<const int>& ab_thread_offset,
Array<const double>& pgto_coeff,
Array<const double>& pgto_normalization_factor,
Array<const int>& pgto_idx_to_cgto_idx,
Array<const double>& rdm1,
Array<const int>& n_cgto,
Array<const int>& n_pgto,
Array<double>& output){
// std::memcpy(output.ptr, outshape.ptr, sizeof(float) * outshape.spec->Size());
}

// template <typename FLOAT>
static void Gpu(cudaStream_t stream,
Array<const int>& N,
Array<const int>& screened_length,
Array<const int>& N,
Array<const int>& thread_load,
Array<const int64_t>& thread_num,
Array<const int64_t>& screened_length,
Array<const int>& n,
Array<const double>& r,
Array<const double>& z,
Expand All @@ -157,12 +169,15 @@ class Hartree_64 {
Array<const int>& sorted_ab_idx,
Array<const int>& sorted_cd_idx,
Array<const int>& screened_cd_idx_start,
Array<const int>& screened_idx_offset,
// Array<const int>& screened_idx_offset,
Array<const int>& ab_thread_num,
Array<const int>& ab_thread_offset,
Array<const double>& pgto_coeff,
Array<const double>& pgto_normalization_factor,
Array<const int>& pgto_idx_to_cgto_idx,
Array<const double>& rdm1,
Array<const int>& n_cgto,
Array<const int>& n_pgto,
Array<double>& output);
};

Expand Down
Binary file modified d4ft/native/obara_saika/eri_kernel.so
Binary file not shown.
Loading

0 comments on commit be6620a

Please sign in to comment.