Skip to content

Commit

Permalink
Differentiate for loop condition expression (vgvassilev#746)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohanjulka19 committed Mar 15, 2024
1 parent 5736df6 commit 442f39b
Show file tree
Hide file tree
Showing 4 changed files with 447 additions and 15 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ namespace clad {
LoopCounter& loopCounter,
clang::Stmt* condVarDifff = nullptr,
clang::Stmt* forLoopIncDiff = nullptr,
clang::Stmt* condDiff = nullptr,
bool isForLoop = false);

/// This class modifies forward and reverse blocks of the loop/switch
Expand Down
45 changes: 36 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1056,11 +1056,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

auto CommaJoin = [this](Expr* Acc, Stmt* S) {
Expr* E = cast<Expr>(S);
return BuildOp(BO_Comma, E, BuildParens(Acc));
};

// FIXME: for now we assume that cond has no differentiable effects,
// but it is not generally true, e.g. for (...; (x = y); ...)...
StmtDiff cond;
StmtDiff condDiff;
StmtDiff condExprDiff;
StmtDiff condDiffOuter;
StmtDiff condExprDiffOuter;
if (FS->getCond())
cond = Visit(FS->getCond());
std::tie(condDiff, condExprDiff) = DifferentiateSingleExpr(FS->getCond());
std::tie(condDiffOuter, condExprDiffOuter) = DifferentiateSingleExpr(FS->getCond());

const auto* IDRE = dyn_cast<DeclRefExpr>(FS->getInc());
const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc();

Expand Down Expand Up @@ -1088,10 +1098,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
// Otherwise, join all exprs by comma operator.
else if (incExprDiff.getExpr()) {
auto CommaJoin = [this](Expr* Acc, Stmt* S) {
Expr* E = cast<Expr>(S);
return BuildOp(BO_Comma, E, BuildParens(Acc));
};
incResult = std::accumulate(Additional->body_rbegin(),
Additional->body_rend(),
incExprDiff.getExpr(),
Expand All @@ -1102,13 +1108,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff BodyDiff = DifferentiateLoopBody(body, loopCounter,
condVarRes.getStmt_dx(),
incDiff.getStmt_dx(),
condDiff.getStmt_dx(),
/*isForLoop=*/true);

/// FIXME: This part in necessary to replace local variables inside loops
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from
/// overwriting local variables on different loop passes.
Expr* forwardCond = cond.getExpr();
Expr* forwardCond = condExprDiff.getExpr();
/// If there is a declaration in the condition, `cond` will be
/// a DeclRefExpr of the declared variable. There is no point in
/// inserting it since condVarRes.getExpr() represents an assignment with
Expand All @@ -1118,8 +1125,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (condVarRes.getExpr() != nullptr && isa<Expr>(condVarRes.getExpr()))
forwardCond = cast<Expr>(condVarRes.getExpr());


auto* AdditionalStmts = cast<CompoundStmt>(condDiff.getStmt());
Expr* condResult = std::accumulate(AdditionalStmts->body_rbegin(),
AdditionalStmts->body_rend(),
forwardCond,
CommaJoin);


Stmt* Forward = new (m_Context)
ForStmt(m_Context, initResult.getStmt(), forwardCond, condVarClone,
ForStmt(m_Context, initResult.getStmt(), condResult, condVarClone,
incResult, BodyDiff.getStmt(), noLoc, noLoc, noLoc);

// Create a condition testing counter for being zero, and its decrement.
Expand All @@ -1141,10 +1156,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
noLoc,
noLoc,
noLoc);
addToCurrentBlock(unwrapIfSingleStmt(condDiffOuter.getStmt()));
addToCurrentBlock(Forward, direction::forward);
Forward = endBlock(direction::forward);
addToCurrentBlock(loopCounter.getPop(), direction::reverse);
addToCurrentBlock(initResult.getStmt_dx(), direction::reverse);
addToCurrentBlock(condDiffOuter.getStmt_dx(), direction::reverse);
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
endScope();
Expand Down Expand Up @@ -2580,8 +2597,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else if (opCode == BO_Comma) {
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Ldiff = Visit(L, zero);
Rdiff = Visit(R, dfdx());
Ldiff = Visit(L, zero);
valueForRevPass = Ldiff.getRevSweepAsExpr();
ResultRef = Ldiff.getExpr();
} else {
Expand Down Expand Up @@ -3606,6 +3623,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
LoopCounter& loopCounter,
Stmt* condVarDiff,
Stmt* forLoopIncDiff,
Stmt* condDiff,
bool isForLoop) {
Expr* counterIncrement = loopCounter.getCounterIncrement();
auto* activeBreakContHandler = PushBreakContStmtHandler();
Expand Down Expand Up @@ -3656,6 +3674,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

if (condDiff) {
if (bodyDiff.getStmt_dx()) {
bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
m_Context, bodyDiff.getStmt_dx(), condDiff));
} else {
bodyDiff.updateStmtDx(condDiff);
}
}

activeBreakContHandler->EndCFSwitchStmtScope();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();
Expand Down
12 changes: 6 additions & 6 deletions test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -852,12 +852,12 @@ double f21 (double x, double y) {
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_y += 1;
//CHECK-NEXT: {
//CHECK-NEXT: y = _t0;
//CHECK-NEXT: double _r_d0 = * _d_y;
//CHECK-NEXT: * _d_y -= _r_d0;
//CHECK-NEXT: * _d_y += 0;
//CHECK-NEXT: y--;
//CHECK-NEXT: * _d_x += _r_d0;
// CHECK-NEXT: y = _t0;
// CHECK-NEXT: double _r_d0 = * _d_y;
// CHECK-NEXT: * _d_y -= _r_d0;
// CHECK-NEXT: * _d_x += _r_d0;
// CHECK-NEXT: * _d_y += 0;
// CHECK-NEXT: y--;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down
Loading

0 comments on commit 442f39b

Please sign in to comment.