diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 889c0532e..3d23720d4 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -375,6 +375,7 @@ namespace clad { virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); StmtDiff VisitDeclStmt(const clang::DeclStmt* DS); StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL); + StmtDiff VisitCXXForRangeStmt(const clang::CXXForRangeStmt* FRS); StmtDiff VisitForStmt(const clang::ForStmt* FS); StmtDiff VisitIfStmt(const clang::IfStmt* If); StmtDiff VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 32727b0db..d9ddd7b37 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -974,6 +974,87 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(condExpr, ResultRef); } + StmtDiff + ReverseModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) { + beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | + Scope::ContinueScope); + beginBlock(direction::reverse); + + LoopCounter loopCounter(*this); + const auto* RangeDecl = FRS->getRangeStmt(); + const auto* BeginDecl = FRS->getBeginStmt(); + StmtDiff VisitRange = Visit(RangeDecl); + StmtDiff VisitBegin = Visit(BeginDecl); + Expr* BeginExpr = cast(VisitBegin.getStmt())->getLHS(); + + beginBlock(direction::reverse); + // Create all declarations needed. + Expr* d_BeginDeclRef = m_Variables[cast(BeginExpr)->getDecl()]; + + auto* RangeExpr = + cast(cast(VisitRange.getStmt())->getLHS()); + auto* BeginDeclRef = cast(BeginExpr); + + auto* RangeInit = Clone(FRS->getRangeInit()); + Expr* AssignRange = + BuildOp(BO_Assign, RangeExpr, BuildOp(UO_AddrOf, RangeInit)); + Expr* AssignBegin = + BuildOp(BO_Assign, BeginDeclRef, BuildOp(UO_Deref, RangeExpr)); + addToCurrentBlock(AssignRange); + addToCurrentBlock(AssignBegin); + const auto* EndDecl = cast(FRS->getEndStmt()->getSingleDecl()); + + Expr* EndInit = cast(EndDecl->getInit())->getRHS(); + QualType EndType = CloneType(EndDecl->getType()); + std::string EndName = EndDecl->getNameAsString(); + Expr* EndAssign = BuildOp(BO_Add, BuildOp(UO_Deref, RangeExpr), EndInit); + VarDecl* EndVarDecl = + BuildGlobalVarDecl(EndType, EndName, EndAssign, false); + DeclStmt* AssignEnd = BuildDeclStmt(EndVarDecl); + + addToCurrentBlock(AssignEnd); + DeclRefExpr* EndExpr = + BuildDeclRef(cast(cast(AssignEnd)->getSingleDecl())); + + llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); + isInsideLoop = true; + const VarDecl* LoopVD = FRS->getLoopVariable(); + Expr* IncBegin = BuildOp(UO_PreInc, BeginDeclRef); + Expr* d_IncBegin = BuildOp(UO_PreInc, d_BeginDeclRef); + + Expr* d_DecBegin = BuildOp(UO_PostDec, d_BeginDeclRef); + Expr* ForwardCond = BuildOp(BO_NE, BeginDeclRef, EndExpr); + // Add item assignment statement to the body. + Stmt* bodyClone = Clone(FRS->getBody()); + CompoundStmt* body = utils::PrependAndCreateCompoundStmt( + m_Sema.getASTContext(), bodyClone, + BuildDeclStmt(const_cast(LoopVD))); + StmtDiff BodyDiff = DifferentiateLoopBody( + body, loopCounter, /*condVarDiff*/ nullptr, d_DecBegin, + /*isForLoop=*/true); + + Expr* Inc = BuildOp(BO_Comma, IncBegin, d_IncBegin); + Stmt* Forward = new (m_Context) + ForStmt(m_Context, /*Init*/ nullptr, ForwardCond, /*CondVar*/ nullptr, + Inc, BodyDiff.getStmt(), FRS->getForLoc(), FRS->getBeginLoc(), + FRS->getEndLoc()); + Expr* CounterCondition = + loopCounter.getCounterConditionResult().get().second; + Expr* CounterDecrement = loopCounter.getCounterDecrement(); + Stmt* Reverse = BodyDiff.getStmt_dx(); + addToCurrentBlock(Reverse, direction::reverse); + Reverse = endBlock(direction::reverse); + Reverse = new (m_Context) + ForStmt(m_Context, /*Init*/ nullptr, CounterCondition, + /*CondVar*/ nullptr, CounterDecrement, Reverse, + FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc()); + addToCurrentBlock(Reverse, direction::reverse); + Reverse = endBlock(direction::reverse); + endScope(); + return {utils::unwrapIfSingleStmt(Forward), + utils::unwrapIfSingleStmt(Reverse)}; + } + StmtDiff ReverseModeVisitor::VisitForStmt(const ForStmt* FS) { beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | Scope::ContinueScope); diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index b61cf53b5..a536384d6 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -2732,6 +2732,68 @@ double fn33(double i, double j) { //CHECK-NEXT: } //CHECK-NEXT:} +double fn34(double x, double y){ + double r = 0; + double a[] = {y, x*y, x*x + y}; + for(auto& i: a){ + r+=i; + } + return r; +} + +// CHECK: void fn34_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double _d_a[3] = {0}; +// CHECK-NEXT: unsigned {{int|long}} _t0; +// CHECK-NEXT: double (*_d___range1)[3] = 0; +// CHECK-NEXT: double (*__range10)[3] = {}; +// CHECK-NEXT: double *_d___begin1 = 0; +// CHECK-NEXT: double *__begin10 = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: double *_d_i = 0; +// CHECK-NEXT: double *i = {}; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: double a[3] = {y, x * y, x * x + y}; +// CHECK-NEXT: _t0 = {{0U|0UL}}; +// CHECK-NEXT: _d___range1 = &_d_a; +// CHECK-NEXT: _d___begin1 = *_d___range1; +// CHECK-NEXT: __range10 = &a; +// CHECK-NEXT: __begin10 = *__range10; +// CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; +// CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: _d_i = &*_d___begin1; +// CHECK-NEXT: clad::push(_t1, _d_i); +// CHECK-NEXT: clad::push(_t2, i) , i = &*__begin10; +// CHECK-NEXT: clad::push(_t3, r); +// CHECK-NEXT: r += *i; +// CHECK-NEXT: } +// CHECK-NEXT: _d_r += 1; +// CHECK-NEXT: for (; _t0; _t0--) { +// CHECK-NEXT: { +// CHECK-NEXT: _d___begin1--; +// CHECK-NEXT: _d_i = clad::pop(_t1); +// CHECK-NEXT: { +// CHECK-NEXT: r = clad::pop(_t3); +// CHECK-NEXT: double _r_d0 = _d_r; +// CHECK-NEXT: *_d_i += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: i = clad::pop(_t2); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: *_d_y += _d_a[0]; +// CHECK-NEXT: *_d_x += _d_a[1] * y; +// CHECK-NEXT: *_d_y += x * _d_a[1]; +// CHECK-NEXT: *_d_x += _d_a[2] * x; +// CHECK-NEXT: *_d_x += x * _d_a[2]; +// CHECK-NEXT: *_d_y += _d_a[2]; +// CHECK-NEXT: } +// CHECK-NEXT: } + + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -2816,6 +2878,7 @@ int main() { TEST_2(fn32, 3, 5); // CHECK-EXEC: {45.00, 27.00} TEST_2(fn33, 3, 5); // CHECK-EXEC: {15.00, 9.00} + TEST_2(fn34, 5, 2); // CHECK-EXEC: {12.00, 7.00} } //CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {