-
Notifications
You must be signed in to change notification settings - Fork 139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial LCAO vgl batched implementation using GEMM #4407
Changes from 27 commits
7aa3e32
2d53994
04b0615
1fb2dc9
b6078dd
069bdb1
8cd6b4c
a2e10c3
9439898
1fa8711
e5a8413
3ec7f15
523c670
9fca749
a0e58cf
e5b485b
1a51180
2737c07
5ea35e7
d25acfb
0e0fa6b
32758f7
bebebcf
ddb042f
2b170bf
838af8d
602599f
e35086d
c4b54d6
c43f8f3
7f96fa1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -346,6 +346,67 @@ void LCAOrbitalSet::evaluateVGL(const ParticleSet& P, int iat, ValueVector& psi, | |
} | ||
} | ||
|
||
void LCAOrbitalSet::mw_evaluateVGL(const RefVectorWithLeader<SPOSet>& spo_list, | ||
const RefVectorWithLeader<ParticleSet>& P_list, | ||
int iat, | ||
const RefVector<ValueVector>& psi_v_list, | ||
const RefVector<GradVector>& dpsi_v_list, | ||
const RefVector<ValueVector>& d2psi_v_list) const | ||
{ | ||
OffloadMWVGLArray phi_vgl_v; | ||
phi_vgl_v.resize(DIM_VGL, spo_list.size(), OrbitalSetSize); | ||
mw_evaluateVGLImplGEMM(spo_list, P_list, iat, phi_vgl_v); | ||
|
||
const size_t output_size = phi_vgl_v.size(2); | ||
const size_t nw = phi_vgl_v.size(1); | ||
|
||
//TODO: make this cleaner? | ||
for (int iw = 0; iw < nw; iw++) | ||
{ | ||
std::copy_n(phi_vgl_v.data_at(0, iw, 0), output_size, psi_v_list[iw].get().data()); | ||
std::copy_n(phi_vgl_v.data_at(4, iw, 0), output_size, d2psi_v_list[iw].get().data()); | ||
// grads are [dim, walker, orb] in phi_vgl_v | ||
// [walker][orb, dim] in dpsi_v_list | ||
for (size_t idim = 0; idim < DIM; idim++) | ||
BLAS::copy(output_size, phi_vgl_v.data_at(idim + 1, iw, 0), 1, &dpsi_v_list[iw].get().data()[0][idim], DIM); | ||
} | ||
} | ||
|
||
void LCAOrbitalSet::mw_evaluateVGLImplGEMM(const RefVectorWithLeader<SPOSet>& spo_list, | ||
const RefVectorWithLeader<ParticleSet>& P_list, | ||
int iat, | ||
OffloadMWVGLArray& phi_vgl_v) const | ||
{ | ||
// [5][NW][NumAO] | ||
OffloadMWVGLArray basis_mw; | ||
PDoakORNL marked this conversation as resolved.
Show resolved
Hide resolved
|
||
basis_mw.resize(DIM_VGL, spo_list.size(), BasisSetSize); | ||
|
||
if (Identity) | ||
{ | ||
myBasisSet->mw_evaluateVGL(P_list, iat, basis_mw); | ||
// output_size can be smaller than BasisSetSize | ||
const size_t output_size = phi_vgl_v.size(2); | ||
const size_t nw = phi_vgl_v.size(1); | ||
|
||
for (size_t idim = 0; idim < DIM_VGL; idim++) | ||
for (int iw = 0; iw < nw; iw++) | ||
std::copy_n(basis_mw.data_at(idim, iw, 0), output_size, phi_vgl_v.data_at(idim, iw, 0)); | ||
} | ||
else | ||
{ | ||
const size_t requested_orb_size = phi_vgl_v.size(2); | ||
assert(requested_orb_size <= OrbitalSetSize); | ||
ValueMatrix C_partial_view(C->data(), requested_orb_size, BasisSetSize); | ||
myBasisSet->mw_evaluateVGL(P_list, iat, basis_mw); | ||
BLAS::gemm('T', 'N', | ||
requested_orb_size, // MOs | ||
spo_list.size() * DIM_VGL, // walkers * DIM_VGL | ||
BasisSetSize, // AOs | ||
1, C_partial_view.data(), BasisSetSize, basis_mw.data(), BasisSetSize, 0, phi_vgl_v.data(), | ||
requested_orb_size); | ||
} | ||
} | ||
|
||
void LCAOrbitalSet::evaluateDetRatios(const VirtualParticleSet& VP, | ||
ValueVector& psi, | ||
const ValueVector& psiinv, | ||
|
@@ -365,6 +426,34 @@ void LCAOrbitalSet::evaluateDetRatios(const VirtualParticleSet& VP, | |
} | ||
} | ||
|
||
void LCAOrbitalSet::mw_evaluateVGLandDetRatioGrads(const RefVectorWithLeader<SPOSet>& spo_list, | ||
const RefVectorWithLeader<ParticleSet>& P_list, | ||
int iat, | ||
const std::vector<const ValueType*>& invRow_ptr_list, | ||
OffloadMWVGLArray& phi_vgl_v, | ||
std::vector<ValueType>& ratios, | ||
std::vector<GradType>& grads) const | ||
{ | ||
assert(this == &spo_list.getLeader()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this needs test coverage. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will add once I have the shared resource set up. |
||
assert(phi_vgl_v.size(0) == DIM_VGL); | ||
assert(phi_vgl_v.size(1) == spo_list.size()); | ||
|
||
mw_evaluateVGLImplGEMM(spo_list, P_list, iat, phi_vgl_v); | ||
// Device data of phi_vgl_v must be up-to-date upon return | ||
phi_vgl_v.updateTo(); | ||
|
||
const size_t nw = spo_list.size(); | ||
const size_t norb_requested = phi_vgl_v.size(2); | ||
for (int iw = 0; iw < nw; iw++) | ||
{ | ||
ratios[iw] = simd::dot(invRow_ptr_list[iw], phi_vgl_v.data_at(0, iw, 0), norb_requested); | ||
GradType dphi; | ||
for (size_t idim = 0; idim < DIM; idim++) | ||
dphi[idim] = simd::dot(invRow_ptr_list[iw], phi_vgl_v.data_at(idim + 1, iw, 0), norb_requested) / ratios[iw]; | ||
grads[iw] = dphi; | ||
} | ||
} | ||
|
||
void LCAOrbitalSet::evaluateVGH(const ParticleSet& P, int iat, ValueVector& psi, GradVector& dpsi, HessVector& dhpsi) | ||
{ | ||
//TAKE CARE OF IDENTITY | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,10 +30,11 @@ namespace qmcplusplus | |
struct LCAOrbitalSet : public SPOSet | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a class |
||
{ | ||
public: | ||
using basis_type = SoaBasisSetBase<ValueType>; | ||
using vgl_type = basis_type::vgl_type; | ||
using vgh_type = basis_type::vgh_type; | ||
using vghgh_type = basis_type::vghgh_type; | ||
using basis_type = SoaBasisSetBase<ValueType>; | ||
using vgl_type = basis_type::vgl_type; | ||
using vgh_type = basis_type::vgh_type; | ||
using vghgh_type = basis_type::vghgh_type; | ||
using OffloadMWVGLArray = Array<ValueType, 3, OffloadPinnedAllocator<ValueType>>; // [VGL, walker, Orbs] | ||
ye-luo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
///pointer to the basis set | ||
std::unique_ptr<basis_type> myBasisSet; | ||
|
@@ -78,11 +79,27 @@ struct LCAOrbitalSet : public SPOSet | |
|
||
void evaluateVGL(const ParticleSet& P, int iat, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi) override; | ||
|
||
|
||
void mw_evaluateVGL(const RefVectorWithLeader<SPOSet>& spo_list, | ||
const RefVectorWithLeader<ParticleSet>& P_list, | ||
int iat, | ||
const RefVector<ValueVector>& psi_v_list, | ||
const RefVector<GradVector>& dpsi_v_list, | ||
const RefVector<ValueVector>& d2psi_v_list) const override; | ||
|
||
void evaluateDetRatios(const VirtualParticleSet& VP, | ||
ValueVector& psi, | ||
const ValueVector& psiinv, | ||
std::vector<ValueType>& ratios) override; | ||
|
||
void mw_evaluateVGLandDetRatioGrads(const RefVectorWithLeader<SPOSet>& spo_list, | ||
const RefVectorWithLeader<ParticleSet>& P_list, | ||
int iat, | ||
const std::vector<const ValueType*>& invRow_ptr_list, | ||
OffloadMWVGLArray& phi_vgl_v, | ||
std::vector<ValueType>& ratios, | ||
std::vector<GradType>& grads) const override; | ||
|
||
void evaluateVGH(const ParticleSet& P, | ||
int iat, | ||
ValueVector& psi, | ||
|
@@ -219,7 +236,7 @@ struct LCAOrbitalSet : public SPOSet | |
vghgh_type Tempghv; | ||
|
||
private: | ||
///helper functions to handl Identity | ||
///helper functions to handle Identity | ||
void evaluate_vgl_impl(const vgl_type& temp, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi) const; | ||
|
||
void evaluate_vgl_impl(const vgl_type& temp, | ||
|
@@ -228,6 +245,8 @@ struct LCAOrbitalSet : public SPOSet | |
GradMatrix& dlogdet, | ||
ValueMatrix& d2logdet) const; | ||
///These two functions unpack the data in vgh_type temp object into wavefunction friendly data structures. | ||
|
||
|
||
///This unpacks temp into vectors psi, dpsi, and d2psi. | ||
void evaluate_vgh_impl(const vgh_type& temp, ValueVector& psi, GradVector& dpsi, HessVector& d2psi) const; | ||
|
||
|
@@ -266,6 +285,11 @@ struct LCAOrbitalSet : public SPOSet | |
|
||
///Unpacks data in vgl object and calculates/places ionic gradient of a single row (phi_j(r)) into dlogdet. | ||
void evaluate_ionderiv_v_row_impl(const vgl_type& temp, GradVector& dlogdet) const; | ||
|
||
void mw_evaluateVGLImplGEMM(const RefVectorWithLeader<SPOSet>& spo_list, | ||
const RefVectorWithLeader<ParticleSet>& P_list, | ||
int iat, | ||
OffloadMWVGLArray& phi_vgl_v) const; | ||
}; | ||
} // namespace qmcplusplus | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type
and_type
are redundant here the naming convention is types are leading capital mixed case so:Value
implies a type already.Other types here would be properly named
Value
Vgl
,Vgh
,Vghgh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess there SoaBasisSetBase::value_type being used. Need to do renaming outside this PR.