Skip to content

Commit

Permalink
Differentiate for loop conditions in reverse mode (vgvassilev#746)
Browse files Browse the repository at this point in the history
Currently loop conditions are not differentiated in reverse mode.
This change differentiates the loop condition expression.
  • Loading branch information
rohanjulka19 committed Jul 2, 2024
1 parent 645d2b6 commit 1d0b23e
Show file tree
Hide file tree
Showing 14 changed files with 1,676 additions and 608 deletions.
12 changes: 9 additions & 3 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ namespace clad {
/// Output variable of vector-valued function
std::string outputArrayStr;
std::vector<Stmts> m_LoopBlock;
clang::Expr* m_CurrentBreakFlagExpr;

unsigned outputArrayCursor = 0;
unsigned numParams = 0;
// FIXME: Should we make this an object instead of a pointer?
Expand Down Expand Up @@ -516,6 +518,7 @@ namespace clad {
LoopCounter& loopCounter,
clang::Stmt* condVarDifff = nullptr,
clang::Stmt* forLoopIncDiff = nullptr,
StmtDiff* condDiff = nullptr,
bool isForLoop = false);

/// This class modifies forward and reverse blocks of the loop/switch
Expand Down Expand Up @@ -561,9 +564,6 @@ namespace clad {

ReverseModeVisitor& m_RMV;

const bool m_IsInvokedBySwitchStmt = false;
/// Builds and returns a literal expression of type `std::size_t` with
/// `value` as value.
clang::Expr* CreateSizeTLiteralExpr(std::size_t value);

/// Initialise the `m_ControlFlowTape`.
Expand All @@ -576,6 +576,10 @@ namespace clad {
clang::Expr* CreateCFTapePushExpr(std::size_t value);

public:
const bool m_IsInvokedBySwitchStmt = false;
/// Builds and returns a literal expression of type `std::size_t` with
/// `value` as value.

BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false)
: m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {}

Expand All @@ -598,6 +602,8 @@ namespace clad {
/// by their actual values respectively.
clang::Stmt* CreateCFTapePushExprToCurrentCase();

clang::Expr* CreateCFTapeBackExprForCurrentCase();

/// Does final modifications on forward and reverse blocks
/// so that `break` and `continue` statements are handled
/// accurately.
Expand Down
193 changes: 170 additions & 23 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTLambda.h"
#include "clang/AST/Expr.h"
#include "clang/AST/ExprCXX.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/TemplateBase.h"
#include "clang/Basic/TargetInfo.h"
Expand Down Expand Up @@ -880,7 +881,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff thenDiff = VisitBranch(If->getThen());
StmtDiff elseDiff = VisitBranch(If->getElse());

Stmt* Forward = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr, /*Var=*/nullptr,
condDiffStored, noLoc, noLoc, thenDiff.getStmt(), noLoc,
Expand Down Expand Up @@ -992,6 +992,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Scope::ContinueScope);
beginBlock(direction::reverse);
LoopCounter loopCounter(*this);
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;
const Stmt* init = FS->getInit();
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingLoopInitStmt();
Expand All @@ -1000,7 +1003,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Save the isInsideLoop value (we may be inside another loop).
llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop);
isInsideLoop = true;

StmtDiff condVarRes;
VarDecl* condVarClone = nullptr;
if (FS->getConditionVariable()) {
Expand All @@ -1011,11 +1013,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

// FIXME: for now we assume that cond has no differentiable effects,
// but it is not generally true, e.g. for (...; (x = y); ...)...
StmtDiff cond;
if (FS->getCond())
cond = Visit(FS->getCond());
StmtDiff condDiff;
StmtDiff condExprDiff;
if (FS->getCond()) {
std::tie(condDiff, condExprDiff) = DifferentiateSingleExpr(FS->getCond());
}

const auto* IDRE = dyn_cast<DeclRefExpr>(FS->getInc());
const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc();

Expand Down Expand Up @@ -1054,16 +1058,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

const Stmt* body = FS->getBody();
StmtDiff BodyDiff = DifferentiateLoopBody(body, loopCounter,
condVarRes.getStmt_dx(),
incDiff.getStmt_dx(),
/*isForLoop=*/true);
StmtDiff BodyDiff =
DifferentiateLoopBody(body, loopCounter, condVarRes.getStmt_dx(),
incDiff.getStmt_dx(), &condDiff,
/*isForLoop=*/true);

/// FIXME: This part in necessary to replace local variables inside loops
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from
/// overwriting local variables on different loop passes.
Expr* forwardCond = cond.getExpr();
Expr* forwardCond = condExprDiff.getExpr();
/// If there is a declaration in the condition, `cond` will be
/// a DeclRefExpr of the declared variable. There is no point in
/// inserting it since condVarRes.getExpr() represents an assignment with
Expand All @@ -1073,8 +1077,34 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (condVarRes.getExpr() != nullptr && isa<Expr>(condVarRes.getExpr()))
forwardCond = cast<Expr>(condVarRes.getExpr());

Stmt* breakStmt = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get();

beginBlock();
if (utils::unwrapIfSingleStmt(condDiff.getStmt())) {
addToCurrentBlock(condDiff.getStmt());
}

Stmt* IfStmt = clad_compat::IfStmt_Create(
/*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
/*Init=*/nullptr, /*Var=*/nullptr,
/*Cond=*/
BuildOp(clang::UnaryOperatorKind::UO_LNot, BuildParens(forwardCond)),
/*LPL=*/noLoc, /*RPL=*/noLoc,
/*Then=*/breakStmt,
/*EL=*/noLoc,
/*Else=*/nullptr);
addToCurrentBlock(IfStmt);

Stmt* forwardCondStmts = endBlock();
if (BodyDiff.getStmt()) {
BodyDiff.updateStmt(utils::PrependAndCreateCompoundStmt(
m_Context, BodyDiff.getStmt(), forwardCondStmts));
} else {
BodyDiff.updateStmt(utils::unwrapIfSingleStmt(forwardCondStmts));
}

Stmt* Forward = new (m_Context)
ForStmt(m_Context, initResult.getStmt(), forwardCond, condVarClone,
ForStmt(m_Context, initResult.getStmt(), nullptr, condVarClone,
incResult, BodyDiff.getStmt(), noLoc, noLoc, noLoc);

// Create a condition testing counter for being zero, and its decrement.
Expand All @@ -1084,12 +1114,44 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CounterCondition = loopCounter.getCounterConditionResult().get().second;
Expr* CounterDecrement = loopCounter.getCounterDecrement();

Stmt* ReverseResult = BodyDiff.getStmt_dx();
if (!ReverseResult)
ReverseResult = new (m_Context) NullStmt(noLoc);
beginBlock(direction::reverse);
Stmt* RevIfStmt = clad_compat::IfStmt_Create(
/*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
/*Init=*/nullptr, /*Var=*/nullptr,
/*Cond=*/BuildOp(clang::UnaryOperatorKind::UO_LNot, CounterCondition),
/*LPL=*/noLoc, /*RPL=*/noLoc,
/*Then=*/Clone(breakStmt),
/*EL=*/noLoc,
/*Else=*/nullptr);
addToCurrentBlock(RevIfStmt, direction::reverse);

if (condDiff.getStmt_dx()) {
if (m_CurrentBreakFlagExpr) {
Expr* loopBreakFlagCond =
BuildOp(BinaryOperatorKind::BO_LOr,
BuildOp(UnaryOperatorKind::UO_LNot, CounterCondition),
m_CurrentBreakFlagExpr);
auto* RevIfStmt = clad_compat::IfStmt_Create(
m_Context, noLoc, false, nullptr, nullptr, loopBreakFlagCond, noLoc,
noLoc, condDiff.getStmt_dx(), noLoc, nullptr);
addToCurrentBlock(RevIfStmt, direction::reverse);
} else {
addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse);
}
}

Stmt* revPassCondStmts = endBlock(direction::reverse);
if (BodyDiff.getStmt_dx()) {
BodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
m_Context, BodyDiff.getStmt_dx(), revPassCondStmts));
} else {
BodyDiff.updateStmtDx(utils::unwrapIfSingleStmt(revPassCondStmts));
}

Stmt* Reverse = new (m_Context)
ForStmt(m_Context, nullptr, CounterCondition, nullptr, CounterDecrement,
ReverseResult, noLoc, noLoc, noLoc);
ForStmt(m_Context, nullptr, nullptr, nullptr, CounterDecrement,
BodyDiff.getStmt_dx(), noLoc, noLoc, noLoc);

addToCurrentBlock(initResult.getStmt_dx(), direction::reverse);
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
Expand Down Expand Up @@ -2391,8 +2453,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else if (opCode == BO_Comma) {
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Ldiff = Visit(L, zero);
Rdiff = Visit(R, dfdx());
Ldiff = Visit(L, zero);
valueForRevPass = Ldiff.getRevSweepAsExpr();
ResultRef = Ldiff.getExpr();
} else if (opCode == BO_LAnd) {
Expand Down Expand Up @@ -3447,11 +3509,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {endBlock(direction::forward), endBlock(direction::reverse)};
}

StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body,
LoopCounter& loopCounter,
Stmt* condVarDiff,
Stmt* forLoopIncDiff,
bool isForLoop) {
StmtDiff ReverseModeVisitor::DifferentiateLoopBody(
const Stmt* body, LoopCounter& loopCounter, Stmt* condVarDiff,
Stmt* forLoopIncDiff, StmtDiff* condDiff, bool isForLoop) {
Expr* counterIncrement = loopCounter.getCounterIncrement();
auto* activeBreakContHandler = PushBreakContStmtHandler();
activeBreakContHandler->BeginCFSwitchStmtScope();
Expand Down Expand Up @@ -3500,6 +3560,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bodyDiff.updateStmtDx(forLoopIncDiff);
}
}
//
// if (condDiff && utils::unwrapIfSingleStmt(condDiff->getStmt())) {
// if (bodyDiff.getStmt()) {
// bodyDiff.updateStmt(utils::PrependAndCreateCompoundStmt(
// m_Context, bodyDiff.getStmt(),
// condDiff->getStmt()));
// } else {
// bodyDiff.updateStmt(utils::unwrapIfSingleStmt(condDiff->getStmt()));
// }
// }

activeBreakContHandler->EndCFSwitchStmtScope();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
Expand All @@ -3522,6 +3592,63 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(counterDecrement, direction::reverse);
addToCurrentBlock(condVarDiff, direction::reverse);
addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse);
// if (condExprDiff) {
// Stmt* breakStmt = m_Sema.ActOnBreakStmt(noLoc,
// getCurrentScope()).get();
//
// Stmt* IfStmt = clad_compat::IfStmt_Create(
// /*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
// /*Init=*/nullptr, /*Var=*/nullptr,
// /*Cond=*/BuildOp(clang::UnaryOperatorKind::UO_LNot,
// condExprDiff->getExpr()),
// /*LPL=*/noLoc, /*RPL=*/noLoc,
// /*Then=*/breakStmt,
// /*EL=*/noLoc,
// /*Else=*/nullptr);

// if (bodyDiff.getStmt()) {
// bodyDiff.updateStmt(utils::PrependAndCreateCompoundStmt(
// m_Context, bodyDiff.getStmt(),
// IfStmt));
// } else {
// bodyDiff.updateStmt(utils::unwrapIfSingleStmt(IfStmt));
// }
// Expr*
// CounterCondition =
// loopCounter.getCounterConditionResult().get().second;
//
// Stmt* RevIfStmt = clad_compat::IfStmt_Create(
// /*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
// /*Init=*/nullptr, /*Var=*/nullptr,
// /*Cond=*/BuildOp(clang::UnaryOperatorKind::UO_LNot,
// CounterCondition),
// /*LPL=*/noLoc, /*RPL=*/noLoc,
// /*Then=*/Clone(breakStmt),
// /*EL=*/noLoc,
// /*Else=*/nullptr);
// addToCurrentBlock(RevIfStmt, direction::reverse);
// }

// if (condDiff && condDiff->getStmt_dx()) {
// if(m_CurrentBreakFlagExpr) {
// Expr* loopBreakFlagCond = BuildOp(UnaryOperatorKind::UO_LNot,
// m_CurrentBreakFlagExpr);
// addToCurrentBlock(
// BuildOp(BinaryOperatorKind::BO_Assign, m_CurrentBreakFlagExpr,
// ConstantFolder::synthesizeLiteral(m_Context.IntTy,
// m_Context, 0)),
// direction::reverse);
// auto* IfStmt =
// clad_compat::IfStmt_Create(m_Context, noLoc, false, nullptr,
// nullptr,
// loopBreakFlagCond, noLoc, noLoc,
// condDiff->getStmt_dx(), noLoc,
// nullptr);
// addToCurrentBlock(IfStmt, direction::reverse);
// } else {
// addToCurrentBlock(condDiff->getStmt_dx(), direction::reverse);
// }
// }
bodyDiff = {bodyDiff.getStmt(),
utils::unwrapIfSingleStmt(endBlock(direction::reverse))};
return bodyDiff;
Expand All @@ -3546,6 +3673,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Stmt* CFCaseStmt = activeBreakContHandler->GetNextCFCaseStmt();
Stmt* pushExprToCurrentCase = activeBreakContHandler
->CreateCFTapePushExprToCurrentCase();
if (isInsideLoop && !activeBreakContHandler->m_IsInvokedBySwitchStmt) {
Expr* tapeBackExprForCurrentCase =
activeBreakContHandler->CreateCFTapeBackExprForCurrentCase();
if (m_CurrentBreakFlagExpr) {
m_CurrentBreakFlagExpr =
BuildOp(BinaryOperatorKind::BO_LOr, m_CurrentBreakFlagExpr,
tapeBackExprForCurrentCase);

} else {
m_CurrentBreakFlagExpr = tapeBackExprForCurrentCase;
}
}
addToCurrentBlock(pushExprToCurrentCase);
addToCurrentBlock(newBS);
return {endBlock(direction::forward), CFCaseStmt};
Expand Down Expand Up @@ -3607,6 +3746,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return CS;
}

Expr* ReverseModeVisitor::BreakContStmtHandler::
CreateCFTapeBackExprForCurrentCase() {
return m_RMV.BuildOp(
BinaryOperatorKind::BO_NE, m_ControlFlowTape->Last(),
ConstantFolder::synthesizeLiteral(m_RMV.m_Context.IntTy,
m_RMV.m_Context, m_CaseCounter));
}

Stmt* ReverseModeVisitor::BreakContStmtHandler::
CreateCFTapePushExprToCurrentCase() {
if (!m_ControlFlowTape)
Expand Down
Loading

0 comments on commit 1d0b23e

Please sign in to comment.