Skip to content

Commit

Permalink
The SR-ERI estimator underestimated the upper bound when two basis fu…
Browse files Browse the repository at this point in the history
…nctions were located at the same center. This fix adjusts the upper bound estimator.
  • Loading branch information
sunqm committed Oct 2, 2024
1 parent 410d960 commit 6472403
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 19 deletions.
9 changes: 2 additions & 7 deletions pyscf/lib/pbc/nr_direct.c
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,7 @@ void PBCVHFnr_sindex(int16_t *sindex, int *atm, int natm,
{
float fac_guess = .5f - logf(omega2)/4;
int ijb, ib, jb, i0, j0, i1, j1, i, j, li, lj;
float dx, dy, dz, ai, aj, ci, cj, aij, a1, fi, fj, rr, rij, dri, drj;
float dx, dy, dz, ai, aj, ci, cj, aij, a1, rr;
float log_fac, theta, theta_r, r_guess, v;
#pragma omp for schedule(dynamic, 1)
for (ijb = 0; ijb < ngroups*(ngroups+1)/2; ijb++) {
Expand All @@ -1193,8 +1193,6 @@ void PBCVHFnr_sindex(int16_t *sindex, int *atm, int natm,
cj = cs[jb];

aij = ai + aj;
fi = ai / aij;
fj = aj / aij;
a1 = ai * aj / aij;
theta = omega2/(omega2+aij);
r_guess = sqrtf(-logf(1e-9f) / (aij * theta));
Expand All @@ -1209,10 +1207,7 @@ void PBCVHFnr_sindex(int16_t *sindex, int *atm, int natm,
dy = ry[i] - ry[j];
dz = rz[i] - rz[j];
rr = dx * dx + dy * dy + dz * dz;
rij = sqrtf(rr);
dri = fj * rij + theta_r;
drj = fi * rij + theta_r;
v = li*logf(dri) + lj*logf(drj) - a1*rr + log_fac;
v = (li+lj)*logf(MAX(theta_r, 1.f)) - a1*rr + log_fac;
sindex[i*Nbas+j] = v * LOG_ADJUST;
} }
if (ib > jb) {
Expand Down
20 changes: 8 additions & 12 deletions pyscf/lib/vhf/nr_sr_vhf.c
Original file line number Diff line number Diff line change
Expand Up @@ -888,9 +888,9 @@ void CVHFnr_sr_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, float *q_cond,
float fac_guess = .5f - logf(omega2)/4;
int ish, jsh, li, lj;
int ij, i, j, di, dj, dij, di2, dj2;
float ai, aj, fi, fj, aij, a1, ci, cj;
float ai, aj, aij, ai_aij, a1, ci, cj;
float xi, yi, zi, xj, yj, zj, xij, yij, zij;
float dx, dy, dz, r2, r, dri, drj, v, log_fac, r_guess, theta, theta_r;
float dx, dy, dz, r2, v, log_fac, r_guess, theta, theta_r;
double qtmp, tmp;
float log_qmax;
int shls[4];
Expand Down Expand Up @@ -924,12 +924,11 @@ void CVHFnr_sr_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, float *q_cond,
dy = yj - yi;
dz = zj - zi;
aij = ai + aj;
fi = ai / aij;
fj = aj / aij;
a1 = fi * aj;
xij = xi + fj * dx;
yij = yi + fj * dy;
zij = zi + fj * dz;
ai_aij = ai / aij;
a1 = ai_aij * aj;
xij = xi + ai_aij * dx;
yij = yi + ai_aij * dy;
zij = zi + ai_aij * dz;

theta = omega2/(omega2+aij);
r_guess = R_GUESS_FAC / sqrtf(aij * theta);
Expand All @@ -938,10 +937,7 @@ void CVHFnr_sr_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, float *q_cond,
log_fac = logf(ci*cj * sqrtf((2*li+1.f)*(2*lj+1.f))/(4*M_PI))
+ 1.5f*logf(M_PI/aij) + fac_guess;
r2 = dx * dx + dy * dy + dz * dz;
r = sqrtf(r2);
dri = fj * r + theta_r;
drj = fi * r + theta_r;
v = li*logf(dri) + lj*logf(drj) - a1*r2 + log_fac;
v = (li+lj)*logf(MAX(theta_r, 1.f)) - a1*r2 + log_fac;
s_index[ish*Nbas+jsh] = v;
s_index[jsh*Nbas+ish] = v;
xij_cond[ish*Nbas+jsh] = xij;
Expand Down
34 changes: 34 additions & 0 deletions pyscf/lib/vhf/test/test_nr_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,40 @@ def test_direct_jk_s2(self):
self.assertTrue(numpy.allclose(vj0,vj1))
self.assertTrue(numpy.allclose(vk0,vk1))

def test_sr_vhf_q_cond(self):
for omega in [.1, .2, .3]:
for l in [0, 1, 2, 3]:
for a in [.15, .5, 2.5]:
rs = numpy.arange(1, 10) * 2.
mol = gto.M(atom=['H 0 0 0'] + [f'H 0 0 {r}' for r in rs],
basis=[[l,[a,1]]], unit='B')
nbas = mol.nbas
q_cond = numpy.empty((6,nbas,nbas), dtype=numpy.float32)
ao_loc = mol.ao_loc
cintopt = lib.c_null_ptr()
with mol.with_short_range_coulomb(omega):
with mol.with_integral_screen(1e-26):
libcvhf2.CVHFnr_sr_int2e_q_cond(
libcvhf2.int2e_sph, cintopt,
q_cond.ctypes.data_as(ctypes.c_void_p),
ao_loc.ctypes.data_as(ctypes.c_void_p),
mol._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.natm),
mol._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.nbas),
mol._env.ctypes.data_as(ctypes.c_void_p))

s_index = q_cond[2]
si_0 = s_index[0,0]
si_others = s_index.diagonal()[1:]
with mol.with_short_range_coulomb(omega):
ints = [abs(mol.intor_by_shell('int2e', (0,0,i,i))).max()
for i in range(1, mol.nbas)]

aij = akl = a * 2
omega2 = mol.omega**2
theta = 1/(2/aij+1/omega2)
rr = rs**2
estimator = rr * numpy.exp(si_0 + si_others - theta*rr)
assert all(estimator / ints > 1)


if __name__ == '__main__':
Expand Down

0 comments on commit 6472403

Please sign in to comment.