Skip to content

Commit

Permalink
Adding support for range-based for loop in forward mode
Browse files Browse the repository at this point in the history
Adding support for the most simple cases of the range-based for loops in forward mode.

Fixes: vgvassilev#723
  • Loading branch information
Max Andriychuk authored and Max Andriychuk committed Jul 26, 2024
1 parent 69d15a3 commit fe9ad04
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class BaseForwardModeVisitor
set to 0",
args);
}
StmtDiff VisitCXXForRangeStmt(const clang::CXXForRangeStmt* FRS);
StmtDiff VisitWhileStmt(const clang::WhileStmt* WS);
StmtDiff VisitDoStmt(const clang::DoStmt* DS);
StmtDiff VisitContinueStmt(const clang::ContinueStmt* ContStmt);
Expand Down
65 changes: 65 additions & 0 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,71 @@ StmtDiff BaseForwardModeVisitor::VisitConditionalOperator(
return StmtDiff(condExpr, condExprDiff);
}

StmtDiff
BaseForwardModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) {
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
// Visiting for range-based ststement produces __range1, __begin1 and __end1
// variables, so for(auto i: a){
// ...
//}
//
// is equivalent to
//
// auto&& __range1 = a
// auto __begin1 = __range1;
// auto __end1 = __range1 + OUL
// for(;__begin != __end1; ++__begin){
// auto i = *__begin1;
// ...
//}
const auto* RangeDecl = FRS->getRangeStmt();
const auto* BeginDecl = FRS->getBeginStmt();
const auto* EndDecl = FRS->getEndStmt();

StmtDiff VisitRange = Visit(RangeDecl);
StmtDiff VisitBegin = Visit(BeginDecl);
StmtDiff VisitEnd = Visit(EndDecl);
addToCurrentBlock(VisitRange.getStmt_dx());
addToCurrentBlock(VisitRange.getStmt());
addToCurrentBlock(VisitBegin.getStmt_dx());
addToCurrentBlock(VisitBegin.getStmt());
addToCurrentBlock(VisitEnd.getStmt());
// Build d_begin preincrementation.

auto* BeginAdjExpr = BuildDeclRef(
cast<VarDecl>(cast<DeclStmt>(VisitBegin.getStmt_dx())->getSingleDecl()));
// Build begin preincrementation.

Expr* IncAdjBegin = BuildOp(UO_PreInc, BeginAdjExpr);
auto* BeginExpr = BuildDeclRef(
cast<VarDecl>(cast<DeclStmt>(VisitBegin.getStmt())->getSingleDecl()));
Expr* IncBegin = BuildOp(UO_PreInc, BeginExpr);
Expr* Inc = BuildOp(BO_Comma, IncAdjBegin, IncBegin);

auto* EndExpr = BuildDeclRef(
cast<VarDecl>(cast<DeclStmt>(VisitEnd.getStmt())->getSingleDecl()));
// Build begin != end condition.
Expr* cond = BuildOp(BO_NE, BeginExpr, EndExpr);

const VarDecl* Item = FRS->getLoopVariable();
DeclDiff<VarDecl> ItemDiff = DifferentiateVarDecl(Item);
// Differentiate body and add both Item and it's derivative.
const Stmt* body = FRS->getBody();
Stmt* bodyResult = nullptr;
bodyResult = Visit(Clone(body)).getStmt();
Stmt* bodyWithItem = utils::PrependAndCreateCompoundStmt(
m_Sema.getASTContext(), bodyResult, BuildDeclStmt(ItemDiff.getDecl()));
bodyResult =
utils::PrependAndCreateCompoundStmt(m_Sema.getASTContext(), bodyWithItem,
BuildDeclStmt(ItemDiff.getDecl_dx()));

Stmt* forStmtDiff =
new (m_Context) ForStmt(m_Context, nullptr, cond, /*condVar=*/nullptr,
Inc, bodyResult, noLoc, noLoc, noLoc);
return StmtDiff(forStmtDiff);
}

StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
Expand Down
109 changes: 109 additions & 0 deletions test/FirstDerivative/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,106 @@ double fn18(double u, double v) {
// CHECK-NEXT: }


double fn19(double x, double y){
double res = 0;
double f[] = {x*x, 2*x*y, y*y, x, y};
for(auto i: f){
if(i == x)
break;
res+=i;
}
return res;
}

double fn19_darg0(double x, double y);
// CHECK: double fn19_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: double _t0 = 2 * x;
// CHECK-NEXT: double _d_f[5] = {_d_x * x + x * _d_x, (0 * x + 2 * _d_x) * y + _t0 * _d_y, _d_y * y + y * _d_y, _d_x, _d_y};
// CHECK-NEXT: double f[5] = {x * x, _t0 * y, y * y, x, y};
// CHECK-NEXT: double (&_d___range1)[5] = _d_f;
// CHECK-NEXT: double (&__range10)[5] = f;
// CHECK-NEXT: double *_d___begin1 = _d___range1;
// CHECK-NEXT: double *__begin10 = __range10;
// CHECK-NEXT: double *__end10 = __range10 + {{5|5L}};
// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) {
// CHECK-NEXT: double _d_i = *_d___begin1;
// CHECK-NEXT: double i = *__begin10;
// CHECK-NEXT: if (i == x)
// CHECK-NEXT: break;
// CHECK-NEXT: _d_res += _d_i;
// CHECK-NEXT: res += i;
// CHECK-NEXT: }
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }

double fn20(double x){
int a[] = {5};
for(auto i: a){
x+=i*x;
}
return x;
}

double fn20_darg0(double x);
// CHECK: double fn20_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: int _d_a[1] = {0};
// CHECK-NEXT: int a[1] = {5};
// CHECK-NEXT: int (&_d___range1)[1] = _d_a;
// CHECK-NEXT: int (&__range10)[1] = a;
// CHECK-NEXT: int *_d___begin1 = _d___range1;
// CHECK-NEXT: int *__begin10 = __range10;
// CHECK-NEXT: int *__end10 = __range10 + {{1|1L}};
// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) {
// CHECK-NEXT: int _d_i = *_d___begin1;
// CHECK-NEXT: int i = *__begin10;
// CHECK-NEXT: _d_x += _d_i * x + i * _d_x;
// CHECK-NEXT: x += i * x;
// CHECK-NEXT: }
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }

double fn21(double x, double y){
int coefficients[3] = {4, 7, 3};
double res = 0;
for(auto& i: coefficients){
if(i%2==0)
continue;
res+= x*y*i;
}
return res;
}

double fn21_darg0(double x, double y);
// CHECK: double fn21_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: int _d_coefficients[3] = {0, 0, 0};
// CHECK-NEXT: int coefficients[3] = {4, 7, 3};
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: int (&_d___range1)[3] = _d_coefficients;
// CHECK-NEXT: int (&__range10)[3] = coefficients;
// CHECK-NEXT: int *_d___begin1 = _d___range1;
// CHECK-NEXT: int *__begin10 = __range10;
// CHECK-NEXT: int *__end10 = __range10 + {{3|3L}};
// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) {
// CHECK-NEXT: int &_d_i = *_d___begin1;
// CHECK-NEXT: int &i = *__begin10;
// CHECK-NEXT: if (i % 2 == 0)
// CHECK-NEXT: continue;
// CHECK-NEXT: double _t0 = x * y;
// CHECK-NEXT: _d_res += (_d_x * y + x * _d_y) * i + _t0 * _d_i;
// CHECK-NEXT: res += _t0 * i;
// CHECK-NEXT: }
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }



#define TEST(fn)\
auto d_##fn = clad::differentiate(fn, "i");\
Expand Down Expand Up @@ -652,4 +752,13 @@ int main() {

INIT_DIFFERENTIATE(fn18, "u");
TEST_DIFFERENTIATE(fn18, 3, 5); // CHECK-EXEC: {6.00}

clad::differentiate(fn19, 0);
printf("Result is = %.2f\n", fn19_darg0(5, 2)); // CHECK-EXEC: Result is = 14.00

clad::differentiate(fn20, 0);
printf("Result is = %.2f\n", fn20_darg0(5)); // CHECK-EXEC: Result is = 6.00

clad::differentiate(fn21, 0);
printf("Result is = %.2f\n", fn21_darg0(5, 1)); // CHECK-EXEC: Result is = 10.00
}

0 comments on commit fe9ad04

Please sign in to comment.