diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 507698135..2c956f857 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -87,6 +87,7 @@ class BaseForwardModeVisitor set to 0", args); } + StmtDiff VisitCXXForRangeStmt(const clang::CXXForRangeStmt* FRS); StmtDiff VisitWhileStmt(const clang::WhileStmt* WS); StmtDiff VisitDoStmt(const clang::DoStmt* DS); StmtDiff VisitContinueStmt(const clang::ContinueStmt* ContStmt); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index b6eebb8b2..b331c7d99 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -683,6 +683,71 @@ StmtDiff BaseForwardModeVisitor::VisitConditionalOperator( return StmtDiff(condExpr, condExprDiff); } +StmtDiff +BaseForwardModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) { + beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | + Scope::ContinueScope); + // Visiting for range-based ststement produces __range1, __begin1 and __end1 + // variables, so for(auto i: a){ + // ... + //} + // + // is equivalent to + // + // auto&& __range1 = a + // auto __begin1 = __range1; + // auto __end1 = __range1 + OUL + // for(;__begin != __end1; ++__begin){ + // auto i = *__begin1; + // ... + //} + const auto* RangeDecl = FRS->getRangeStmt(); + const auto* BeginDecl = FRS->getBeginStmt(); + const auto* EndDecl = FRS->getEndStmt(); + + StmtDiff VisitRange = Visit(RangeDecl); + StmtDiff VisitBegin = Visit(BeginDecl); + StmtDiff VisitEnd = Visit(EndDecl); + addToCurrentBlock(VisitRange.getStmt_dx()); + addToCurrentBlock(VisitRange.getStmt()); + addToCurrentBlock(VisitBegin.getStmt_dx()); + addToCurrentBlock(VisitBegin.getStmt()); + addToCurrentBlock(VisitEnd.getStmt()); + // Build d_begin preincrementation. + + auto* BeginAdjExpr = BuildDeclRef( + cast(cast(VisitBegin.getStmt_dx())->getSingleDecl())); + // Build begin preincrementation. + + Expr* IncAdjBegin = BuildOp(UO_PreInc, BeginAdjExpr); + auto* BeginExpr = BuildDeclRef( + cast(cast(VisitBegin.getStmt())->getSingleDecl())); + Expr* IncBegin = BuildOp(UO_PreInc, BeginExpr); + Expr* Inc = BuildOp(BO_Comma, IncAdjBegin, IncBegin); + + auto* EndExpr = BuildDeclRef( + cast(cast(VisitEnd.getStmt())->getSingleDecl())); + // Build begin != end condition. + Expr* cond = BuildOp(BO_NE, BeginExpr, EndExpr); + + const VarDecl* Item = FRS->getLoopVariable(); + DeclDiff ItemDiff = DifferentiateVarDecl(Item); + // Differentiate body and add both Item and it's derivative. + const Stmt* body = FRS->getBody(); + Stmt* bodyResult = nullptr; + bodyResult = Visit(Clone(body)).getStmt(); + Stmt* bodyWithItem = utils::PrependAndCreateCompoundStmt( + m_Sema.getASTContext(), bodyResult, BuildDeclStmt(ItemDiff.getDecl())); + bodyResult = + utils::PrependAndCreateCompoundStmt(m_Sema.getASTContext(), bodyWithItem, + BuildDeclStmt(ItemDiff.getDecl_dx())); + + Stmt* forStmtDiff = + new (m_Context) ForStmt(m_Context, nullptr, cond, /*condVar=*/nullptr, + Inc, bodyResult, noLoc, noLoc, noLoc); + return StmtDiff(forStmtDiff); +} + StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | Scope::ContinueScope); diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 5534291b5..8605e53f4 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -582,6 +582,106 @@ double fn18(double u, double v) { // CHECK-NEXT: } +double fn19(double x, double y){ + double res = 0; + double f[] = {x*x, 2*x*y, y*y, x, y}; + for(auto i: f){ + if(i == x) + break; + res+=i; + } + return res; +} + +double fn19_darg0(double x, double y); +// CHECK: double fn19_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: double _t0 = 2 * x; +// CHECK-NEXT: double _d_f[5] = {_d_x * x + x * _d_x, (0 * x + 2 * _d_x) * y + _t0 * _d_y, _d_y * y + y * _d_y, _d_x, _d_y}; +// CHECK-NEXT: double f[5] = {x * x, _t0 * y, y * y, x, y}; +// CHECK-NEXT: double (&_d___range1)[5] = _d_f; +// CHECK-NEXT: double (&__range10)[5] = f; +// CHECK-NEXT: double *_d___begin1 = _d___range1; +// CHECK-NEXT: double *__begin10 = __range10; +// CHECK-NEXT: double *__end10 = __range10 + {{5|5L}}; +// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) { +// CHECK-NEXT: double _d_i = *_d___begin1; +// CHECK-NEXT: double i = *__begin10; +// CHECK-NEXT: if (i == x) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_res += _d_i; +// CHECK-NEXT: res += i; +// CHECK-NEXT: } +// CHECK-NEXT: return _d_res; +// CHECK-NEXT: } + +double fn20(double x){ + int a[] = {5}; + for(auto i: a){ + x+=i*x; + } + return x; +} + +double fn20_darg0(double x); +// CHECK: double fn20_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: int _d_a[1] = {0}; +// CHECK-NEXT: int a[1] = {5}; +// CHECK-NEXT: int (&_d___range1)[1] = _d_a; +// CHECK-NEXT: int (&__range10)[1] = a; +// CHECK-NEXT: int *_d___begin1 = _d___range1; +// CHECK-NEXT: int *__begin10 = __range10; +// CHECK-NEXT: int *__end10 = __range10 + {{1|1L}}; +// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) { +// CHECK-NEXT: int _d_i = *_d___begin1; +// CHECK-NEXT: int i = *__begin10; +// CHECK-NEXT: _d_x += _d_i * x + i * _d_x; +// CHECK-NEXT: x += i * x; +// CHECK-NEXT: } +// CHECK-NEXT: return _d_x; +// CHECK-NEXT: } + +double fn21(double x, double y){ + int coefficients[3] = {4, 7, 3}; + double res = 0; + for(auto& i: coefficients){ + if(i%2==0) + continue; + res+= x*y*i; + } + return res; +} + +double fn21_darg0(double x, double y); +// CHECK: double fn21_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: int _d_coefficients[3] = {0, 0, 0}; +// CHECK-NEXT: int coefficients[3] = {4, 7, 3}; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: int (&_d___range1)[3] = _d_coefficients; +// CHECK-NEXT: int (&__range10)[3] = coefficients; +// CHECK-NEXT: int *_d___begin1 = _d___range1; +// CHECK-NEXT: int *__begin10 = __range10; +// CHECK-NEXT: int *__end10 = __range10 + {{3|3L}}; +// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) { +// CHECK-NEXT: int &_d_i = *_d___begin1; +// CHECK-NEXT: int &i = *__begin10; +// CHECK-NEXT: if (i % 2 == 0) +// CHECK-NEXT: continue; +// CHECK-NEXT: double _t0 = x * y; +// CHECK-NEXT: _d_res += (_d_x * y + x * _d_y) * i + _t0 * _d_i; +// CHECK-NEXT: res += _t0 * i; +// CHECK-NEXT: } +// CHECK-NEXT: return _d_res; +// CHECK-NEXT: } + + #define TEST(fn)\ auto d_##fn = clad::differentiate(fn, "i");\ @@ -652,4 +752,13 @@ int main() { INIT_DIFFERENTIATE(fn18, "u"); TEST_DIFFERENTIATE(fn18, 3, 5); // CHECK-EXEC: {6.00} + + clad::differentiate(fn19, 0); + printf("Result is = %.2f\n", fn19_darg0(5, 2)); // CHECK-EXEC: Result is = 14.00 + + clad::differentiate(fn20, 0); + printf("Result is = %.2f\n", fn20_darg0(5)); // CHECK-EXEC: Result is = 6.00 + + clad::differentiate(fn21, 0); + printf("Result is = %.2f\n", fn21_darg0(5, 1)); // CHECK-EXEC: Result is = 10.00 }