From 1cdea66a41ca18370a27345cfbcba9017a72c9ba Mon Sep 17 00:00:00 2001 From: Rohan Julka Date: Tue, 12 Mar 2024 00:50:25 +0000 Subject: [PATCH] Differentiate for loop condition expression (#746) --- lib/Differentiator/ReverseModeVisitor.cpp | 12 +++-- test/Gradient/Loops.C | 59 +++++++++++++++++++++++ 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index a181785ca..e87a459a1 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1058,9 +1058,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // FIXME: for now we assume that cond has no differentiable effects, // but it is not generally true, e.g. for (...; (x = y); ...)... - StmtDiff cond; - if (FS->getCond()) - cond = Visit(FS->getCond()); + StmtDiff condDiff; + StmtDiff condExprDiff; + if (FS->getCond()) { + std::tie(condDiff, condExprDiff) = DifferentiateSingleExpr(FS->getCond()); + addToCurrentBlock(unwrapIfSingleStmt(condDiff.getStmt())); + } const auto* IDRE = dyn_cast(FS->getInc()); const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc(); @@ -1108,7 +1111,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, /// with function globals and replace initializations with assignments. /// This is a temporary measure to avoid the bug that arises from /// overwriting local variables on different loop passes. - Expr* forwardCond = cond.getExpr(); + Expr* forwardCond = condExprDiff.getExpr(); /// If there is a declaration in the condition, `cond` will be /// a DeclRefExpr of the declared variable. There is no point in /// inserting it since condVarRes.getExpr() represents an assignment with @@ -1145,6 +1148,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Forward = endBlock(direction::forward); addToCurrentBlock(loopCounter.getPop(), direction::reverse); addToCurrentBlock(initResult.getStmt_dx(), direction::reverse); + addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse); addToCurrentBlock(Reverse, direction::reverse); Reverse = endBlock(direction::reverse); endScope(); diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 82de66836..39ceb3422 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -1626,6 +1626,63 @@ double f_loop_init_var(double lower, double upper) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn20(double i, double j) { + double res = 0; + for (int c = 0; (res = i * j); ++c) { + if (c == 1) + break; + } + return res; +} + +// CHECK: void fn20_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: int _d_c = 0; +// CHECK-NEXT: int c = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: clad::tape _t4 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, res); +// CHECK-NEXT: for (c = 0; (res = i * j); ++c) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: bool _t2 = c == 1; +// CHECK-NEXT: { +// CHECK-NEXT: if (_t2) { +// CHECK-NEXT: clad::push(_t4, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t3, _t2); +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t4, 2UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: for (; _t0; _t0--) +// CHECK-NEXT: switch (clad::pop(_t4)) { +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: --c; +// CHECK-NEXT: if (clad::pop(_t3)) +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: _d_res -= _r_d0; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -1692,4 +1749,6 @@ int main() { TEST_GRADIENT(fn19, 1, arr, 5, d_arr); TEST_2(f_loop_init_var, 1, 2); // CHECK-EXEC: {-1.00, 4.00} + TEST_2(fn20, 3, 5); // CHECK-EXEC: {5.00, 3.00} + }