diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 17e6ba6a4..74428186e 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -85,6 +85,7 @@ class BaseForwardModeVisitor set to 0", args); } + StmtDiff VisitCXXForRangeStmt(const clang::CXXForRangeStmt* FRS); StmtDiff VisitWhileStmt(const clang::WhileStmt* WS); StmtDiff VisitDoStmt(const clang::DoStmt* DS); StmtDiff VisitContinueStmt(const clang::ContinueStmt* ContStmt); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 56a192335..ce23eb1fa 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -683,6 +683,91 @@ StmtDiff BaseForwardModeVisitor::VisitConditionalOperator( return StmtDiff(condExpr, condExprDiff); } +StmtDiff +BaseForwardModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) { + beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | + Scope::ContinueScope); + // Visiting for range-based ststement produces __range1, __begin1 and __end1 + // variables, so for(auto i: a){ + // ... + //} + // + // is equivalent to + // + // auto&& __range1 = a + // auto __begin1 = __range1; + // auto __end1 = __range1 + OUL + // for(;__begin != __end1; ++__begin){ + // auto i = *__begin1; + // ... + //} + // + const VarDecl* Item = FRS->getLoopVariable(); + + const auto* RangeDecl = + cast(FRS->getRangeStmt()->getSingleDecl()); + const auto* BeginDecl = + cast(FRS->getBeginStmt()->getSingleDecl()); + const auto* EndDecl = cast(FRS->getEndStmt()->getSingleDecl()); + + // Add range, begin and their's adjoints to the current block. + DeclDiff RangeDeclDiff = DifferentiateVarDecl(RangeDecl); + addToCurrentBlock(BuildDeclStmt(RangeDeclDiff.getDecl_dx())); + addToCurrentBlock(BuildDeclStmt(RangeDeclDiff.getDecl())); + + DeclDiff BeginDeclDiff = DifferentiateVarDecl(BeginDecl); + addToCurrentBlock(BuildDeclStmt(BeginDeclDiff.getDecl_dx())); + addToCurrentBlock(BuildDeclStmt(BeginDeclDiff.getDecl())); + + VarDecl* EndDeclDiff = DifferentiateVarDecl(EndDecl).getDecl(); + addToCurrentBlock(BuildDeclStmt(EndDeclDiff)); + + // Build begin preincrementation. + DeclRefExpr* BeginExpr = BuildDeclRef(BeginDeclDiff.getDecl()); + Expr* IncBegin = BuildOp(UO_PreInc, BeginExpr); + + // Build begin preincrementation. + DeclRefExpr* d_BeginExpr = BuildDeclRef(BeginDeclDiff.getDecl_dx()); + Expr* d_IncBegin = BuildOp(UO_PreInc, d_BeginExpr); + + Expr* Inc = BuildOp(BO_Comma, d_IncBegin, IncBegin); + + // Build begin != end expretion. + DeclRefExpr* EndExpr = BuildDeclRef(EndDeclDiff); + Expr* cond = BuildOp(BO_NE, BeginExpr, EndExpr); + + DeclDiff ItemDiff = DifferentiateVarDecl(Item); + + // Differentiate body and add both Item and it's derivative. + const Stmt* body = FRS->getBody(); + Stmt* bodyResult = nullptr; + if (isa(body)) { + bodyResult = Visit(body).getStmt(); + Stmt* bodyWithItem1 = utils::PrependAndCreateCompoundStmt( + m_Sema.getASTContext(), bodyResult, BuildDeclStmt(ItemDiff.getDecl())); + bodyResult = utils::PrependAndCreateCompoundStmt( + m_Sema.getASTContext(), bodyWithItem1, + BuildDeclStmt(ItemDiff.getDecl_dx())); + } else { + beginScope(Scope::DeclScope); + beginBlock(); + addToCurrentBlock(BuildDeclStmt(ItemDiff.getDecl_dx())); + addToCurrentBlock(BuildDeclStmt(ItemDiff.getDecl())); + StmtDiff Result = Visit(body); + for (Stmt* S : Result.getBothStmts()) + addToCurrentBlock(S); + CompoundStmt* Block = endBlock(); + endScope(); + bodyResult = Block; + } + + Stmt* forStmtDiff = + new (m_Context) ForStmt(m_Context, nullptr, cond, /*condVar=*/nullptr, + Inc, bodyResult, noLoc, noLoc, noLoc); + + return StmtDiff(forStmtDiff); +} + StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | Scope::ContinueScope); diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 152984bcd..f670c5519 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -548,6 +548,109 @@ double fn17_darg0(double x); // CHECK-NEXT: return _d_x; // CHECK-NEXT: } +double fn18(double x, double y){ + int coefficients[3] = {4, 7, 3}; + double res = 0; + for(auto i: coefficients){ + if(i%2==0) + continue; + res+= x*y*i; + } + return res; +} + +double fn18_darg0(double x, double y); +// CHECK: double fn18_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: int _d_coefficients[3] = {0, 0, 0}; +// CHECK-NEXT: int coefficients[3] = {4, 7, 3}; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: int (&_d___range1)[3] = _d_coefficients; +// CHECK-NEXT: int (&__range10)[3] = coefficients; +// CHECK-NEXT: int *_d___begin1 = _d___range1; +// CHECK-NEXT: int *__begin10 = __range10; +// CHECK-NEXT: int *__end10 = __range10 + {{3|3L}}; +// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) { +// CHECK-NEXT: int _d_i = *_d___begin1; +// CHECK-NEXT: int i = *__begin10; +// CHECK-NEXT: if (i % 2 == 0) +// CHECK-NEXT: continue; +// CHECK-NEXT: double _t0 = x * y; +// CHECK-NEXT: _d_res += (_d_x * y + x * _d_y) * i + _t0 * _d_i; +// CHECK-NEXT: res += _t0 * i; +// CHECK-NEXT: } +// CHECK-NEXT: return _d_res; +// CHECK-NEXT: } + + +double fn19(double x, double y){ + double res = 0; + double f[] = {x*x, 2*x*y, y*y, x, y}; + for(auto i: f){ + if(i == x) + break; + res+=i; + } + return res; +} + +double fn19_darg0(double x, double y); +// CHECK: double fn19_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: double _t0 = 2 * x; +// CHECK-NEXT: double _d_f[5] = {_d_x * x + x * _d_x, (0 * x + 2 * _d_x) * y + _t0 * _d_y, _d_y * y + y * _d_y, _d_x, _d_y}; +// CHECK-NEXT: double f[5] = {x * x, _t0 * y, y * y, x, y}; +// CHECK-NEXT: double (&_d___range1)[5] = _d_f; +// CHECK-NEXT: double (&__range10)[5] = f; +// CHECK-NEXT: double *_d___begin1 = _d___range1; +// CHECK-NEXT: double *__begin10 = __range10; +// CHECK-NEXT: double *__end10 = __range10 + {{5|5L}}; +// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) { +// CHECK-NEXT: double _d_i = *_d___begin1; +// CHECK-NEXT: double i = *__begin10; +// CHECK-NEXT: if (i == x) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_res += _d_i; +// CHECK-NEXT: res += i; +// CHECK-NEXT: } +// CHECK-NEXT: return _d_res; +// CHECK-NEXT: } + +double fn20(double x){ + int a[] = {5}; + for(auto i: a){ + x+=i*x; + } + return x; +} + +// CHECK: double fn20_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: int _d_a[1] = {0}; +// CHECK-NEXT: int a[1] = {5}; +// CHECK-NEXT: int (&_d___range1)[1] = _d_a; +// CHECK-NEXT: int (&__range10)[1] = a; +// CHECK-NEXT: int *_d___begin1 = _d___range1; +// CHECK-NEXT: int *__begin10 = __range10; +// CHECK-NEXT: int *__end10 = __range10 + {{1|1L}}; +// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) { +// CHECK-NEXT: int _d_i = *_d___begin1; +// CHECK-NEXT: int i = *__begin10; +// CHECK-NEXT: _d_x += _d_i * x + i * _d_x; +// CHECK-NEXT: x += i * x; +// CHECK-NEXT: } +// CHECK-NEXT: return _d_x; +// CHECK-NEXT: } + + +double fn20_darg0(double x); + + #define TEST(fn)\ auto d_##fn = clad::differentiate(fn, "i");\ printf("%.2f\n", d_##fn.execute(3, 5)); @@ -614,4 +717,13 @@ int main() { clad::differentiate(fn17, 0); printf("Result is = %.2f\n", fn17_darg0(5)); // CHECK-EXEC: Result is = 0 + + clad::differentiate(fn18, 0); + printf("Result is = %.2f\n", fn18_darg0(5, 1)); // CHECK-EXEC: Result is = 10.00 + + clad::differentiate(fn19, 0); + printf("Result is = %.2f\n", fn19_darg0(5, 2)); // CHECK-EXEC: Result is = 14.00 + + clad::differentiate(fn20, 0); + printf("Result is = %.2f\n", fn20_darg0(5)); // CHECK-EXEC: Result is = 6.00 }