Skip to content

Commit

Permalink
Merge pull request #4612 from shivupa/J1spin_indexing_fix
Browse files Browse the repository at this point in the history
J1Spin indexing issue
  • Loading branch information
ye-luo authored Jun 1, 2023
2 parents 22ce17b + 0574de9 commit 58424b9
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 19 deletions.
22 changes: 3 additions & 19 deletions src/QMCWaveFunctions/Jastrow/J1Spin.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,26 +151,10 @@ struct J1Spin : public WaveFunctionComponent
// if target type is specified J1UniqueFunctors[i*NumTargetGroups + j] is assigned
assert(target_type < NumTargetGroups);
if (target_type == -1)
{
for (int i = 0; i < Nions; i++)
for (int j = 0; j < NumTargetGroups; j++)
{
auto igroup = Ions.getGroupID(i);
if (igroup == source_type && J1UniqueFunctors[igroup * NumTargetGroups + j] == nullptr)
J1UniqueFunctors[igroup * NumTargetGroups + j] = std::move(afunc);
}
}
throw std::runtime_error(
"J1Spin::addFunc is not compatible with spin independent Jastrow factors (target_type == -1");
else
{
for (int i = 0; i < Nions; i++)
for (int j = 0; j < NumTargetGroups; j++)
{
auto igroup = Ions.getGroupID(i);
if (Ions.getGroupID(i) == source_type && j == target_type &&
J1UniqueFunctors[i * NumTargetGroups + j] == nullptr)
J1UniqueFunctors[igroup * Nelec + j] = std::move(afunc);
}
}
J1UniqueFunctors[source_type * NumTargetGroups + target_type] = std::move(afunc);
}

void recompute(const ParticleSet& P) override
Expand Down
119 changes: 119 additions & 0 deletions src/QMCWaveFunctions/tests/test_J1Spin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,123 @@ TEST_CASE("J1 spin evaluate derivatives Jastrow", "[wavefunction]")
CHECK(cloned_dhpsioverpsi[i] == ValueApprox(expected_dhpsioverpsi[i]));
}
}

TEST_CASE("J1 spin evaluate derivatives multiparticle Jastrow", "[wavefunction]")
{
Communicate* c = OHMMS::Controller;

ParticleSetPool ptcl = ParticleSetPool(c);
auto ions_uptr = std::make_unique<ParticleSet>(ptcl.getSimulationCell());
auto elec_uptr = std::make_unique<ParticleSet>(ptcl.getSimulationCell());
ParticleSet& ions_(*ions_uptr);
ParticleSet& elec_(*elec_uptr);

ions_.setName("ion0");
ptcl.addParticleSet(std::move(ions_uptr));
ions_.create({2});
ions_.R[0] = {-1.0, 0.0, 0.0};
ions_.R[1] = { 1.0, 0.0, 0.0};
SpeciesSet& ispecies = ions_.getSpeciesSet();
int BeIdx = ispecies.addSpecies("Be");
int ichargeIdx = ispecies.addAttribute("charge");
ispecies(ichargeIdx, BeIdx) = 4.0;

elec_.setName("e");
ptcl.addParticleSet(std::move(elec_uptr));
elec_.create({4, 4, 1});
elec_.R[0] = { 0.5, 0.5, 0.5};
elec_.R[1] = {-0.5, 0.5, 0.5};
elec_.R[2] = { 0.5, -0.5, 0.5};
elec_.R[3] = { 0.5, 0.5, -0.5};
elec_.R[4] = {-0.5, -0.5, 0.5};
elec_.R[5] = { 0.5, -0.5, -0.5};
elec_.R[6] = {-0.5, 0.5, -0.5};
elec_.R[7] = {-0.5, -0.5, -0.5};
elec_.R[8] = { 1.5, 1.5, 1.5};

SpeciesSet& tspecies = elec_.getSpeciesSet();
int upIdx = tspecies.addSpecies("u");
int downIdx = tspecies.addSpecies("d");
int posIdx = tspecies.addSpecies("p");
int massIdx = tspecies.addAttribute("mass");
int chargeIdx = tspecies.addAttribute("charge");
tspecies(massIdx, upIdx) = 1.0;
tspecies(massIdx, downIdx) = 1.0;
tspecies(massIdx, posIdx) = 1.0;
tspecies(chargeIdx, upIdx) = -1.0;
tspecies(massIdx, downIdx) = -1.0;
tspecies(massIdx, posIdx) = 1.0;
// Necessary to set mass
elec_.resetGroups();

ions_.update();
elec_.addTable(elec_);
elec_.addTable(ions_);
elec_.update();

const char* jasxml = R"(<wavefunction name="psi0" target="e">
<jastrow name="J1" type="One-Body" function="Bspline" print="yes" source="ion0" spin="yes">
<correlation speciesA="Be" speciesB="u" cusp="0.0" size="2" rcut="5.0">
<coefficients id="J1uH" type="Array"> 0.5 0.1 </coefficients>
</correlation>
<correlation speciesA="Be" speciesB="d" cusp="0.0" size="2" rcut="5.0">
<coefficients id="J1dH" type="Array"> 0.5 0.1 </coefficients>
</correlation>
<correlation speciesA="Be" speciesB="p" cusp="0.0" size="2" rcut="5.0">
<coefficients id="J1pH" type="Array"> 0.5 0.1 </coefficients>
</correlation>
</jastrow>
</wavefunction>
)";
Libxml2Document doc;
bool okay = doc.parseFromString(jasxml);
REQUIRE(okay);
xmlNodePtr jas1 = doc.getRoot();
WaveFunctionFactory wf_factory(elec_, ptcl.getPool(), c);
RuntimeOptions runtime_options;
auto twf_ptr = wf_factory.buildTWF(jas1, runtime_options);
auto& twf(*twf_ptr);
twf.setMassTerm(elec_);
auto& twf_component_list = twf.getOrbitals();
auto cloned_j1spin = twf_component_list[0]->makeClone(elec_);

opt_variables_type active;
twf.checkInVariables(active);
active.removeInactive();
int nparam = active.size_of_active();
REQUIRE(nparam == 6);

// check logs
//evaluateLog += into G + L so reset
elec_.G = 0.0;
elec_.L = 0.0;
LogValueType log = twf_component_list[0]->evaluateLog(elec_, elec_.G, elec_.L);
LogValueType expected_log{-3.58983, 0.0};
CHECK(log == LogComplexApprox(expected_log));
//evaluateLog += into G + L so reset
elec_.G = 0.0;
elec_.L = 0.0;
LogValueType cloned_log = cloned_j1spin->evaluateLog(elec_, elec_.G, elec_.L);
CHECK(cloned_log == LogComplexApprox(expected_log));

// check derivatives
twf.evaluateLog(elec_);
Vector<ValueType> dlogpsi(nparam);
Vector<ValueType> dhpsioverpsi(nparam);
Vector<ValueType> cloned_dlogpsi(nparam);
Vector<ValueType> cloned_dhpsioverpsi(nparam);

twf_component_list[0]->evaluateDerivatives(elec_, active, dlogpsi, dhpsioverpsi);
cloned_j1spin->evaluateDerivatives(elec_, active, cloned_dlogpsi, cloned_dhpsioverpsi);
// Numbers not validated
std::vector<ValueType> expected_dlogpsi = {-2.544, -4.70578, -2.544, -4.70578, -0.055314, -0.770138};
std::vector<ValueType> expected_dhpsioverpsi = {-2.45001, 0.0794429, -2.45001, 0.0794429, 0.0462761, -0.330801};
for (int i = 0; i < nparam; i++)
{
CHECK(dlogpsi[i] == ValueApprox(expected_dlogpsi[i]));
CHECK(cloned_dlogpsi[i] == ValueApprox(expected_dlogpsi[i]));
CHECK(dhpsioverpsi[i] == ValueApprox(expected_dhpsioverpsi[i]));
CHECK(cloned_dhpsioverpsi[i] == ValueApprox(expected_dhpsioverpsi[i]));
}
}
} // namespace qmcplusplus

0 comments on commit 58424b9

Please sign in to comment.