From 709bc623b8b9697634bf8cad3b9edd474220bdef Mon Sep 17 00:00:00 2001 From: Rohan Julka Date: Tue, 12 Mar 2024 00:50:25 +0000 Subject: [PATCH] Differentiate for loop condition expression (#746) --- .../clad/Differentiator/ReverseModeVisitor.h | 1 + lib/Differentiator/ReverseModeVisitor.cpp | 60 ++- test/Gradient/Assignments.C | 12 +- test/Gradient/Loops.C | 460 ++++++++++++++++++ 4 files changed, 508 insertions(+), 25 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 10564dbbf..00c54cbc6 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -495,6 +495,7 @@ namespace clad { LoopCounter& loopCounter, clang::Stmt* condVarDifff = nullptr, clang::Stmt* forLoopIncDiff = nullptr, + clang::Stmt* condDiff = nullptr, bool isForLoop = false); /// This class modifies forward and reverse blocks of the loop/switch diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index a181785ca..30656542f 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1056,11 +1056,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } + auto CommaJoin = [this](Expr* Acc, Stmt* S) { + Expr* E = cast(S); + return BuildOp(BO_Comma, E, BuildParens(Acc)); + }; + // FIXME: for now we assume that cond has no differentiable effects, // but it is not generally true, e.g. for (...; (x = y); ...)... - StmtDiff cond; - if (FS->getCond()) - cond = Visit(FS->getCond()); + StmtDiff condDiff; + StmtDiff condExprDiff; + StmtDiff condDiffOuter; + StmtDiff condExprDiffOuter; + if (FS->getCond()) { + std::tie(condDiff, condExprDiff) = DifferentiateSingleExpr(FS->getCond()); + std::tie(condDiffOuter, condExprDiffOuter) = + DifferentiateSingleExpr(FS->getCond()); + } + const auto* IDRE = dyn_cast(FS->getInc()); const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc(); @@ -1088,10 +1100,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } // Otherwise, join all exprs by comma operator. else if (incExprDiff.getExpr()) { - auto CommaJoin = [this](Expr* Acc, Stmt* S) { - Expr* E = cast(S); - return BuildOp(BO_Comma, E, BuildParens(Acc)); - }; incResult = std::accumulate(Additional->body_rbegin(), Additional->body_rend(), incExprDiff.getExpr(), @@ -1099,16 +1107,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } const Stmt* body = FS->getBody(); - StmtDiff BodyDiff = DifferentiateLoopBody(body, loopCounter, - condVarRes.getStmt_dx(), - incDiff.getStmt_dx(), - /*isForLoop=*/true); + StmtDiff BodyDiff = + DifferentiateLoopBody(body, loopCounter, condVarRes.getStmt_dx(), + incDiff.getStmt_dx(), condDiff.getStmt_dx(), + /*isForLoop=*/true); /// FIXME: This part in necessary to replace local variables inside loops /// with function globals and replace initializations with assignments. /// This is a temporary measure to avoid the bug that arises from /// overwriting local variables on different loop passes. - Expr* forwardCond = cond.getExpr(); + Expr* forwardCond = condExprDiff.getExpr(); /// If there is a declaration in the condition, `cond` will be /// a DeclRefExpr of the declared variable. There is no point in /// inserting it since condVarRes.getExpr() represents an assignment with @@ -1118,8 +1126,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (condVarRes.getExpr() != nullptr && isa(condVarRes.getExpr())) forwardCond = cast(condVarRes.getExpr()); + auto* AdditionalStmts = cast(condDiff.getStmt()); + Expr* condResult = + std::accumulate(AdditionalStmts->body_rbegin(), + AdditionalStmts->body_rend(), forwardCond, CommaJoin); + Stmt* Forward = new (m_Context) - ForStmt(m_Context, initResult.getStmt(), forwardCond, condVarClone, + ForStmt(m_Context, initResult.getStmt(), condResult, condVarClone, incResult, BodyDiff.getStmt(), noLoc, noLoc, noLoc); // Create a condition testing counter for being zero, and its decrement. @@ -1141,10 +1154,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, noLoc, noLoc, noLoc); + addToCurrentBlock(unwrapIfSingleStmt(condDiffOuter.getStmt())); addToCurrentBlock(Forward, direction::forward); Forward = endBlock(direction::forward); addToCurrentBlock(loopCounter.getPop(), direction::reverse); addToCurrentBlock(initResult.getStmt_dx(), direction::reverse); + addToCurrentBlock(condDiffOuter.getStmt_dx(), direction::reverse); addToCurrentBlock(Reverse, direction::reverse); Reverse = endBlock(direction::reverse); endScope(); @@ -2580,8 +2595,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } else if (opCode == BO_Comma) { auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); - Ldiff = Visit(L, zero); Rdiff = Visit(R, dfdx()); + Ldiff = Visit(L, zero); valueForRevPass = Ldiff.getRevSweepAsExpr(); ResultRef = Ldiff.getExpr(); } else { @@ -3602,11 +3617,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return {endBlock(direction::forward), endBlock(direction::reverse)}; } - StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body, - LoopCounter& loopCounter, - Stmt* condVarDiff, - Stmt* forLoopIncDiff, - bool isForLoop) { + StmtDiff ReverseModeVisitor::DifferentiateLoopBody( + const Stmt* body, LoopCounter& loopCounter, Stmt* condVarDiff, + Stmt* forLoopIncDiff, Stmt* condDiff, bool isForLoop) { Expr* counterIncrement = loopCounter.getCounterIncrement(); auto* activeBreakContHandler = PushBreakContStmtHandler(); activeBreakContHandler->BeginCFSwitchStmtScope(); @@ -3656,6 +3669,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } + if (condDiff) { + if (bodyDiff.getStmt_dx()) { + bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt( + m_Context, bodyDiff.getStmt_dx(), condDiff)); + } else { + bodyDiff.updateStmtDx(condDiff); + } + } + activeBreakContHandler->EndCFSwitchStmtScope(); activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); PopBreakContStmtHandler(); diff --git a/test/Gradient/Assignments.C b/test/Gradient/Assignments.C index f209372ab..eb1b0522e 100644 --- a/test/Gradient/Assignments.C +++ b/test/Gradient/Assignments.C @@ -852,12 +852,12 @@ double f21 (double x, double y) { //CHECK-NEXT: _label0: //CHECK-NEXT: * _d_y += 1; //CHECK-NEXT: { -//CHECK-NEXT: y = _t0; -//CHECK-NEXT: double _r_d0 = * _d_y; -//CHECK-NEXT: * _d_y -= _r_d0; -//CHECK-NEXT: * _d_y += 0; -//CHECK-NEXT: y--; -//CHECK-NEXT: * _d_x += _r_d0; +// CHECK-NEXT: y = _t0; +// CHECK-NEXT: double _r_d0 = * _d_y; +// CHECK-NEXT: * _d_y -= _r_d0; +// CHECK-NEXT: * _d_x += _r_d0; +// CHECK-NEXT: * _d_y += 0; +// CHECK-NEXT: y--; //CHECK-NEXT: } //CHECK-NEXT: } diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 82de66836..455d8d0b1 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -1626,6 +1626,459 @@ double f_loop_init_var(double lower, double upper) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn20(double i, double j) { + double res = 0; + for (int c = 0; (res = i * j); ++c) { + if (c == 1) + break; + } + return res; +} + +// CHECK: void fn20_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: int _d_c = 0; +// CHECK-NEXT: int c = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: clad::tape _t4 = {}; +// CHECK-NEXT: clad::tape _t5 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t2, res); +// CHECK-NEXT: for (c = 0; clad::push(_t1, res) , (res = i * j); ++c) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: bool _t3 = c == 1; +// CHECK-NEXT: { +// CHECK-NEXT: if (_t3) { +// CHECK-NEXT: clad::push(_t5, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t4, _t3); +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t5, 2UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: for (; _t0; _t0--) +// CHECK-NEXT: switch (clad::pop(_t5)) { +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: _d_res -= _r_d0; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: --c; +// CHECK-NEXT: if (clad::pop(_t4)) +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t2); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: _d_res -= _r_d1; +// CHECK-NEXT: * _d_i += _r_d1 * j; +// CHECK-NEXT: * _d_j += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn21(double i, double j) { + double res = 0; + for (int c = 0; (res += i * j), c<3; ++c) { + } + return res; +} + +// CHECK: void fn21_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: int _d_c = 0; +// CHECK-NEXT: int c = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t2, res); +// CHECK-NEXT: for (c = 0; clad::push(_t1, res) , ((res += i * j) , c < 3); ++c) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: for (; _t0; _t0--) { +// CHECK-NEXT: { +// CHECK-NEXT: _d_res += 0; +// CHECK-NEXT: res = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: --c; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: _d_res += 0; +// CHECK-NEXT: res = clad::pop(_t2); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += _r_d1 * j; +// CHECK-NEXT: * _d_j += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn22(double i, double j) { + double res = 0; + for (int c = 0; (res += i * j), c<3; ++c) { + if (c == 1) { + res = 8 * i * j; + break; + } + } + return res; +} + +// CHECK: void fn22_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: int _d_c = 0; +// CHECK-NEXT: int c = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: clad::tape _t4 = {}; +// CHECK-NEXT: clad::tape _t5 = {}; +// CHECK-NEXT: clad::tape _t6 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t2, res); +// CHECK-NEXT: for (c = 0; clad::push(_t1, res) , ((res += i * j) , c < 3); ++c) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: bool _t3 = c == 1; +// CHECK-NEXT: { +// CHECK-NEXT: if (_t3) { +// CHECK-NEXT: clad::push(_t5, res); +// CHECK-NEXT: res = 8 * i * j; +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t6, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t4, _t3); +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t6, 2UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: for (; _t0; _t0--) +// CHECK-NEXT: switch (clad::pop(_t6)) { +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: _d_res += 0; +// CHECK-NEXT: res = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: --c; +// CHECK-NEXT: if (clad::pop(_t4)) { +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t5); +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: _d_res -= _r_d2; +// CHECK-NEXT: * _d_i += 8 * _r_d2 * j; +// CHECK-NEXT: * _d_j += 8 * i * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: _d_res += 0; +// CHECK-NEXT: res = clad::pop(_t2); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += _r_d1 * j; +// CHECK-NEXT: * _d_j += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn23(double i, double j) { + double res = 0; + for (int c = 0; (res += i * j); ++c, res=7*i*j) { + if(c == 0) + break; + } + return res; +} + +// CHECK: void fn23_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: int _d_c = 0; +// CHECK-NEXT: int c = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: clad::tape _t5 = {}; +// CHECK-NEXT: clad::tape _t6 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t2, res); +// CHECK-NEXT: for (c = 0; clad::push(_t1, res) , (res += i * j); clad::push(_t3, res) , (++c , res = 7 * i * j)) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: bool _t4 = c == 0; +// CHECK-NEXT: { +// CHECK-NEXT: if (_t4) { +// CHECK-NEXT: clad::push(_t6, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t5, _t4); +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t6, 2UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: for (; _t0; _t0--) +// CHECK-NEXT: switch (clad::pop(_t6)) { +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t3); +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: _d_res -= _r_d2; +// CHECK-NEXT: * _d_i += 7 * _r_d2 * j; +// CHECK-NEXT: * _d_j += 7 * i * _r_d2; +// CHECK-NEXT: _d_c += 0; +// CHECK-NEXT: --c; +// CHECK-NEXT: } +// CHECK-NEXT: if (clad::pop(_t5)) +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t2); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += _r_d1 * j; +// CHECK-NEXT: * _d_j += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn24(double i, double j) { + double res = 0; + for (int c = 0; (res += i * j); ++c) { + if(c == 2) + break; + + res = c * i * j; + } + return res; +} + +// CHECK: void fn24_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: int _d_c = 0; +// CHECK-NEXT: int c = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: clad::tape _t4 = {}; +// CHECK-NEXT: clad::tape _t5 = {}; +// CHECK-NEXT: clad::tape _t6 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t2, res); +// CHECK-NEXT: for (c = 0; clad::push(_t1, res) , (res += i * j); ++c) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: bool _t3 = c == 2; +// CHECK-NEXT: { +// CHECK-NEXT: if (_t3) { +// CHECK-NEXT: clad::push(_t5, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t4, _t3); +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t6, res); +// CHECK-NEXT: res = c * i * j; +// CHECK-NEXT: clad::push(_t5, 2UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: for (; _t0; _t0--) +// CHECK-NEXT: switch (clad::pop(_t5)) { +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: --c; +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t6); +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: _d_res -= _r_d2; +// CHECK-NEXT: _d_c += _r_d2 * j * i; +// CHECK-NEXT: * _d_i += c * _r_d2 * j; +// CHECK-NEXT: * _d_j += c * i * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (clad::pop(_t4)) +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t2); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += _r_d1 * j; +// CHECK-NEXT: * _d_j += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn25(double i, double j) { + double res = 0; + for (int c = 0; (res = i * j), c < 2; ++c) { + res = 3 * i * j; + } + return res; +} + +// CHECK: void fn25_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: int _d_c = 0; +// CHECK-NEXT: int c = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t2, res); +// CHECK-NEXT: for (c = 0; clad::push(_t1, res) , ((res = i * j) , c < 2); ++c) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: clad::push(_t3, res); +// CHECK-NEXT: res = 3 * i * j; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: for (; _t0; _t0--) { +// CHECK-NEXT: { +// CHECK-NEXT: _d_res += 0; +// CHECK-NEXT: res = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: _d_res -= _r_d0; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: --c; +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t3); +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: _d_res -= _r_d2; +// CHECK-NEXT: * _d_i += 3 * _r_d2 * j; +// CHECK-NEXT: * _d_j += 3 * i * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: _d_res += 0; +// CHECK-NEXT: res = clad::pop(_t2); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: _d_res -= _r_d1; +// CHECK-NEXT: * _d_i += _r_d1 * j; +// CHECK-NEXT: * _d_j += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn26(double i, double j) { + double res = 0; + for (int c = 0; res = i * j, c<1; ++c, res=3*i*j) {} + return res; +} + +// CHECK: void fn26_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: int _d_c = 0; +// CHECK-NEXT: int c = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t2, res); +// CHECK-NEXT: for (c = 0; clad::push(_t1, res) , (res = i * j , c < 1); clad::push(_t3, res) , (++c , res = 3 * i * j)) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: for (; _t0; _t0--) { +// CHECK-NEXT: { +// CHECK-NEXT: _d_res += 0; +// CHECK-NEXT: res = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: _d_res -= _r_d0; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t3); +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: _d_res -= _r_d2; +// CHECK-NEXT: * _d_i += 3 * _r_d2 * j; +// CHECK-NEXT: * _d_j += 3 * i * _r_d2; +// CHECK-NEXT: _d_c += 0; +// CHECK-NEXT: --c; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: _d_res += 0; +// CHECK-NEXT: res = clad::pop(_t2); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: _d_res -= _r_d1; +// CHECK-NEXT: * _d_i += _r_d1 * j; +// CHECK-NEXT: * _d_j += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -1692,4 +2145,11 @@ int main() { TEST_GRADIENT(fn19, 1, arr, 5, d_arr); TEST_2(f_loop_init_var, 1, 2); // CHECK-EXEC: {-1.00, 4.00} + TEST_2(fn20, 3, 5); // CHECK-EXEC: {5.00, 3.00} + TEST_2(fn21, 3, 5); // CHECK-EXEC: {20.00, 12.00} + TEST_2(fn22, 3, 5); // CHECK-EXEC: {40.00, 24.00} + TEST_2(fn23, 3, 5); // CHECK-EXEC: {5.00, 3.00} + TEST_2(fn24, 3, 5); // CHECK-EXEC: {10.00, 6.00} + TEST_2(fn25, 3, 5); // CHECK-EXEC: {5.00, 3.00} + TEST_2(fn26, 3, 5); // CHECK-EXEC: {5.00, 3.00} }