Skip to content

Commit

Permalink
Adding support for range-based for loop
Browse files Browse the repository at this point in the history
Adding support for the most simple cases of the range-based for loops.

Fixes: vgvassilev#723
  • Loading branch information
Max Andriychuk authored and Max Andriychuk committed Jun 24, 2024
1 parent a0b29f6 commit 7a240c3
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class BaseForwardModeVisitor
set to 0",
args);
}
StmtDiff VisitCXXForRangeStmt(const clang::CXXForRangeStmt* FRS);
StmtDiff VisitWhileStmt(const clang::WhileStmt* WS);
StmtDiff VisitDoStmt(const clang::DoStmt* DS);
StmtDiff VisitContinueStmt(const clang::ContinueStmt* ContStmt);
Expand Down
70 changes: 70 additions & 0 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,76 @@ StmtDiff BaseForwardModeVisitor::VisitConditionalOperator(
return StmtDiff(condExpr, condExprDiff);
}

StmtDiff
BaseForwardModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) {
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
// Visiting for range-based ststement produces __range1, __begin1 and __end1
// variables, so for(auto i: a){
// ...
//}
//
// is equivalent to
//
// auto&& __range1 = a
// auto __begin1 = __range1;
// auto __end1 = __range1 + OUL
// for(;__begin != __end1; ++__begin){
// auto i = *__begin1;
// ...
//}
//
const VarDecl* Item = FRS->getLoopVariable();

const auto* RangeDecl = cast<VarDecl>(FRS->getRangeStmt()->getSingleDecl());
const auto* BeginDecl = cast<VarDecl>(FRS->getBeginStmt()->getSingleDecl());
const auto* EndDecl = cast<VarDecl>(FRS->getEndStmt()->getSingleDecl());

// Add range, begin and their's adjoints to the current block.
DeclDiff<VarDecl> RangeDeclDiff = DifferentiateVarDecl(RangeDecl);
addToCurrentBlock(BuildDeclStmt(RangeDeclDiff.getDecl_dx()));
addToCurrentBlock(BuildDeclStmt(RangeDeclDiff.getDecl()));

DeclDiff<VarDecl> BeginDeclDiff = DifferentiateVarDecl(BeginDecl);
addToCurrentBlock(BuildDeclStmt(BeginDeclDiff.getDecl_dx()));
addToCurrentBlock(BuildDeclStmt(BeginDeclDiff.getDecl()));

VarDecl* EndDeclDiff = DifferentiateVarDecl(EndDecl).getDecl();
addToCurrentBlock(BuildDeclStmt(EndDeclDiff));

// Build begin preincrementation.
DeclRefExpr* BeginExpr = BuildDeclRef(BeginDeclDiff.getDecl());
Expr* IncBegin = BuildOp(UO_PreInc, BeginExpr);

// Build begin preincrementation.
DeclRefExpr* d_BeginExpr = BuildDeclRef(BeginDeclDiff.getDecl_dx());
Expr* d_IncBegin = BuildOp(UO_PreInc, d_BeginExpr);

Expr* Inc = BuildOp(BO_Comma, d_IncBegin, IncBegin);

// Build begin != end expretion.
DeclRefExpr* EndExpr = BuildDeclRef(EndDeclDiff);
Expr* cond = BuildOp(BO_NE, BeginExpr, EndExpr);

DeclDiff<VarDecl> ItemDiff = DifferentiateVarDecl(Item);

// Differentiate body and add both Item and it's derivative.
const Stmt* body = FRS->getBody();
Stmt* bodyResult = nullptr;
bodyResult = Visit(body).getStmt();
Stmt* bodyWithItem1 = utils::PrependAndCreateCompoundStmt(
m_Sema.getASTContext(), bodyResult, BuildDeclStmt(ItemDiff.getDecl()));
bodyResult =
utils::PrependAndCreateCompoundStmt(m_Sema.getASTContext(), bodyWithItem1,
BuildDeclStmt(ItemDiff.getDecl_dx()));

Stmt* forStmtDiff =
new (m_Context) ForStmt(m_Context, nullptr, cond, /*condVar=*/nullptr,
Inc, bodyResult, noLoc, noLoc, noLoc);

return StmtDiff(forStmtDiff);
}

StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
Expand Down
112 changes: 112 additions & 0 deletions test/FirstDerivative/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,109 @@ double fn17_darg0(double x);
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }

double fn18(double x, double y){
int coefficients[3] = {4, 7, 3};
double res = 0;
for(auto i: coefficients){
if(i%2==0)
continue;
res+= x*y*i;
}
return res;
}

double fn18_darg0(double x, double y);
// CHECK: double fn18_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: int _d_coefficients[3] = {0, 0, 0};
// CHECK-NEXT: int coefficients[3] = {4, 7, 3};
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: int (&_d___range1)[3] = _d_coefficients;
// CHECK-NEXT: int (&__range10)[3] = coefficients;
// CHECK-NEXT: int *_d___begin1 = _d___range1;
// CHECK-NEXT: int *__begin10 = __range10;
// CHECK-NEXT: int *__end10 = __range10 + {{3|3L}};
// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) {
// CHECK-NEXT: int _d_i = *_d___begin1;
// CHECK-NEXT: int i = *__begin10;
// CHECK-NEXT: if (i % 2 == 0)
// CHECK-NEXT: continue;
// CHECK-NEXT: double _t0 = x * y;
// CHECK-NEXT: _d_res += (_d_x * y + x * _d_y) * i + _t0 * _d_i;
// CHECK-NEXT: res += _t0 * i;
// CHECK-NEXT: }
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }


double fn19(double x, double y){
double res = 0;
double f[] = {x*x, 2*x*y, y*y, x, y};
for(auto i: f){
if(i == x)
break;
res+=i;
}
return res;
}

double fn19_darg0(double x, double y);
// CHECK: double fn19_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: double _t0 = 2 * x;
// CHECK-NEXT: double _d_f[5] = {_d_x * x + x * _d_x, (0 * x + 2 * _d_x) * y + _t0 * _d_y, _d_y * y + y * _d_y, _d_x, _d_y};
// CHECK-NEXT: double f[5] = {x * x, _t0 * y, y * y, x, y};
// CHECK-NEXT: double (&_d___range1)[5] = _d_f;
// CHECK-NEXT: double (&__range10)[5] = f;
// CHECK-NEXT: double *_d___begin1 = _d___range1;
// CHECK-NEXT: double *__begin10 = __range10;
// CHECK-NEXT: double *__end10 = __range10 + {{5|5L}};
// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) {
// CHECK-NEXT: double _d_i = *_d___begin1;
// CHECK-NEXT: double i = *__begin10;
// CHECK-NEXT: if (i == x)
// CHECK-NEXT: break;
// CHECK-NEXT: _d_res += _d_i;
// CHECK-NEXT: res += i;
// CHECK-NEXT: }
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }

double fn20(double x){
int a[] = {5};
for(auto i: a){
x+=i*x;
}
return x;
}

// CHECK: double fn20_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: int _d_a[1] = {0};
// CHECK-NEXT: int a[1] = {5};
// CHECK-NEXT: int (&_d___range1)[1] = _d_a;
// CHECK-NEXT: int (&__range10)[1] = a;
// CHECK-NEXT: int *_d___begin1 = _d___range1;
// CHECK-NEXT: int *__begin10 = __range10;
// CHECK-NEXT: int *__end10 = __range10 + {{1|1L}};
// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) {
// CHECK-NEXT: int _d_i = *_d___begin1;
// CHECK-NEXT: int i = *__begin10;
// CHECK-NEXT: _d_x += _d_i * x + i * _d_x;
// CHECK-NEXT: x += i * x;
// CHECK-NEXT: }
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }


double fn20_darg0(double x);


#define TEST(fn)\
auto d_##fn = clad::differentiate(fn, "i");\
printf("%.2f\n", d_##fn.execute(3, 5));
Expand Down Expand Up @@ -614,4 +717,13 @@ int main() {

clad::differentiate(fn17, 0);
printf("Result is = %.2f\n", fn17_darg0(5)); // CHECK-EXEC: Result is = 0

clad::differentiate(fn18, 0);
printf("Result is = %.2f\n", fn18_darg0(5, 1)); // CHECK-EXEC: Result is = 10.00

clad::differentiate(fn19, 0);
printf("Result is = %.2f\n", fn19_darg0(5, 2)); // CHECK-EXEC: Result is = 14.00

clad::differentiate(fn20, 0);
printf("Result is = %.2f\n", fn20_darg0(5)); // CHECK-EXEC: Result is = 6.00
}

0 comments on commit 7a240c3

Please sign in to comment.