Skip to content

Commit

Permalink
Directly set the adjoint of the LHS to 0 in assignments.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Jul 21, 2024
1 parent 7b71e25 commit f4dcf5c
Show file tree
Hide file tree
Showing 19 changed files with 151 additions and 146 deletions.
2 changes: 1 addition & 1 deletion demos/ErrorEstimation/CustomModel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ The code is: void func_grad(float x, float y, float *_d_x, float *_d_y, double &
_final_error += _d_z * z;
z = _t0;
float _r_d0 = _d_z;
_d_z -= _r_d0;
_d_z = 0;
*_d_x += _r_d0;
*_d_y += _r_d0;
}
Expand Down
2 changes: 1 addition & 1 deletion demos/ErrorEstimation/PrintModel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ The code is: void func_grad(float x, float y, float *_d_x, float *_d_y, double &
_final_error += _d_z * z;
z = _t0;
float _r_d0 = _d_z;
_d_z -= _r_d0;
_d_z = 0;
*_d_x += _r_d0;
*_d_y += _r_d0;
}
Expand Down
19 changes: 12 additions & 7 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2357,9 +2357,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d",
/*forceDeclCreation=*/true);
if (opCode == BO_Assign) {
// Add the statement `dl -= oldValue;`
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue),
direction::reverse);
if (!isPointerOp) {
// Add the statement `dl = 0;`
Expr* zero = getZeroInit(AssignedDiff->getType());
addToCurrentBlock(BuildOp(BO_Assign, AssignedDiff, zero),
direction::reverse);
}
Rdiff = Visit(R, oldValue);
valueForRevPass = Rdiff.getRevSweepAsExpr();
} else if (opCode == BO_AddAssign) {
Expand Down Expand Up @@ -2388,8 +2391,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// double r = _ref0 *= z;
if (isInsideLoop)
addToCurrentBlock(LCloned, direction::forward);
// Add the statement `dl -= oldValue;`
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue),
// Add the statement `dl = 0;`
Expr* zero = getZeroInit(AssignedDiff->getType());
addToCurrentBlock(BuildOp(BO_Assign, AssignedDiff, zero),
direction::reverse);
/// Capture all the emitted statements while visiting R
/// and insert them after `dl += dl * R`
Expand All @@ -2407,8 +2411,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Ldiff.getRevSweepAsExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, Rdiff.getExpr());
} else if (opCode == BO_DivAssign) {
// Add the statement `dl -= oldValue;`
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue),
// Add the statement `dl = 0;`
Expr* zero = getZeroInit(AssignedDiff->getType());
addToCurrentBlock(BuildOp(BO_Assign, AssignedDiff, zero),
direction::reverse);
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
Expand Down
16 changes: 8 additions & 8 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ float func(float* a, float* b) {
//CHECK-NEXT: {
//CHECK-NEXT: a[i] = clad::pop(_t1);
//CHECK-NEXT: float _r_d0 = _d_a[i];
//CHECK-NEXT: _d_a[i] -= _r_d0;
//CHECK-NEXT: _d_a[i] = 0;
//CHECK-NEXT: _d_a[i] += _r_d0 * b[i];
//CHECK-NEXT: _d_b[i] += a[i] * _r_d0;
//CHECK-NEXT: }
Expand Down Expand Up @@ -293,7 +293,7 @@ double func5(int k) {
//CHECK-NEXT: {
//CHECK-NEXT: arr[i] = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_arr[i];
//CHECK-NEXT: _d_arr[i] -= _r_d0;
//CHECK-NEXT: _d_arr[i] = 0;
//CHECK-NEXT: *_d_k += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down Expand Up @@ -401,7 +401,7 @@ double func7(double *params) {
//CHECK-NEXT: {
//CHECK-NEXT: out = clad::pop(_t2);
//CHECK-NEXT: double _r_d0 = _d_out;
//CHECK-NEXT: _d_out -= _r_d0;
//CHECK-NEXT: _d_out = 0;
//CHECK-NEXT: _d_out += _r_d0;
//CHECK-NEXT: inv_square_pullback(paramsPrime, _r_d0, _d_paramsPrime);
//CHECK-NEXT: }
Expand Down Expand Up @@ -443,12 +443,12 @@ double func8(double i, double *arr, int n) {
//CHECK-NEXT: {
//CHECK-NEXT: arr[0] = _t2;
//CHECK-NEXT: double _r_d2 = _d_arr[0];
//CHECK-NEXT: _d_arr[0] -= _r_d2;
//CHECK-NEXT: _d_arr[0] = 0;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: res = _t1;
//CHECK-NEXT: double _r_d1 = _d_res;
//CHECK-NEXT: _d_res -= _r_d1;
//CHECK-NEXT: _d_res = 0;
//CHECK-NEXT: double _r0 = 0;
//CHECK-NEXT: int _r1 = 0;
//CHECK-NEXT: helper2_pullback(i, arr, n, _r_d1, &_r0, _d_arr, &_r1);
Expand All @@ -458,7 +458,7 @@ double func8(double i, double *arr, int n) {
//CHECK-NEXT: {
//CHECK-NEXT: arr[0] = _t0;
//CHECK-NEXT: double _r_d0 = _d_arr[0];
//CHECK-NEXT: _d_arr[0] -= _r_d0;
//CHECK-NEXT: _d_arr[0] = 0;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down Expand Up @@ -695,7 +695,7 @@ int main() {
//CHECK-NEXT: {
//CHECK-NEXT: elem = _t0;
//CHECK-NEXT: double _r_d0 = *_d_elem;
//CHECK-NEXT: *_d_elem -= _r_d0;
//CHECK-NEXT: *_d_elem = 0;
//CHECK-NEXT: *_d_val += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand All @@ -708,7 +708,7 @@ int main() {
//CHECK-NEXT: {
//CHECK-NEXT: elem = _t0;
//CHECK-NEXT: double _r_d0 = *_d_elem;
//CHECK-NEXT: *_d_elem -= _r_d0;
//CHECK-NEXT: *_d_elem = 0;
//CHECK-NEXT: *_d_elem += _r_d0 * elem;
//CHECK-NEXT: *_d_elem += elem * _r_d0;
//CHECK-NEXT: }
Expand Down
2 changes: 1 addition & 1 deletion test/CUDA/GradientCuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ __device__ __host__ double gauss(double* x, double* p, double sigma, int dim) {
//CHECK-NEXT: {
//CHECK-NEXT: t = _t2;
//CHECK-NEXT: double _r_d1 = _d_t;
//CHECK-NEXT: _d_t -= _r_d1;
//CHECK-NEXT: _d_t = 0;
//CHECK-NEXT: _d_t += -_r_d1 / _t3;
//CHECK-NEXT: double _r0 = _r_d1 * -(-t / (_t3 * _t3));
//CHECK-NEXT: _d_sigma += 2 * _r0 * sigma;
Expand Down
14 changes: 7 additions & 7 deletions test/ErrorEstimation/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ float func(float x, float y) {
//CHECK-NEXT: {
//CHECK-NEXT: y = _t1;
//CHECK-NEXT: float _r_d1 = *_d_y;
//CHECK-NEXT: *_d_y -= _r_d1;
//CHECK-NEXT: *_d_y = 0;
//CHECK-NEXT: *_d_x += _r_d1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_x = 0;
//CHECK-NEXT: *_d_x += _r_d0;
//CHECK-NEXT: *_d_y += _r_d0;
//CHECK-NEXT: }
Expand All @@ -52,7 +52,7 @@ float func2(float x, int y) {
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_x = 0;
//CHECK-NEXT: *_d_y += _r_d0 * x;
//CHECK-NEXT: *_d_x += y * _r_d0;
//CHECK-NEXT: *_d_x += _r_d0 * x;
Expand All @@ -74,7 +74,7 @@ float func3(int x, int y) {
//CHECK-NEXT: {
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: int _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_x = 0;
//CHECK-NEXT: *_d_y += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand All @@ -96,7 +96,7 @@ float func4(float x, float y) {
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_x = 0;
//CHECK-NEXT: _d_z += _r_d0;
//CHECK-NEXT: *_d_y += _r_d0;
//CHECK-NEXT: }
Expand All @@ -122,7 +122,7 @@ float func5(float x, float y) {
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_x = 0;
//CHECK-NEXT: _d_z += _r_d0;
//CHECK-NEXT: *_d_y += _r_d0;
//CHECK-NEXT: }
Expand Down Expand Up @@ -164,7 +164,7 @@ float func8(int x, int y) {
//CHECK-NEXT: {
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: int _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_x = 0;
//CHECK-NEXT: *_d_y += _r_d0 * y;
//CHECK-NEXT: *_d_y += y * _r_d0;
//CHECK-NEXT: }
Expand Down
14 changes: 7 additions & 7 deletions test/ErrorEstimation/BasicOps.C
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ float func(float x, float y) {
//CHECK-NEXT: _final_error += std::abs(*_d_y * y * {{.+}});
//CHECK-NEXT: y = _t1;
//CHECK-NEXT: float _r_d1 = *_d_y;
//CHECK-NEXT: *_d_y -= _r_d1;
//CHECK-NEXT: *_d_y = 0;
//CHECK-NEXT: *_d_y += _r_d1;
//CHECK-NEXT: *_d_y += _r_d1;
//CHECK-NEXT: y--;
Expand All @@ -43,7 +43,7 @@ float func(float x, float y) {
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_x = 0;
//CHECK-NEXT: *_d_x += _r_d0;
//CHECK-NEXT: *_d_y += _r_d0;
//CHECK-NEXT: }
Expand Down Expand Up @@ -76,7 +76,7 @@ float func2(float x, float y) {
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_x = 0;
//CHECK-NEXT: *_d_x += _r_d0;
//CHECK-NEXT: *_d_y += -_r_d0;
//CHECK-NEXT: *_d_y += -_r_d0 * y;
Expand Down Expand Up @@ -116,7 +116,7 @@ float func3(float x, float y) {
//CHECK-NEXT: *_d_y += x * z * _d_t;
//CHECK-NEXT: y = _t2;
//CHECK-NEXT: float _r_d1 = *_d_y;
//CHECK-NEXT: *_d_y -= _r_d1;
//CHECK-NEXT: *_d_y = 0;
//CHECK-NEXT: *_d_x += _r_d1;
//CHECK-NEXT: *_d_x += _r_d1;
//CHECK-NEXT: }
Expand All @@ -125,7 +125,7 @@ float func3(float x, float y) {
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_x = 0;
//CHECK-NEXT: *_d_x += _r_d0;
//CHECK-NEXT: *_d_y += -_r_d0;
//CHECK-NEXT: *_d_y += -_r_d0 * y;
Expand Down Expand Up @@ -173,7 +173,7 @@ float func5(float x, float y) {
//CHECK-NEXT: _final_error += std::abs(*_d_y * y * {{.+}});
//CHECK-NEXT: y = _t0;
//CHECK-NEXT: float _r_d0 = *_d_y;
//CHECK-NEXT: *_d_y -= _r_d0;
//CHECK-NEXT: *_d_y = 0;
//CHECK-NEXT: float _r0 = 0;
//CHECK-NEXT: _r0 += _r_d0 * clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, 1.F).pushforward;
//CHECK-NEXT: *_d_x += _r0;
Expand Down Expand Up @@ -273,7 +273,7 @@ float func8(float x, float y) {
//CHECK-NEXT: {
//CHECK-NEXT: z = _t0;
//CHECK-NEXT: float _r_d0 = _d_z;
//CHECK-NEXT: _d_z -= _r_d0;
//CHECK-NEXT: _d_z = 0;
//CHECK-NEXT: *_d_y += _r_d0;
//CHECK-NEXT: x = _t1;
//CHECK-NEXT: double _t2 = 0;
Expand Down
8 changes: 4 additions & 4 deletions test/ErrorEstimation/ConditonalStatements.C
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,22 @@ float func(float x, float y) {
//CHECK-NEXT: _final_error += std::abs(*_d_y * y * {{.+}});
//CHECK-NEXT: y = _t0;
//CHECK-NEXT: float _r_d0 = *_d_y;
//CHECK-NEXT: *_d_y -= _r_d0;
//CHECK-NEXT: *_d_y = 0;
//CHECK-NEXT: *_d_y += _r_d0 * x;
//CHECK-NEXT: *_d_x += y * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: } else {
//CHECK-NEXT: {
//CHECK-NEXT: x = _t2;
//CHECK-NEXT: float _r_d2 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d2;
//CHECK-NEXT: *_d_x = 0;
//CHECK-NEXT: *_d_y += _r_d2;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _final_error += std::abs(_d_temp * temp * {{.+}});
//CHECK-NEXT: temp = _t1;
//CHECK-NEXT: float _r_d1 = _d_temp;
//CHECK-NEXT: _d_temp -= _r_d1;
//CHECK-NEXT: _d_temp = 0;
//CHECK-NEXT: *_d_y += _r_d1 * y;
//CHECK-NEXT: *_d_y += y * _r_d1;
//CHECK-NEXT: }
Expand Down Expand Up @@ -171,7 +171,7 @@ float func4(float x, float y) {
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: x = _t1;
//CHECK-NEXT: float _r_d1 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d1;
//CHECK-NEXT: *_d_x = 0;
//CHECK-NEXT: *_d_x += _r_d1 * x;
//CHECK-NEXT: *_d_x += x * _r_d1;
//CHECK-NEXT: }
Expand Down
14 changes: 7 additions & 7 deletions test/ErrorEstimation/LoopsAndArrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ float func2(float x) {
//CHECK-NEXT: _final_error += std::abs(_d_z * z * {{.+}});
//CHECK-NEXT: z = clad::pop(_t2);
//CHECK-NEXT: float _r_d0 = _d_z;
//CHECK-NEXT: _d_z -= _r_d0;
//CHECK-NEXT: _d_z = 0;
//CHECK-NEXT: _d_m += _r_d0;
//CHECK-NEXT: _d_m += _r_d0;
//CHECK-NEXT: }
Expand Down Expand Up @@ -135,23 +135,23 @@ float func3(float x, float y) {
//CHECK-NEXT: _final_error += std::abs(_d_arr[2] * arr[2] * {{.+}});
//CHECK-NEXT: arr[2] = _t2;
//CHECK-NEXT: double _r_d2 = _d_arr[2];
//CHECK-NEXT: _d_arr[2] -= _r_d2;
//CHECK-NEXT: _d_arr[2] = 0;
//CHECK-NEXT: _d_arr[0] += _r_d2;
//CHECK-NEXT: _d_arr[1] += _r_d2;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _final_error += std::abs(_d_arr[1] * arr[1] * {{.+}});
//CHECK-NEXT: arr[1] = _t1;
//CHECK-NEXT: double _r_d1 = _d_arr[1];
//CHECK-NEXT: _d_arr[1] -= _r_d1;
//CHECK-NEXT: _d_arr[1] = 0;
//CHECK-NEXT: *_d_x += _r_d1 * x;
//CHECK-NEXT: *_d_x += x * _r_d1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _final_error += std::abs(_d_arr[0] * arr[0] * {{.+}});
//CHECK-NEXT: arr[0] = _t0;
//CHECK-NEXT: double _r_d0 = _d_arr[0];
//CHECK-NEXT: _d_arr[0] -= _r_d0;
//CHECK-NEXT: _d_arr[0] = 0;
//CHECK-NEXT: *_d_x += _r_d0;
//CHECK-NEXT: *_d_y += _r_d0;
//CHECK-NEXT: }
Expand Down Expand Up @@ -257,7 +257,7 @@ double func5(double* x, double* y, double* output) {
//CHECK-NEXT: _final_error += std::abs(_d_output[2] * output[2] * {{.+}});
//CHECK-NEXT: output[2] = _t2;
//CHECK-NEXT: double _r_d2 = _d_output[2];
//CHECK-NEXT: _d_output[2] -= _r_d2;
//CHECK-NEXT: _d_output[2] = 0;
//CHECK-NEXT: _d_x[0] += _r_d2 * y[1];
//CHECK-NEXT: x_size = std::max(x_size, 0);
//CHECK-NEXT: _d_y[1] += x[0] * _r_d2;
Expand All @@ -272,7 +272,7 @@ double func5(double* x, double* y, double* output) {
//CHECK-NEXT: _final_error += std::abs(_d_output[1] * output[1] * {{.+}});
//CHECK-NEXT: output[1] = _t1;
//CHECK-NEXT: double _r_d1 = _d_output[1];
//CHECK-NEXT: _d_output[1] -= _r_d1;
//CHECK-NEXT: _d_output[1] = 0;
//CHECK-NEXT: _d_x[2] += _r_d1 * y[0];
//CHECK-NEXT: x_size = std::max(x_size, 2);
//CHECK-NEXT: _d_y[0] += x[2] * _r_d1;
Expand All @@ -287,7 +287,7 @@ double func5(double* x, double* y, double* output) {
//CHECK-NEXT: _final_error += std::abs(_d_output[0] * output[0] * {{.+}});
//CHECK-NEXT: output[0] = _t0;
//CHECK-NEXT: double _r_d0 = _d_output[0];
//CHECK-NEXT: _d_output[0] -= _r_d0;
//CHECK-NEXT: _d_output[0] = 0;
//CHECK-NEXT: _d_x[1] += _r_d0 * y[2];
//CHECK-NEXT: x_size = std::max(x_size, 1);
//CHECK-NEXT: _d_y[2] += x[1] * _r_d0;
Expand Down
Loading

0 comments on commit f4dcf5c

Please sign in to comment.