Skip to content

Commit

Permalink
Differentiate for loop conditions in reverse mode (vgvassilev#746)
Browse files Browse the repository at this point in the history
This change differentiates the loop condition expression.
Additionaly if in forward pass a loop terminates pre-maturely due to break stmt.
The reverse pass should start differentiation with break statment and
not the loop condition differentiation. This change keeps track of whether the break
was called in forward pass and based on that in reverse mode it is decided
whether the loop differentiation is skipped for the first iteration or not.
  • Loading branch information
rohanjulka19 committed Jul 3, 2024
1 parent 645d2b6 commit e604ae7
Show file tree
Hide file tree
Showing 14 changed files with 1,508 additions and 511 deletions.
9 changes: 6 additions & 3 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ namespace clad {
/// Output variable of vector-valued function
std::string outputArrayStr;
std::vector<Stmts> m_LoopBlock;
clang::Expr* m_CurrentBreakFlagExpr;

unsigned outputArrayCursor = 0;
unsigned numParams = 0;
// FIXME: Should we make this an object instead of a pointer?
Expand Down Expand Up @@ -561,9 +563,6 @@ namespace clad {

ReverseModeVisitor& m_RMV;

const bool m_IsInvokedBySwitchStmt = false;
/// Builds and returns a literal expression of type `std::size_t` with
/// `value` as value.
clang::Expr* CreateSizeTLiteralExpr(std::size_t value);

/// Initialise the `m_ControlFlowTape`.
Expand All @@ -576,6 +575,8 @@ namespace clad {
clang::Expr* CreateCFTapePushExpr(std::size_t value);

public:
bool m_IsInvokedBySwitchStmt = false;

BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false)
: m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {}

Expand All @@ -598,6 +599,8 @@ namespace clad {
/// by their actual values respectively.
clang::Stmt* CreateCFTapePushExprToCurrentCase();

clang::Expr* CreateCFTapeBackExprForCurrentCase();

/// Does final modifications on forward and reverse blocks
/// so that `break` and `continue` statements are handled
/// accurately.
Expand Down
105 changes: 92 additions & 13 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff thenDiff = VisitBranch(If->getThen());
StmtDiff elseDiff = VisitBranch(If->getElse());

Stmt* Forward = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr, /*Var=*/nullptr,
condDiffStored, noLoc, noLoc, thenDiff.getStmt(), noLoc,
Expand Down Expand Up @@ -992,6 +991,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Scope::ContinueScope);
beginBlock(direction::reverse);
LoopCounter loopCounter(*this);
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;
const Stmt* init = FS->getInit();
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingLoopInitStmt();
Expand All @@ -1000,7 +1002,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Save the isInsideLoop value (we may be inside another loop).
llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop);
isInsideLoop = true;

StmtDiff condVarRes;
VarDecl* condVarClone = nullptr;
if (FS->getConditionVariable()) {
Expand All @@ -1011,11 +1012,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

// FIXME: for now we assume that cond has no differentiable effects,
// but it is not generally true, e.g. for (...; (x = y); ...)...
StmtDiff cond;
StmtDiff condDiff;
StmtDiff condExprDiff;
if (FS->getCond())
cond = Visit(FS->getCond());
std::tie(condDiff, condExprDiff) = DifferentiateSingleExpr(FS->getCond());

const auto* IDRE = dyn_cast<DeclRefExpr>(FS->getInc());
const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc();

Expand Down Expand Up @@ -1063,7 +1065,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from
/// overwriting local variables on different loop passes.
Expr* forwardCond = cond.getExpr();
Expr* forwardCond = condExprDiff.getExpr();
/// If there is a declaration in the condition, `cond` will be
/// a DeclRefExpr of the declared variable. There is no point in
/// inserting it since condVarRes.getExpr() represents an assignment with
Expand All @@ -1073,8 +1075,33 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (condVarRes.getExpr() != nullptr && isa<Expr>(condVarRes.getExpr()))
forwardCond = cast<Expr>(condVarRes.getExpr());

Stmt* breakStmt = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get();

beginBlock();
if (utils::unwrapIfSingleStmt(condDiff.getStmt()))
addToCurrentBlock(condDiff.getStmt());

Stmt* IfStmt = clad_compat::IfStmt_Create(
/*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
/*Init=*/nullptr, /*Var=*/nullptr,
/*Cond=*/
BuildOp(clang::UnaryOperatorKind::UO_LNot, BuildParens(forwardCond)),
/*LPL=*/noLoc, /*RPL=*/noLoc,
/*Then=*/breakStmt,
/*EL=*/noLoc,
/*Else=*/nullptr);
addToCurrentBlock(IfStmt);

Stmt* forwardCondStmts = endBlock();
if (BodyDiff.getStmt()) {
BodyDiff.updateStmt(utils::PrependAndCreateCompoundStmt(
m_Context, BodyDiff.getStmt(), forwardCondStmts));
} else {
BodyDiff.updateStmt(utils::unwrapIfSingleStmt(forwardCondStmts));
}

Stmt* Forward = new (m_Context)
ForStmt(m_Context, initResult.getStmt(), forwardCond, condVarClone,
ForStmt(m_Context, initResult.getStmt(), nullptr, condVarClone,
incResult, BodyDiff.getStmt(), noLoc, noLoc, noLoc);

// Create a condition testing counter for being zero, and its decrement.
Expand All @@ -1084,12 +1111,44 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CounterCondition = loopCounter.getCounterConditionResult().get().second;
Expr* CounterDecrement = loopCounter.getCounterDecrement();

Stmt* ReverseResult = BodyDiff.getStmt_dx();
if (!ReverseResult)
ReverseResult = new (m_Context) NullStmt(noLoc);
beginBlock(direction::reverse);
Stmt* RevIfStmt = clad_compat::IfStmt_Create(
/*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
/*Init=*/nullptr, /*Var=*/nullptr,
/*Cond=*/BuildOp(clang::UnaryOperatorKind::UO_LNot, CounterCondition),
/*LPL=*/noLoc, /*RPL=*/noLoc,
/*Then=*/Clone(breakStmt),
/*EL=*/noLoc,
/*Else=*/nullptr);
addToCurrentBlock(RevIfStmt, direction::reverse);

if (condDiff.getStmt_dx()) {
if (m_CurrentBreakFlagExpr) {
Expr* loopBreakFlagCond =
BuildOp(BinaryOperatorKind::BO_LOr,
BuildOp(UnaryOperatorKind::UO_LNot, CounterCondition),
m_CurrentBreakFlagExpr);
auto* RevIfStmt = clad_compat::IfStmt_Create(
m_Context, noLoc, false, nullptr, nullptr, loopBreakFlagCond, noLoc,
noLoc, condDiff.getStmt_dx(), noLoc, nullptr);
addToCurrentBlock(RevIfStmt, direction::reverse);
} else {
addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse);
}
}

Stmt* revPassCondStmts = endBlock(direction::reverse);
if (BodyDiff.getStmt_dx()) {
BodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
m_Context, BodyDiff.getStmt_dx(), revPassCondStmts));
} else {
BodyDiff.updateStmtDx(utils::unwrapIfSingleStmt(revPassCondStmts));
}

Stmt* Reverse = new (m_Context)
ForStmt(m_Context, nullptr, CounterCondition, nullptr, CounterDecrement,
ReverseResult, noLoc, noLoc, noLoc);
ForStmt(m_Context, nullptr, nullptr, nullptr, CounterDecrement,
BodyDiff.getStmt_dx(), noLoc, noLoc, noLoc);

addToCurrentBlock(initResult.getStmt_dx(), direction::reverse);
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
Expand Down Expand Up @@ -2391,8 +2450,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else if (opCode == BO_Comma) {
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Ldiff = Visit(L, zero);
Rdiff = Visit(R, dfdx());
Ldiff = Visit(L, zero);
valueForRevPass = Ldiff.getRevSweepAsExpr();
ResultRef = Ldiff.getExpr();
} else if (opCode == BO_LAnd) {
Expand Down Expand Up @@ -3546,6 +3605,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Stmt* CFCaseStmt = activeBreakContHandler->GetNextCFCaseStmt();
Stmt* pushExprToCurrentCase = activeBreakContHandler
->CreateCFTapePushExprToCurrentCase();
if (isInsideLoop && !activeBreakContHandler->m_IsInvokedBySwitchStmt) {
Expr* tapeBackExprForCurrentCase =
activeBreakContHandler->CreateCFTapeBackExprForCurrentCase();
if (m_CurrentBreakFlagExpr) {
m_CurrentBreakFlagExpr =
BuildOp(BinaryOperatorKind::BO_LOr, m_CurrentBreakFlagExpr,
tapeBackExprForCurrentCase);

} else {
m_CurrentBreakFlagExpr = tapeBackExprForCurrentCase;
}
}
addToCurrentBlock(pushExprToCurrentCase);
addToCurrentBlock(newBS);
return {endBlock(direction::forward), CFCaseStmt};
Expand Down Expand Up @@ -3607,6 +3678,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return CS;
}

Expr* ReverseModeVisitor::BreakContStmtHandler::
CreateCFTapeBackExprForCurrentCase() {
return m_RMV.BuildOp(
BinaryOperatorKind::BO_NE, m_ControlFlowTape->Last(),
ConstantFolder::synthesizeLiteral(m_RMV.m_Context.IntTy,
m_RMV.m_Context, m_CaseCounter));
}

Stmt* ReverseModeVisitor::BreakContStmtHandler::
CreateCFTapePushExprToCurrentCase() {
if (!m_ControlFlowTape)
Expand Down
Loading

0 comments on commit e604ae7

Please sign in to comment.