Skip to content

Commit

Permalink
Move the reliable update code in CG and PCG to a common struct called…
Browse files Browse the repository at this point in the history
… reliable_updates.
  • Loading branch information
hummingtree committed Oct 13, 2021
1 parent 35bf858 commit 3ab9b4e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 239 deletions.
172 changes: 27 additions & 145 deletions lib/inv_cg_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <eigensolve_quda.h>
#include <eigen_helper.h>

#include <reliable_updates.h>

namespace quda {

CG::CG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig,
Expand Down Expand Up @@ -328,23 +330,11 @@ namespace quda {
}
}

// alternative reliable updates
// alternative reliable updates - set precision - does not hurt performance here

const double u = precisionEpsilon(param.precision_sloppy);
const double uhigh = precisionEpsilon(); // solver precision

const double deps=sqrt(u);
constexpr double dfac = 1.1;
double d_new = 0;
double d = 0;
double dinit = 0;
double xNorm = 0;
double xnorm = 0;
double pnorm = 0;
double ppnorm = 0;
double Anorm = 0;
double beta = 0.0;
double beta = 0;

// for alternative reliable updates
if (advanced_feature && alternative_reliable) {
Expand Down Expand Up @@ -428,30 +418,6 @@ namespace quda {

auto alpha = std::make_unique<double[]>(Np);
double pAp;
int rUpdate = 0;

double rNorm = sqrt(r2);
double r0Norm = rNorm;
double maxrx = rNorm;
double maxrr = rNorm;
double maxr_deflate = rNorm; // The maximum residual since the last deflation
double delta = param.delta;

// this parameter determines how many consective reliable update
// residual increases we tolerate before terminating the solver,
// i.e., how long do we want to keep trying to converge
const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance
const int maxResIncreaseTotal = param.max_res_increase_total;

// this means when using heavy quarks we will switch to simple hq restarts as soon as the reliable strategy fails
const int hqmaxresIncrease = param.max_hq_res_increase;
const int hqmaxresRestartTotal
= param.max_hq_res_restart_total; // this limits the number of heavy quark restarts we can do

int resIncrease = 0;
int resIncreaseTotal = 0;
int hqresIncrease = 0;
int hqresRestartTotal = 0;

// set this to true if maxResIncrease has been exceeded but when we use heavy quark residual we still want to continue the CG
// only used if we use the heavy_quark_res
Expand All @@ -469,14 +435,11 @@ namespace quda {

PrintStats("CG", k, r2, b2, heavy_quark_res);

int steps_since_reliable = 1;
bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq);

// alternative reliable updates
if(advanced_feature && alternative_reliable){
dinit = uhigh * (rNorm + Anorm * xNorm);
d = dinit;
}
reliable_updates ru(alternative_reliable, u, uhigh, Anorm, r2, param.delta,
param.max_res_increase, param.max_res_increase_total,
use_heavy_quark_res, param.max_hq_res_increase, param.max_hq_res_restart_total);

while ( !converged && k < param.maxiter ) {
matSloppy(Ap, *p[j], tmp, tmp2); // tmp as tmp
Expand All @@ -485,18 +448,18 @@ namespace quda {
bool breakdown = false;
if (advanced_feature && param.pipeline) {
double Ap2;
//TODO: alternative reliable updates - need r2, Ap2, pAp, p norm
if(alternative_reliable){
double4 quadruple = blas::quadrupleCGReduction(rSloppy, Ap, *p[j]);
r2 = quadruple.x; Ap2 = quadruple.y; pAp = quadruple.z; ppnorm= quadruple.w;
r2 = quadruple.x; Ap2 = quadruple.y; pAp = quadruple.z;
ru.update_ppnorm(quadruple.w);
} else {
double3 triplet = blas::tripleCGReduction(rSloppy, Ap, *p[j]);
r2 = triplet.x; Ap2 = triplet.y; pAp = triplet.z;
}
r2_old = r2;
alpha[j] = r2 / pAp;
sigma = alpha[j]*(alpha[j] * Ap2 - pAp);
if (sigma < 0.0 || steps_since_reliable == 0) { // sigma condition has broken down
if (sigma < 0.0 || ru.steps_since_reliable == 0) { // sigma condition has broken down
r2 = blas::axpyNorm(-alpha[j], Ap, rSloppy);
sigma = r2;
breakdown = true;
Expand All @@ -510,7 +473,7 @@ namespace quda {
if (advanced_feature && alternative_reliable) {
double3 pAppp = blas::cDotProductNormA(*p[j],Ap);
pAp = pAppp.x;
ppnorm = pAppp.z;
ru.update_ppnorm(pAppp.z);
} else {
pAp = blas::reDotProduct(*p[j], Ap);
}
Expand All @@ -524,32 +487,21 @@ namespace quda {
}

// reliable update conditions
rNorm = sqrt(r2);
int updateX = 0;
int updateR = 0;
ru.update_rNorm(sqrt(r2));

if (advanced_feature) {
if (alternative_reliable) {
// alternative reliable updates
updateX = ( (d <= deps*sqrt(r2_old)) or (dfac * dinit > deps * r0Norm) ) and (d_new > deps*rNorm) and (d_new > dfac * dinit);
updateR = 0;
} else {
if (rNorm > maxrx) maxrx = rNorm;
if (rNorm > maxrr) maxrr = rNorm;
updateX = (rNorm < delta * r0Norm && r0Norm <= maxrx) ? 1 : 0;
updateR = ((rNorm < delta * maxrr && r0Norm <= maxrr) || updateX) ? 1 : 0;
}
ru.evaluate(r2_old);
// force a reliable update if we are within target tolerance (only if doing reliable updates)
if ( convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol ) updateX = 1;
if ( convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol ) ru.set_updateX();

if (use_heavy_quark_res and L2breakdown
and (convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) or (r2 / b2) < hq_res_stall_check)
and param.delta >= param.tol) {
updateX = 1;
ru.set_updateX();
}
}

if ( !(updateR || updateX )) {
if ( !ru.trigger() ) {
beta = sigma / r2_old; // use the alternative beta computation

if (advanced_feature && param.pipeline && !breakdown) {
Expand Down Expand Up @@ -588,15 +540,7 @@ namespace quda {

// alternative reliable updates
if (advanced_feature) {
if (alternative_reliable) {
d = d_new;
pnorm = pnorm + alpha[j] * alpha[j]* (ppnorm);
xnorm = sqrt(pnorm);
d_new = d + u*rNorm + uhigh*Anorm * xnorm;
if (steps_since_reliable==0 && getVerbosity() >= QUDA_DEBUG_VERBOSE)
printfQuda("New dnew: %e (r %e , y %e)\n",d_new,u*rNorm,uhigh*Anorm * sqrt(blas::norm2(y)) );
}
steps_since_reliable++;
ru.accumulate_norm(alpha[j]);
}
} else {

Expand All @@ -614,93 +558,33 @@ namespace quda {
mat(r, y, x, tmp3); // here we can use x as tmp
r2 = blas::xmyNorm(b, r);

if (param.deflate && sqrt(r2) < maxr_deflate * param.tol_restart) {
if (param.deflate && sqrt(r2) < ru.maxr_deflate * param.tol_restart) {
// Deflate and accumulate to solution vector
eig_solve->deflate(y, r, evecs, evals, true);

// Compute r_defl = RHS - A * LHS
mat(r, y, x, tmp3);
r2 = blas::xmyNorm(b, r);

maxr_deflate = sqrt(r2);
ru.update_maxr_deflate(r2);
}

blas::copy(rSloppy, r); //nop when these pointers alias
blas::zero(xSloppy);

if (advanced_feature) {
// alternative reliable updates
if (alternative_reliable) {
dinit = uhigh*(sqrt(r2) + Anorm * sqrt(blas::norm2(y)));
d = d_new;
xnorm = 0;//sqrt(norm2(x));
pnorm = 0;//pnorm + alpha * sqrt(norm2(p));
if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("New dinit: %e (r %e , y %e)\n",dinit,uhigh*sqrt(r2),uhigh*Anorm*sqrt(blas::norm2(y)));
d_new = dinit;
} else {
rNorm = sqrt(r2);
maxrr = rNorm;
maxrx = rNorm;
}
ru.update_norm(r2, y);
}

// calculate new reliable HQ resididual
if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z);

if (advanced_feature) {
// break-out check if we have reached the limit of the precision
if (sqrt(r2) > r0Norm && updateX and not L2breakdown) { // reuse r0Norm for this
resIncrease++;
resIncreaseTotal++;
warningQuda(
"CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)",
sqrt(r2), r0Norm, resIncreaseTotal);

if ((use_heavy_quark_res and sqrt(r2) < L2breakdown_eps) or resIncrease > maxResIncrease
or resIncreaseTotal > maxResIncreaseTotal or r2 < stop) {
if (use_heavy_quark_res) {
L2breakdown = true;
warningQuda("CG: L2 breakdown %e, %e", sqrt(r2), L2breakdown_eps);
} else {
if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal or r2 < stop) {
warningQuda("CG: solver exiting due to too many true residual norm increases");
break;
}
}
}
} else {
resIncrease = 0;
}
if (ru.reliable_break(r2, stop, L2breakdown, L2breakdown_eps)) { break; }
}

// if L2 broke down already we turn off reliable updates and restart the CG
if (use_heavy_quark_res and L2breakdown) {
hqresRestartTotal++; // count the number of heavy quark restarts we've done
delta = 0;
warningQuda("CG: Restarting without reliable updates for heavy-quark residual (total #inc %i)",
hqresRestartTotal);
heavy_quark_restart = true;

if (heavy_quark_res > heavy_quark_res_old) { // check if new hq residual is greater than previous
hqresIncrease++; // count the number of consecutive increases
warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e",
heavy_quark_res, heavy_quark_res_old);
// break out if we do not improve here anymore
if (hqresIncrease > hqmaxresIncrease) {
warningQuda("CG: solver exiting due to too many heavy quark residual norm increases (%i/%i)",
hqresIncrease, hqmaxresIncrease);
break;
}
} else {
hqresIncrease = 0;
}

if (hqresRestartTotal > hqmaxresRestartTotal) {
warningQuda("CG: solver exiting due to too many heavy quark residual restarts (%i/%i)", hqresRestartTotal,
hqmaxresRestartTotal);
break;
}
}
if(ru.reliable_heavy_quark_break(L2breakdown, heavy_quark_res, heavy_quark_res_old, heavy_quark_restart)) { break; }

if (use_heavy_quark_res and heavy_quark_restart) {
// perform a restart
Expand All @@ -715,9 +599,7 @@ namespace quda {
blas::xpayz(rSloppy, beta, *p[j], *p[0]);
}

steps_since_reliable = 0;
r0Norm = sqrt(r2);
rUpdate++;
ru.reset(r2);

heavy_quark_res_old = heavy_quark_res;
}
Expand All @@ -734,20 +616,20 @@ namespace quda {
// L2 is converged or precision maxed out for L2
bool L2done = L2breakdown or convergenceL2(r2, heavy_quark_res, stop, param.tol_hq);
// HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update
bool HQdone = (steps_since_reliable == 0 and param.delta > 0) and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq);
bool HQdone = (ru.steps_since_reliable == 0 and param.delta > 0) and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq);
converged = L2done and HQdone;
}

// if we have converged and need to update any trailing solutions
if (converged && steps_since_reliable > 0 && (j+1)%Np != 0 ) {
if (converged && ru.steps_since_reliable > 0 && (j+1)%Np != 0 ) {
std::vector<ColorSpinorField*> x_;
x_.push_back(&xSloppy);
std::vector<ColorSpinorField*> p_;
for (int i=0; i<=j; i++) p_.push_back(p[i]);
blas::axpy(alpha.get(), p_, x_);
}

j = steps_since_reliable == 0 ? 0 : (j+1)%Np; // if just done a reliable update then reset j
j = ru.steps_since_reliable == 0 ? 0 : (j+1)%Np; // if just done a reliable update then reset j
}

blas::copy(x, xSloppy);
Expand All @@ -766,7 +648,7 @@ namespace quda {
}

if (getVerbosity() >= QUDA_VERBOSE)
printfQuda("CG: Reliable updates = %d\n", rUpdate);
printfQuda("CG: Reliable updates = %d\n", ru.rUpdate);

if (advanced_feature && param.compute_true_res) {
// compute the true residuals
Expand Down Expand Up @@ -1665,7 +1547,7 @@ void CG::solve(ColorSpinorField& x, ColorSpinorField& b) {
// L2 is concverged or precision maxed out for L2
bool L2done = L2breakdown or convergenceL2(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
// HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update
bool HQdone = (steps_since_reliable == 0 and param.delta > 0) and convergenceHQ(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
bool HQdone = (ru.steps_since_reliable == 0 and param.delta > 0) and convergenceHQ(r2(i,i).real(), heavy_quark_res[i], stop[i], param.tol_hq);
converged[i] = L2done and HQdone;
}
}
Expand Down
Loading

0 comments on commit 3ab9b4e

Please sign in to comment.