From e9231a4700a5beaba2110526ebdc9343fbd82df2 Mon Sep 17 00:00:00 2001 From: Ye Luo Date: Sat, 21 Jan 2023 20:36:59 -0600 Subject: [PATCH] Simplify LCAOrbitalSet::mw_evaluateVGLImplGEMM --- src/QMCWaveFunctions/LCAO/LCAOrbitalSet.cpp | 87 +++++++-------------- src/QMCWaveFunctions/LCAO/LCAOrbitalSet.h | 4 +- 2 files changed, 29 insertions(+), 62 deletions(-) diff --git a/src/QMCWaveFunctions/LCAO/LCAOrbitalSet.cpp b/src/QMCWaveFunctions/LCAO/LCAOrbitalSet.cpp index a607942dd17..3e1f1f174e5 100644 --- a/src/QMCWaveFunctions/LCAO/LCAOrbitalSet.cpp +++ b/src/QMCWaveFunctions/LCAO/LCAOrbitalSet.cpp @@ -133,24 +133,6 @@ inline void LCAOrbitalSet::evaluate_vgl_impl(const vgl_type& temp, std::copy_n(temp.data(4), output_size, d2psi.data()); } -inline void LCAOrbitalSet::evaluate_vgl_mw_impl(const OffloadMWVGLArray& temp, OffloadMWVGLArray& phi_vgl_v) const -{ - const size_t output_size = phi_vgl_v.size(2); - const size_t nw = phi_vgl_v.size(1); - - for (int iw = 0; iw < nw; iw++) - { - std::copy_n(temp.data_at(0, iw, 0), output_size, phi_vgl_v.data_at(0, iw, 0)); - for (size_t idim = 0; idim < DIM; idim++) - { - ValueType* phi_g = phi_vgl_v.data_at(idim + 1, iw, 0); - std::copy_n(temp.data_at(idim + 1, iw, 0), output_size, phi_vgl_v.data_at(idim + 1, iw, 0)); - } - std::copy_n(temp.data_at(4, iw, 0), output_size, phi_vgl_v.data_at(4, iw, 0)); - } -} - - inline void LCAOrbitalSet::evaluate_vgh_impl(const vgh_type& temp, ValueVector& psi, GradVector& dpsi, @@ -372,7 +354,7 @@ void LCAOrbitalSet::mw_evaluateVGL(const RefVectorWithLeader& spo_list, const RefVector& d2psi_v_list) const { OffloadMWVGLArray phi_vgl_v; - phi_vgl_v.resize(5, spo_list.size(), OrbitalSetSize); + phi_vgl_v.resize(DIM_VGL, spo_list.size(), OrbitalSetSize); mw_evaluateVGLImplGEMM(spo_list, P_list, iat, phi_vgl_v); const size_t output_size = phi_vgl_v.size(2); @@ -386,9 +368,7 @@ void LCAOrbitalSet::mw_evaluateVGL(const RefVectorWithLeader& spo_list, // grads are [dim, walker, orb] in phi_vgl_v // [walker][orb, dim] in dpsi_v_list for (size_t idim = 0; idim < DIM; idim++) - { BLAS::copy(output_size, phi_vgl_v.data_at(idim + 1, iw, 0), 1, &dpsi_v_list[iw].get().data()[0][idim], DIM); - } } } @@ -398,34 +378,32 @@ void LCAOrbitalSet::mw_evaluateVGLImplGEMM(const RefVectorWithLeader& sp OffloadMWVGLArray& phi_vgl_v) const { // [5][NW][NumAO] - OffloadMWVGLArray Temp_mw; - Temp_mw.resize(5, spo_list.size(), BasisSetSize); - OffloadMWVGLArray Tempv_mw; - Tempv_mw.resize(5, spo_list.size(), OrbitalSetSize); + OffloadMWVGLArray basis_mw; + basis_mw.resize(DIM_VGL, spo_list.size(), BasisSetSize); if (Identity) { - myBasisSet->mw_evaluateVGL(P_list, iat, Temp_mw); - evaluate_vgl_mw_impl(Temp_mw, phi_vgl_v); + myBasisSet->mw_evaluateVGL(P_list, iat, basis_mw); + // output_size can be smaller than BasisSetSize + const size_t output_size = phi_vgl_v.size(2); + const size_t nw = phi_vgl_v.size(1); + + for (size_t idim = 0; idim < DIM_VGL; idim++) + for (int iw = 0; iw < nw; iw++) + std::copy_n(basis_mw.data_at(idim, iw, 0), output_size, phi_vgl_v.data_at(idim, iw, 0)); } else { - ValueMatrix C_partial_view(C->data(), OrbitalSetSize, BasisSetSize); - myBasisSet->mw_evaluateVGL(P_list, iat, Temp_mw); - for (int idim = 0; idim < DIM_VGL; idim++) - { - constexpr char transa = 't'; - constexpr char transb = 'n'; - constexpr double zone(1); - constexpr double zero(0); - BLAS::gemm(transa, transb, - C_partial_view.rows(), // MOs - spo_list.size(), // walkers - C_partial_view.cols(), // AOs - zone, C_partial_view.data(), C_partial_view.cols(), Temp_mw.data_at(idim, 0, 0), C_partial_view.cols(), - zero, Tempv_mw.data_at(idim, 0, 0), C_partial_view.rows()); - } - evaluate_vgl_mw_impl(Tempv_mw, phi_vgl_v); + const size_t requested_orb_size = phi_vgl_v.size(2); + assert(requested_orb_size <= OrbitalSetSize); + ValueMatrix C_partial_view(C->data(), requested_orb_size, BasisSetSize); + myBasisSet->mw_evaluateVGL(P_list, iat, basis_mw); + BLAS::gemm('T', 'N', + requested_orb_size, // MOs + spo_list.size() * DIM_VGL, // walkers * DIM_VGL + BasisSetSize, // AOs + 1, C_partial_view.data(), BasisSetSize, basis_mw.data(), BasisSetSize, 0, phi_vgl_v.data(), + requested_orb_size); } } @@ -459,27 +437,18 @@ void LCAOrbitalSet::mw_evaluateVGLandDetRatioGrads(const RefVectorWithLeader