Skip to content

Commit

Permalink
Simplify LCAOrbitalSet::mw_evaluateVGLImplGEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
ye-luo committed Jan 22, 2023
1 parent 838af8d commit e9231a4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 62 deletions.
87 changes: 28 additions & 59 deletions src/QMCWaveFunctions/LCAO/LCAOrbitalSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,24 +133,6 @@ inline void LCAOrbitalSet::evaluate_vgl_impl(const vgl_type& temp,
std::copy_n(temp.data(4), output_size, d2psi.data());
}

inline void LCAOrbitalSet::evaluate_vgl_mw_impl(const OffloadMWVGLArray& temp, OffloadMWVGLArray& phi_vgl_v) const
{
const size_t output_size = phi_vgl_v.size(2);
const size_t nw = phi_vgl_v.size(1);

for (int iw = 0; iw < nw; iw++)
{
std::copy_n(temp.data_at(0, iw, 0), output_size, phi_vgl_v.data_at(0, iw, 0));
for (size_t idim = 0; idim < DIM; idim++)
{
ValueType* phi_g = phi_vgl_v.data_at(idim + 1, iw, 0);
std::copy_n(temp.data_at(idim + 1, iw, 0), output_size, phi_vgl_v.data_at(idim + 1, iw, 0));
}
std::copy_n(temp.data_at(4, iw, 0), output_size, phi_vgl_v.data_at(4, iw, 0));
}
}


inline void LCAOrbitalSet::evaluate_vgh_impl(const vgh_type& temp,
ValueVector& psi,
GradVector& dpsi,
Expand Down Expand Up @@ -372,7 +354,7 @@ void LCAOrbitalSet::mw_evaluateVGL(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVector<ValueVector>& d2psi_v_list) const
{
OffloadMWVGLArray phi_vgl_v;
phi_vgl_v.resize(5, spo_list.size(), OrbitalSetSize);
phi_vgl_v.resize(DIM_VGL, spo_list.size(), OrbitalSetSize);
mw_evaluateVGLImplGEMM(spo_list, P_list, iat, phi_vgl_v);

const size_t output_size = phi_vgl_v.size(2);
Expand All @@ -386,9 +368,7 @@ void LCAOrbitalSet::mw_evaluateVGL(const RefVectorWithLeader<SPOSet>& spo_list,
// grads are [dim, walker, orb] in phi_vgl_v
// [walker][orb, dim] in dpsi_v_list
for (size_t idim = 0; idim < DIM; idim++)
{
BLAS::copy(output_size, phi_vgl_v.data_at(idim + 1, iw, 0), 1, &dpsi_v_list[iw].get().data()[0][idim], DIM);
}
}
}

Expand All @@ -398,34 +378,32 @@ void LCAOrbitalSet::mw_evaluateVGLImplGEMM(const RefVectorWithLeader<SPOSet>& sp
OffloadMWVGLArray& phi_vgl_v) const
{
// [5][NW][NumAO]
OffloadMWVGLArray Temp_mw;
Temp_mw.resize(5, spo_list.size(), BasisSetSize);
OffloadMWVGLArray Tempv_mw;
Tempv_mw.resize(5, spo_list.size(), OrbitalSetSize);
OffloadMWVGLArray basis_mw;
basis_mw.resize(DIM_VGL, spo_list.size(), BasisSetSize);

if (Identity)
{
myBasisSet->mw_evaluateVGL(P_list, iat, Temp_mw);
evaluate_vgl_mw_impl(Temp_mw, phi_vgl_v);
myBasisSet->mw_evaluateVGL(P_list, iat, basis_mw);
// output_size can be smaller than BasisSetSize
const size_t output_size = phi_vgl_v.size(2);
const size_t nw = phi_vgl_v.size(1);

for (size_t idim = 0; idim < DIM_VGL; idim++)
for (int iw = 0; iw < nw; iw++)
std::copy_n(basis_mw.data_at(idim, iw, 0), output_size, phi_vgl_v.data_at(idim, iw, 0));
}
else
{
ValueMatrix C_partial_view(C->data(), OrbitalSetSize, BasisSetSize);
myBasisSet->mw_evaluateVGL(P_list, iat, Temp_mw);
for (int idim = 0; idim < DIM_VGL; idim++)
{
constexpr char transa = 't';
constexpr char transb = 'n';
constexpr double zone(1);
constexpr double zero(0);
BLAS::gemm(transa, transb,
C_partial_view.rows(), // MOs
spo_list.size(), // walkers
C_partial_view.cols(), // AOs
zone, C_partial_view.data(), C_partial_view.cols(), Temp_mw.data_at(idim, 0, 0), C_partial_view.cols(),
zero, Tempv_mw.data_at(idim, 0, 0), C_partial_view.rows());
}
evaluate_vgl_mw_impl(Tempv_mw, phi_vgl_v);
const size_t requested_orb_size = phi_vgl_v.size(2);
assert(requested_orb_size <= OrbitalSetSize);
ValueMatrix C_partial_view(C->data(), requested_orb_size, BasisSetSize);
myBasisSet->mw_evaluateVGL(P_list, iat, basis_mw);
BLAS::gemm('T', 'N',
requested_orb_size, // MOs
spo_list.size() * DIM_VGL, // walkers * DIM_VGL
BasisSetSize, // AOs
1, C_partial_view.data(), BasisSetSize, basis_mw.data(), BasisSetSize, 0, phi_vgl_v.data(),
requested_orb_size);
}
}

Expand Down Expand Up @@ -459,27 +437,18 @@ void LCAOrbitalSet::mw_evaluateVGLandDetRatioGrads(const RefVectorWithLeader<SPO
assert(this == &spo_list.getLeader());
assert(phi_vgl_v.size(0) == DIM_VGL);
assert(phi_vgl_v.size(1) == spo_list.size());
const size_t nw = spo_list.size();
const size_t norb_requested = phi_vgl_v.size(2);
// object to hold gradient

GradVector dphi_v(norb_requested);
mw_evaluateVGLImplGEMM(spo_list, P_list, iat, phi_vgl_v);

const size_t nw = spo_list.size();
const size_t norb_requested = phi_vgl_v.size(2);
for (int iw = 0; iw < nw; iw++)
{
// create data objects to hold values of wave function and second derivative
// phi_vgl_v.data_at(0, iw, 0) constructs another vector which shares the memory location of another containers data.
// specifically phi_vgl_v wf value information for a specific walker
ValueVector phi_v(phi_vgl_v.data_at(0, iw, 0), norb_requested);
GradVector dphi_v(norb_requested);
ratios[iw] = simd::dot(invRow_ptr_list[iw], phi_vgl_v.data_at(0, iw, 0), norb_requested);
GradType dphi;
for (size_t idim = 0; idim < DIM; idim++)
{
ValueType* phi_g = phi_vgl_v.data_at(idim + 1, iw, 0);
for (size_t iorb = 0; iorb < norb_requested; iorb++)
dphi_v[iorb][idim] = phi_g[iorb];
}
ratios[iw] = simd::dot(invRow_ptr_list[iw], phi_v.data(), norb_requested);
grads[iw] = simd::dot(invRow_ptr_list[iw], dphi_v.data(), norb_requested) / ratios[iw];
dphi[idim] = simd::dot(invRow_ptr_list[iw], phi_vgl_v.data_at(idim + 1, iw, 0), norb_requested);
grads[iw] = dphi;
}
}

Expand Down
4 changes: 1 addition & 3 deletions src/QMCWaveFunctions/LCAO/LCAOrbitalSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,14 @@ struct LCAOrbitalSet : public SPOSet
vghgh_type Tempghv;

private:
///helper functions to handl Identity
///helper functions to handle Identity
void evaluate_vgl_impl(const vgl_type& temp, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi) const;

void evaluate_vgl_impl(const vgl_type& temp,
int i,
ValueMatrix& logdet,
GradMatrix& dlogdet,
ValueMatrix& d2logdet) const;
// function to unpack vgl_type when working with batched code.
inline void evaluate_vgl_mw_impl(const OffloadMWVGLArray& temp, OffloadMWVGLArray& phi_vgl_v) const;
///These two functions unpack the data in vgh_type temp object into wavefunction friendly data structures.


Expand Down

0 comments on commit e9231a4

Please sign in to comment.