Skip to content

Commit

Permalink
Merge pull request #2734 from mmorale3/spinorbit
Browse files Browse the repository at this point in the history
Noncollinear and spin-orbit implementation for k-point Hamiltonian in AFQMC
  • Loading branch information
ye-luo authored Oct 2, 2020
2 parents c747996 + e0886fc commit 5e7dc1f
Show file tree
Hide file tree
Showing 34 changed files with 985 additions and 516 deletions.
3 changes: 2 additions & 1 deletion external_codes/boost_multi/multi/memory/allocator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class allocator{
bool operator!=(allocator const& o) const{return mp_ != o.mp_;}
using void_pointer = typename std::pointer_traits<decltype(std::declval<memory_type*>()->allocate(0, 0))>::template rebind<void>;
pointer allocate(size_type n){
return static_cast<pointer>(static_cast<void_pointer>(mp_->allocate(n*sizeof(value_type), alignof(T))));
return static_cast<pointer>(static_cast<void_pointer>(mp_->allocate(n*sizeof(value_type), 16)));
//return static_cast<pointer>(static_cast<void_pointer>(mp_->allocate(n*sizeof(value_type), alignof(T))));
}
void deallocate(pointer p, size_type n){
mp_->deallocate(p, n*sizeof(value_type));
Expand Down
229 changes: 154 additions & 75 deletions src/AFQMC/HamiltonianOperations/KP3IndexFactorization.hpp

Large diffs are not rendered by default.

154 changes: 104 additions & 50 deletions src/AFQMC/HamiltonianOperations/KP3IndexFactorization_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class KP3IndexFactorization_batched
using CMatrix_cref = boost::multi::array_ref<ComplexType const, 2, const_pointer>;
using CVector_ref = ComplexVector_ref<pointer>;
using CMatrix_ref = ComplexMatrix_ref<pointer>;
using C3Tensor_ref = Complex3Tensor_ref<pointer>;
using C4Tensor_ref = ComplexArray_ref<4, pointer>;
using C3Tensor_cref = boost::multi::array_ref<ComplexType const, 3, const_pointer>;

using SpMatrix_cref = boost::multi::array_ref<SPComplexType const, 2, sp_pointer>;
Expand Down Expand Up @@ -338,50 +340,96 @@ class KP3IndexFactorization_batched
{
int nkpts = nopk.size();
int NMO = std::accumulate(nopk.begin(), nopk.end(), 0);
// in non-collinear case with SO, keep SO matrix here and add it
// for now, stay collinear
int npol = (walker_type == NONCOLLINEAR) ? 2 : 1;

CVector vMF_(vMF);
CVector P1D(iextensions<1u>{NMO * NMO});
fill_n(P1D.origin(), P1D.num_elements(), ComplexType(0));
vHS(vMF_, P1D);
CVector P0D(iextensions<1u>{NMO * NMO});
fill_n(P0D.origin(), P0D.num_elements(), ComplexType(0));
vHS(vMF_, P0D);
if (TG_.TG().size() > 1)
TG_.TG().all_reduce_in_place_n(to_address(P1D.origin()), P1D.num_elements(), std::plus<>());
TG_.TG().all_reduce_in_place_n(to_address(P0D.origin()), P0D.num_elements(), std::plus<>());

boost::multi::array<ComplexType, 2> P1({NMO, NMO});
copy_n(P1D.origin(), NMO * NMO, P1.origin());
boost::multi::array<ComplexType, 2> P0({NMO, NMO});
copy_n(P0D.origin(), NMO * NMO, P0.origin());

// add H1 + vn0 and symmetrize
using ma::conj;
boost::multi::array<ComplexType, 2> P1({npol*NMO, npol*NMO});
std::fill_n(P1.origin(), P1.num_elements(), ComplexType(0.0));

// add spin-dependent H1
for (int K = 0, nk0 = 0; K < nkpts; ++K)
{
for (int i = 0, I = nk0; i < nopk[K]; i++, I++)
{
for(int p=0; p<npol; ++p)
P1[p*NMO+I][p*NMO+I] += H1[K][p*nopk[K]+i][p*nopk[K]+i];
for (int j = i + 1, J = I + 1; j < nopk[K]; j++, J++)
{
for(int p=0; p<npol; ++p)
{
P1[p*NMO+I][p*NMO+J] += H1[K][p*nopk[K]+i][p*nopk[K]+j];
P1[p*NMO+J][p*NMO+I] += H1[K][p*nopk[K]+j][p*nopk[K]+i];
}
}
if(walker_type == NONCOLLINEAR) {
// offdiagonal piece
for (int j = 0, J = nk0; j < nopk[K]; j++, J++)
{
P1[I][NMO+J] += H1[K][i][nopk[K]+j];
P1[NMO+J][I] += H1[K][nopk[K]+j][i];
}
}
}
nk0 += nopk[K];
}

// add P0 (diagonal in spin)
for(int p=0; p<npol; ++p)
for(int I=0; I<NMO; I++)
for(int J=0; J<NMO; J++)
P1[p*NMO+I][p*NMO+J] += P0[I][J];

// add vn0 (diagonal in spin)
for (int K = 0, nk0 = 0; K < nkpts; ++K)
{
for (int i = 0, I = nk0; i < nopk[K]; i++, I++)
{
P1[I][I] += H1[K][i][i] + vn0[K][i][i];
for(int p=0; p<npol; ++p)
P1[p*NMO+I][p*NMO+I] += vn0[K][i][i];
for (int j = i + 1, J = I + 1; j < nopk[K]; j++, J++)
{
P1[I][J] += H1[K][i][j] + vn0[K][i][j];
P1[J][I] += H1[K][j][i] + vn0[K][j][i];
// This is really cutoff dependent!!!
#if MIXED_PRECISION
if (std::abs(P1[I][J] - ma::conj(P1[J][I])) * 2.0 > 1e-5)
for(int p=0; p<npol; ++p)
{
#else
if (std::abs(P1[I][J] - ma::conj(P1[J][I])) * 2.0 > 1e-6)
{
#endif
app_error() << " WARNING in getOneBodyPropagatorMatrix. H1 is not hermitian. \n";
app_error() << I << " " << J << " " << P1[I][J] << " " << P1[J][I] << " " << H1[K][i][j] << " "
<< H1[K][j][i] << " " << vn0[K][i][j] << " " << vn0[K][j][i] << std::endl;
//APP_ABORT("Error in getOneBodyPropagatorMatrix. H1 is not hermitian. \n");
P1[p*NMO+I][p*NMO+J] += vn0[K][i][j];
P1[p*NMO+J][p*NMO+I] += vn0[K][j][i];
}
P1[I][J] = 0.5 * (P1[I][J] + ma::conj(P1[J][I]));
P1[J][I] = ma::conj(P1[I][J]);
}
}
nk0 += nopk[K];
}

using ma::conj;
// symmetrize
for(int I=0; I<npol*NMO; I++)
{
for(int J=I+1; J<npol*NMO; J++)
{
// This is really cutoff dependent!!!
#if MIXED_PRECISION
if (std::abs(P1[I][J] - ma::conj(P1[J][I])) * 2.0 > 1e-5)
{
#else
if (std::abs(P1[I][J] - ma::conj(P1[J][I])) * 2.0 > 1e-6)
{
#endif
app_error() << " WARNING in getOneBodyPropagatorMatrix. H1 is not hermitian. \n";
app_error() << I << " " << J << " " << P1[I][J] << " " << P1[J][I] <<std::endl;
//<< H1[K][i][j] << " "
//<< H1[K][j][i] << " " << vn0[K][i][j] << " " << vn0[K][j][i] << std::endl;
}
P1[I][J] = 0.5 * (P1[I][J] + ma::conj(P1[J][I]));
P1[J][I] = ma::conj(P1[I][J]);
}
}
return P1;
}

Expand Down Expand Up @@ -438,6 +486,7 @@ class KP3IndexFactorization_batched

int nwalk = Gc.size(1);
int nspin = (walker_type == COLLINEAR ? 2 : 1);
int npol = (walker_type == NONCOLLINEAR ? 2 : 1);
int nmo_tot = std::accumulate(nopk.begin(), nopk.end(), 0);
int nmo_max = *std::max_element(nopk.begin(), nopk.end());
int nocca_tot = std::accumulate(nelpk[nd].begin(), nelpk[nd].begin() + nkpts, 0);
Expand Down Expand Up @@ -495,13 +544,13 @@ class KP3IndexFactorization_batched
for (int n = 0; n < nwalk; n++)
fill_n(E[n].origin(), 3, ComplexType(0.));

assert(Gc.num_elements() == nwalk * (nocca_tot + noccb_tot) * nmo_tot);
C3Tensor_cref G3Da(make_device_ptr(Gc.origin()), {nocca_tot, nmo_tot, nwalk});
assert(Gc.num_elements() == nwalk * (nocca_tot + noccb_tot) * npol * nmo_tot);
C3Tensor_cref G3Da(make_device_ptr(Gc.origin()), {nocca_tot*npol, nmo_tot, nwalk});
C3Tensor_cref G3Db(make_device_ptr(Gc.origin()) + G3Da.num_elements() * (nspin - 1), {noccb_tot, nmo_tot, nwalk});

// later on, rewrite routine to loop over spins, to avoid storage of both spin
// components simultaneously
Static4Tensor GKK({nspin, nkpts, nkpts, nwalk * nmo_max * nocc_max},
Static4Tensor GKK({nspin, nkpts, nkpts, nwalk * npol * nmo_max * nocc_max},
device_buffer_allocator->template get_allocator<SPComplexType>());
GKaKjw_to_GKKwaj(G3Da, GKK[0], nelpk[nd].sliced(0, nkpts), dev_nelpk[nd], dev_a0pk[nd]);
if (walker_type == COLLINEAR)
Expand All @@ -519,27 +568,29 @@ class KP3IndexFactorization_batched
for (int K = 0; K < nkpts; ++K)
{
#if defined(MIXED_PRECISION)
CMatrix_ref haj_K(make_device_ptr(haj[nd * nkpts + K].origin()), {nocc_max, nmo_max});
for (int a = 0; a < nelpk[nd][K]; ++a)
ma::product(ComplexType(1.), ma::T(G3Da[na + a].sliced(nk, nk + nopk[K])), haj_K[a].sliced(0, nopk[K]),
ComplexType(1.), E({0, nwalk}, 0));
int ni(nopk[K]);
CMatrix_ref haj_K(make_device_ptr(haj[nd * nkpts + K].origin()), {nocc_max, npol*nmo_max});
for (int a = 0; a < nelpk[nd][K]; ++a)
for (int pol = 0; pol < npol; ++pol)
ma::product(ComplexType(1.), ma::T(G3Da[(na + a)*npol+pol].sliced(nk, nk + ni)),
haj_K[a].sliced(pol*ni, pol*ni+ni), ComplexType(1.), E({0, nwalk}, 0));
na += nelpk[nd][K];
if (walker_type == COLLINEAR)
{
boost::multi::array_ref<ComplexType, 2, pointer> haj_Kb(haj_K.origin() + haj_K.num_elements(),
{nocc_max, nmo_max});
for (int b = 0; b < nelpk[nd][nkpts + K]; ++b)
ma::product(ComplexType(1.), ma::T(G3Db[nb + b].sliced(nk, nk + nopk[K])), haj_Kb[b].sliced(0, nopk[K]),
ma::product(ComplexType(1.), ma::T(G3Db[nb + b].sliced(nk, nk + ni)), haj_Kb[b].sliced(0, ni),
ComplexType(1.), E({0, nwalk}, 0));
nb += nelpk[nd][nkpts + K];
}
nk += nopk[K];
nk += ni;
#else
nk = nopk[K];
{
na = nelpk[nd][K];
CVector_ref haj_K(make_device_ptr(haj[nd * nkpts + K].origin()), {nocc_max * nmo_max});
SpMatrix_ref Gaj(GKK[0][K][K].origin(), {nwalk, nocc_max * nmo_max});
CVector_ref haj_K(make_device_ptr(haj[nd * nkpts + K].origin()), {nocc_max * npol * nmo_max});
SpMatrix_ref Gaj(GKK[0][K][K].origin(), {nwalk, nocc_max * npol * nmo_max});
ma::product(ComplexType(1.), Gaj, haj_K, ComplexType(1.), E({0, nwalk}, 0));
}
if (walker_type == COLLINEAR)
Expand Down Expand Up @@ -582,7 +633,7 @@ class KP3IndexFactorization_batched
// I WANT C++17!!!!!!
long mem_ank(0);
if (needs_copy)
mem_ank = nkpts * nocc_max * nchol_max * nmo_max;
mem_ank = nkpts * nocc_max * nchol_max * npol * nmo_max;
StaticVector LBuff(iextensions<1u>{2 * mem_ank},
device_buffer_allocator->template get_allocator<SPComplexType>());
sp_pointer LQptr(nullptr), LQmptr(nullptr);
Expand Down Expand Up @@ -660,8 +711,8 @@ class KP3IndexFactorization_batched

if (batch_cnt >= batch_size)
{
gemmBatched('T', 'N', nocc_max * nchol_max, nwalk * nocc_max, nmo_max, SPComplexType(1.0),
Aarray.data(), nmo_max, Barray.data(), nmo_max, SPComplexType(0.0), Carray.data(),
gemmBatched('T', 'N', nocc_max * nchol_max, nwalk * nocc_max, npol*nmo_max, SPComplexType(1.0),
Aarray.data(), npol*nmo_max, Barray.data(), npol*nmo_max, SPComplexType(0.0), Carray.data(),
nocc_max * nchol_max, Aarray.size());

copy_n(scl_factors.data(), scl_factors.size(), dev_scl_factors.origin());
Expand Down Expand Up @@ -691,8 +742,8 @@ class KP3IndexFactorization_batched

if (batch_cnt > 0)
{
gemmBatched('T', 'N', nocc_max * nchol_max, nwalk * nocc_max, nmo_max, SPComplexType(1.0), Aarray.data(),
nmo_max, Barray.data(), nmo_max, SPComplexType(0.0), Carray.data(), nocc_max * nchol_max,
gemmBatched('T', 'N', nocc_max * nchol_max, nwalk * nocc_max, npol*nmo_max, SPComplexType(1.0), Aarray.data(),
npol*nmo_max, Barray.data(), npol*nmo_max, SPComplexType(0.0), Carray.data(), nocc_max * nchol_max,
Aarray.size());

copy_n(scl_factors.data(), scl_factors.size(), dev_scl_factors.origin());
Expand Down Expand Up @@ -1322,6 +1373,7 @@ class KP3IndexFactorization_batched
assert(v.size(0) == 2 * local_nCV);
assert(v.size(1) == nwalk);
int nspin = (walker_type == COLLINEAR ? 2 : 1);
int npol = (walker_type == NONCOLLINEAR ? 2 : 1);
int nmo_tot = std::accumulate(nopk.begin(), nopk.end(), 0);
int nmo_max = *std::max_element(nopk.begin(), nopk.end());
int nocca_tot = std::accumulate(nelpk[nd].begin(), nelpk[nd].begin() + nkpts, 0);
Expand All @@ -1340,11 +1392,11 @@ class KP3IndexFactorization_batched
SPComplexType minusimhalfa(0.0, -0.5 * a * scl);
SPComplexType imhalfa(0.0, 0.5 * a * scl);

assert(G.num_elements() == nwalk * (nocca_tot + noccb_tot) * nmo_tot);
assert(G.num_elements() == nwalk * (nocca_tot + noccb_tot) * npol * nmo_tot);
// MAM: use reshape when available, then no need to deal with types
using GType = typename std::decay<MatA>::type::element;
boost::multi::array_ref<GType const, 3, decltype(make_device_ptr(G.origin()))> G3Da(make_device_ptr(G.origin()),
{nocca_tot, nmo_tot, nwalk});
{nocca_tot*npol, nmo_tot, nwalk});
boost::multi::array_ref<GType const, 3, decltype(make_device_ptr(G.origin()))> G3Db(make_device_ptr(G.origin()) +
G3Da.num_elements() *
(nspin - 1),
Expand All @@ -1358,7 +1410,7 @@ class KP3IndexFactorization_batched
size_t cnt(0);
Static3Tensor v1({nkpts + number_of_symmetric_Q, nchol_max, nwalk},
device_buffer_allocator->template get_allocator<SPComplexType>());
Static3Tensor GQ({nkpts, nkpts * nocc_max * nmo_max, nwalk},
Static3Tensor GQ({nkpts, nkpts * nocc_max * npol * nmo_max, nwalk},
device_buffer_allocator->template get_allocator<SPComplexType>());
fill_n(v1.origin(), v1.num_elements(), SPComplexType(0.0));
fill_n(GQ.origin(), GQ.num_elements(), SPComplexType(0.0));
Expand All @@ -1370,7 +1422,7 @@ class KP3IndexFactorization_batched
dev_a0pk[nd].sliced(nkpts, 2 * nkpts));

// can use productStridedBatched if LQKakn is changed to a 3Tensor array
int Kak = nkpts * nocc_max * nmo_max;
int Kak = nkpts * nocc_max * npol * nmo_max;
std::vector<sp_pointer> Aarray;
std::vector<sp_pointer> Barray;
std::vector<sp_pointer> Carray;
Expand Down Expand Up @@ -1530,30 +1582,32 @@ class KP3IndexFactorization_batched
template<class MatA, class MatB, class IVec, class IVec2>
void GKaKjw_to_GKKwaj(MatA const& GKaKj, MatB&& GKKaj, IVec&& nocc, IVec2&& dev_no, IVec2&& dev_a0)
{
int npol = (walker_type == NONCOLLINEAR) ? 2 : 1;
int nmo_max = *std::max_element(nopk.begin(), nopk.end());
// int nocc_max = *std::max_element(nocc.begin(),nocc.end());
int nmo_tot = GKaKj.size(1);
int nwalk = GKaKj.size(2);
int nkpts = nopk.size();
assert(GKKaj.num_elements() >= nkpts * nkpts * nwalk * nocc_max * nmo_max);
assert(GKKaj.num_elements() >= nkpts * nkpts * nwalk * nocc_max * npol * nmo_max);

using ma::KaKjw_to_KKwaj;
KaKjw_to_KKwaj(nwalk, nkpts, nmo_max, nmo_tot, nocc_max, dev_nopk.origin(), dev_i0pk.origin(), dev_no.origin(),
KaKjw_to_KKwaj(nwalk, nkpts, npol, nmo_max, nmo_tot, nocc_max, dev_nopk.origin(), dev_i0pk.origin(), dev_no.origin(),
dev_a0.origin(), GKaKj.origin(), GKKaj.origin());
}

template<class MatA, class MatB, class IVec, class IVec2>
void GKaKjw_to_GQKajw(MatA const& GKaKj, MatB&& GQKaj, IVec&& nocc, IVec2&& dev_no, IVec2&& dev_a0)
{
int npol = (walker_type == NONCOLLINEAR) ? 2 : 1;
int nmo_max = *std::max_element(nopk.begin(), nopk.end());
// int nocc_max = *std::max_element(nocc.begin(),nocc.end());
int nmo_tot = GKaKj.size(1);
int nwalk = GKaKj.size(2);
int nkpts = nopk.size();
assert(GQKaj.num_elements() >= nkpts * nkpts * nwalk * nocc_max * nmo_max);
assert(GQKaj.num_elements() >= nkpts * nkpts * nwalk * nocc_max * npol * nmo_max);

using ma::KaKjw_to_QKajw;
KaKjw_to_QKajw(nwalk, nkpts, nmo_max, nmo_tot, nocc_max, dev_nopk.origin(), dev_i0pk.origin(), dev_no.origin(),
KaKjw_to_QKajw(nwalk, nkpts, npol, nmo_max, nmo_tot, nocc_max, dev_nopk.origin(), dev_i0pk.origin(), dev_no.origin(),
dev_a0.origin(), dev_QKToK2.origin(), GKaKj.origin(), GQKaj.origin());
}

Expand Down
Loading

0 comments on commit 5e7dc1f

Please sign in to comment.