Skip to content

Commit

Permalink
Adding support for range-based for loops in the reverse mode.
Browse files Browse the repository at this point in the history
Fixes:#723
  • Loading branch information
Max Andriychuk authored and Max Andriychuk committed Jul 26, 2024
1 parent 6ba607f commit 1d085ca
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ namespace clad {
virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
StmtDiff VisitDeclStmt(const clang::DeclStmt* DS);
StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL);
StmtDiff VisitCXXForRangeStmt(const clang::CXXForRangeStmt* FRS);
StmtDiff VisitForStmt(const clang::ForStmt* FS);
StmtDiff VisitIfStmt(const clang::IfStmt* If);
StmtDiff VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE);
Expand Down
81 changes: 81 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,87 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(condExpr, ResultRef);
}

StmtDiff
ReverseModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) {
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
beginBlock(direction::reverse);

LoopCounter loopCounter(*this);
const auto* RangeDecl = FRS->getRangeStmt();
const auto* BeginDecl = FRS->getBeginStmt();
StmtDiff VisitRange = Visit(RangeDecl);
StmtDiff VisitBegin = Visit(BeginDecl);
Expr* BeginExpr = cast<BinaryOperator>(VisitBegin.getStmt())->getLHS();

beginBlock(direction::reverse);
// Create all declarations needed.
Expr* d_BeginDeclRef = m_Variables[cast<DeclRefExpr>(BeginExpr)->getDecl()];

auto* RangeExpr =
cast<DeclRefExpr>(cast<BinaryOperator>(VisitRange.getStmt())->getLHS());
auto* BeginDeclRef = cast<DeclRefExpr>(BeginExpr);

auto* RangeInit = Clone(FRS->getRangeInit());
Expr* AssignRange =
BuildOp(BO_Assign, RangeExpr, BuildOp(UO_AddrOf, RangeInit));
Expr* AssignBegin =
BuildOp(BO_Assign, BeginDeclRef, BuildOp(UO_Deref, RangeExpr));
addToCurrentBlock(AssignRange);
addToCurrentBlock(AssignBegin);
const auto* EndDecl = cast<VarDecl>(FRS->getEndStmt()->getSingleDecl());

Expr* EndInit = cast<BinaryOperator>(EndDecl->getInit())->getRHS();
QualType EndType = CloneType(EndDecl->getType());
std::string EndName = EndDecl->getNameAsString();
Expr* EndAssign = BuildOp(BO_Add, BuildOp(UO_Deref, RangeExpr), EndInit);
VarDecl* EndVarDecl =
BuildGlobalVarDecl(EndType, EndName, EndAssign, false);
DeclStmt* AssignEnd = BuildDeclStmt(EndVarDecl);

addToCurrentBlock(AssignEnd);
DeclRefExpr* EndExpr =
BuildDeclRef(cast<VarDecl>(cast<DeclStmt>(AssignEnd)->getSingleDecl()));

llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop);
isInsideLoop = true;
const VarDecl* LoopVD = FRS->getLoopVariable();
Expr* IncBegin = BuildOp(UO_PreInc, BeginDeclRef);
Expr* d_IncBegin = BuildOp(UO_PreInc, d_BeginDeclRef);

Expr* d_DecBegin = BuildOp(UO_PostDec, d_BeginDeclRef);
Expr* ForwardCond = BuildOp(BO_NE, BeginDeclRef, EndExpr);
// Add item assignment statement to the body.
Stmt* bodyClone = Clone(FRS->getBody());
CompoundStmt* body = utils::PrependAndCreateCompoundStmt(
m_Sema.getASTContext(), bodyClone,
BuildDeclStmt(const_cast<VarDecl*>(LoopVD)));
StmtDiff BodyDiff = DifferentiateLoopBody(
body, loopCounter, /*condVarDiff*/ nullptr, d_DecBegin,
/*isForLoop=*/true);

Expr* Inc = BuildOp(BO_Comma, IncBegin, d_IncBegin);
Stmt* Forward = new (m_Context)
ForStmt(m_Context, /*Init=*/nullptr, ForwardCond, /*CondVar=*/nullptr,
Inc, BodyDiff.getStmt(), FRS->getForLoc(), FRS->getBeginLoc(),
FRS->getEndLoc());
Expr* CounterCondition =
loopCounter.getCounterConditionResult().get().second;
Expr* CounterDecrement = loopCounter.getCounterDecrement();
Stmt* Reverse = BodyDiff.getStmt_dx();
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
Reverse = new (m_Context)
ForStmt(m_Context, /*Init=*/nullptr, CounterCondition,
/*CondVar=*/nullptr, CounterDecrement, Reverse,
FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc());
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
endScope();
return {utils::unwrapIfSingleStmt(Forward),
utils::unwrapIfSingleStmt(Reverse)};
}

StmtDiff ReverseModeVisitor::VisitForStmt(const ForStmt* FS) {
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
Expand Down
63 changes: 63 additions & 0 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -2732,6 +2732,68 @@ double fn33(double i, double j) {
//CHECK-NEXT: }
//CHECK-NEXT:}

double fn34(double x, double y){
double r = 0;
double a[] = {y, x*y, x*x + y};
for(auto& i: a){
r+=i;
}
return r;
}

// CHECK: void fn34_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: double _d_r = 0;
// CHECK-NEXT: double _d_a[3] = {0};
// CHECK-NEXT: unsigned {{int|long}} _t0;
// CHECK-NEXT: double (*_d___range1)[3] = 0;
// CHECK-NEXT: double (*__range10)[3] = {};
// CHECK-NEXT: double *_d___begin1 = 0;
// CHECK-NEXT: double *__begin10 = 0;
// CHECK-NEXT: clad::tape<double *> _t1 = {};
// CHECK-NEXT: clad::tape<double *> _t2 = {};
// CHECK-NEXT: double *_d_i = 0;
// CHECK-NEXT: double *i = {};
// CHECK-NEXT: clad::tape<double> _t3 = {};
// CHECK-NEXT: double r = 0;
// CHECK-NEXT: double a[3] = {y, x * y, x * x + y};
// CHECK-NEXT: _t0 = {{0U|0UL}};
// CHECK-NEXT: _d___range1 = &_d_a;
// CHECK-NEXT: _d___begin1 = *_d___range1;
// CHECK-NEXT: __range10 = &a;
// CHECK-NEXT: __begin10 = *__range10;
// CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}};
// CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: _d_i = &*_d___begin1;
// CHECK-NEXT: clad::push(_t1, _d_i);
// CHECK-NEXT: clad::push(_t2, i) , i = &*__begin10;
// CHECK-NEXT: clad::push(_t3, r);
// CHECK-NEXT: r += *i;
// CHECK-NEXT: }
// CHECK-NEXT: _d_r += 1;
// CHECK-NEXT: for (; _t0; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: _d___begin1--;
// CHECK-NEXT: _d_i = clad::pop(_t1);
// CHECK-NEXT: {
// CHECK-NEXT: r = clad::pop(_t3);
// CHECK-NEXT: double _r_d0 = _d_r;
// CHECK-NEXT: *_d_i += _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: i = clad::pop(_t2);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: *_d_y += _d_a[0];
// CHECK-NEXT: *_d_x += _d_a[1] * y;
// CHECK-NEXT: *_d_y += x * _d_a[1];
// CHECK-NEXT: *_d_x += _d_a[2] * x;
// CHECK-NEXT: *_d_x += x * _d_a[2];
// CHECK-NEXT: *_d_y += _d_a[2];
// CHECK-NEXT: }
// CHECK-NEXT: }


#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient(F);\
Expand Down Expand Up @@ -2816,6 +2878,7 @@ int main() {
TEST_2(fn32, 3, 5); // CHECK-EXEC: {45.00, 27.00}
TEST_2(fn33, 3, 5); // CHECK-EXEC: {15.00, 9.00}

TEST_2(fn34, 5, 2); // CHECK-EXEC: {12.00, 7.00}
}

//CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {
Expand Down

0 comments on commit 1d085ca

Please sign in to comment.