Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiate for loop condition expression #818

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ namespace clad {
/// Output variable of vector-valued function
std::string outputArrayStr;
std::vector<Stmts> m_LoopBlock;
/// This expression checks if the forward pass loop was terminted due to
/// break. It is used to determine whether to run the loop cond
/// differentiation. One additional time.
clang::Expr* m_CurrentBreakFlagExpr;
rohanjulka19 marked this conversation as resolved.
Show resolved Hide resolved

unsigned outputArrayCursor = 0;
unsigned numParams = 0;
// FIXME: Should we make this an object instead of a pointer?
Expand Down Expand Up @@ -561,7 +566,6 @@ namespace clad {

ReverseModeVisitor& m_RMV;

const bool m_IsInvokedBySwitchStmt = false;
/// Builds and returns a literal expression of type `std::size_t` with
rohanjulka19 marked this conversation as resolved.
Show resolved Hide resolved
/// `value` as value.
clang::Expr* CreateSizeTLiteralExpr(std::size_t value);
Expand All @@ -576,6 +580,8 @@ namespace clad {
clang::Expr* CreateCFTapePushExpr(std::size_t value);

public:
bool m_IsInvokedBySwitchStmt = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'm_IsInvokedBySwitchStmt' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes]

      bool m_IsInvokedBySwitchStmt = false;
           ^


BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false)
: m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {}

Expand All @@ -598,6 +604,11 @@ namespace clad {
/// by their actual values respectively.
clang::Stmt* CreateCFTapePushExprToCurrentCase();

/// Builds and return `clad::back(TapeRef) != m_CaseCounter`
/// expression, where `TapeRef` and `m_CaseCounter` are replaced
/// by their actual values respectively
clang::Expr* CreateCFTapeBackExprForCurrentCase();
rohanjulka19 marked this conversation as resolved.
Show resolved Hide resolved

/// Does final modifications on forward and reverse blocks
/// so that `break` and `continue` statements are handled
/// accurately.
Expand Down
113 changes: 100 additions & 13 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,6 @@

StmtDiff thenDiff = VisitBranch(If->getThen());
StmtDiff elseDiff = VisitBranch(If->getElse());

Stmt* Forward = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr, /*Var=*/nullptr,
condDiffStored, noLoc, noLoc, thenDiff.getStmt(), noLoc,
Expand Down Expand Up @@ -992,6 +991,9 @@
Scope::ContinueScope);
beginBlock(direction::reverse);
LoopCounter loopCounter(*this);
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;
const Stmt* init = FS->getInit();
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingLoopInitStmt();
Expand All @@ -1000,7 +1002,6 @@
// Save the isInsideLoop value (we may be inside another loop).
llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop);
isInsideLoop = true;

StmtDiff condVarRes;
VarDecl* condVarClone = nullptr;
if (FS->getConditionVariable()) {
Expand All @@ -1011,11 +1012,12 @@
}
}

// FIXME: for now we assume that cond has no differentiable effects,
// but it is not generally true, e.g. for (...; (x = y); ...)...
StmtDiff cond;
StmtDiff condDiff;
StmtDiff condExprDiff;
if (FS->getCond())
cond = Visit(FS->getCond());
std::tie(condDiff, condExprDiff) = DifferentiateSingleExpr(FS->getCond());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use DifferentiateSingleStmt instead of DifferentiateSingleExpr here? DifferentiateSingleExpr is more complicated, and I prefer to use it only when it is strictly required.

@vgvassilev @PetroZarytskyi, I think it would be a good idea to note the guidelines that help decide when to use Visit, DifferentiateSingleStmt, and DifferentiateSingleExpr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I guess that can be used will change this


const auto* IDRE = dyn_cast<DeclRefExpr>(FS->getInc());
const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc();

Expand Down Expand Up @@ -1063,7 +1065,7 @@
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from
/// overwriting local variables on different loop passes.
Expr* forwardCond = cond.getExpr();
Expr* forwardCond = condExprDiff.getExpr();
/// If there is a declaration in the condition, `cond` will be
/// a DeclRefExpr of the declared variable. There is no point in
/// inserting it since condVarRes.getExpr() represents an assignment with
Expand All @@ -1073,8 +1075,36 @@
if (condVarRes.getExpr() != nullptr && isa<Expr>(condVarRes.getExpr()))
forwardCond = cast<Expr>(condVarRes.getExpr());

Stmt* breakStmt = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get();

/// This part adds the forward pass of loop condition stmt in the body
/// In this first loop condition diff stmts execute then loop condition
/// is checked if and loop is terminated.
beginBlock();
if (utils::unwrapIfSingleStmt(condDiff.getStmt()))
addToCurrentBlock(condDiff.getStmt());

Stmt* IfStmt = clad_compat::IfStmt_Create(
/*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
/*Init=*/nullptr, /*Var=*/nullptr,
/*Cond=*/
BuildOp(clang::UnaryOperatorKind::UO_LNot, BuildParens(forwardCond)),
/*LPL=*/noLoc, /*RPL=*/noLoc,
/*Then=*/breakStmt,
/*EL=*/noLoc,
/*Else=*/nullptr);
addToCurrentBlock(IfStmt);

Stmt* forwardCondStmts = endBlock();
if (BodyDiff.getStmt()) {
BodyDiff.updateStmt(utils::PrependAndCreateCompoundStmt(
m_Context, BodyDiff.getStmt(), forwardCondStmts));
} else {
BodyDiff.updateStmt(utils::unwrapIfSingleStmt(forwardCondStmts));

Check warning on line 1103 in lib/Differentiator/ReverseModeVisitor.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/ReverseModeVisitor.cpp#L1103

Added line #L1103 was not covered by tests
}

Stmt* Forward = new (m_Context)
ForStmt(m_Context, initResult.getStmt(), forwardCond, condVarClone,
ForStmt(m_Context, initResult.getStmt(), nullptr, condVarClone,
incResult, BodyDiff.getStmt(), noLoc, noLoc, noLoc);

// Create a condition testing counter for being zero, and its decrement.
Expand All @@ -1084,12 +1114,45 @@
CounterCondition = loopCounter.getCounterConditionResult().get().second;
Expr* CounterDecrement = loopCounter.getCounterDecrement();

Stmt* ReverseResult = BodyDiff.getStmt_dx();
if (!ReverseResult)
ReverseResult = new (m_Context) NullStmt(noLoc);
/// This part adds the reverse pass of loop condition stmt in the body
beginBlock(direction::reverse);
Stmt* RevIfStmt = clad_compat::IfStmt_Create(
/*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
/*Init=*/nullptr, /*Var=*/nullptr,
/*Cond=*/BuildOp(clang::UnaryOperatorKind::UO_LNot, CounterCondition),
/*LPL=*/noLoc, /*RPL=*/noLoc,
/*Then=*/Clone(breakStmt),
/*EL=*/noLoc,
/*Else=*/nullptr);
addToCurrentBlock(RevIfStmt, direction::reverse);

if (condDiff.getStmt_dx()) {
if (m_CurrentBreakFlagExpr) {
Expr* loopBreakFlagCond =
BuildOp(BinaryOperatorKind::BO_LOr,
BuildOp(UnaryOperatorKind::UO_LNot, CounterCondition),
BuildParens(m_CurrentBreakFlagExpr));
auto* RevIfStmt = clad_compat::IfStmt_Create(
m_Context, noLoc, false, nullptr, nullptr, loopBreakFlagCond, noLoc,
noLoc, condDiff.getStmt_dx(), noLoc, nullptr);
addToCurrentBlock(RevIfStmt, direction::reverse);
} else {
addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse);
}
}

Stmt* revPassCondStmts = endBlock(direction::reverse);
if (BodyDiff.getStmt_dx()) {
BodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
m_Context, BodyDiff.getStmt_dx(), revPassCondStmts));
} else {
BodyDiff.updateStmtDx(utils::unwrapIfSingleStmt(revPassCondStmts));
}

Stmt* Reverse = new (m_Context)
ForStmt(m_Context, nullptr, CounterCondition, nullptr, CounterDecrement,
ReverseResult, noLoc, noLoc, noLoc);
ForStmt(m_Context, nullptr, nullptr, nullptr, CounterDecrement,
BodyDiff.getStmt_dx(), noLoc, noLoc, noLoc);

addToCurrentBlock(initResult.getStmt_dx(), direction::reverse);
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
Expand Down Expand Up @@ -2391,14 +2454,18 @@
} else if (opCode == BO_Comma) {
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Ldiff = Visit(L, zero);
Rdiff = Visit(R, dfdx());
Ldiff = Visit(L, zero);
parth-07 marked this conversation as resolved.
Show resolved Hide resolved
valueForRevPass = Ldiff.getRevSweepAsExpr();
ResultRef = Ldiff.getExpr();
} else if (opCode == BO_LAnd) {
parth-07 marked this conversation as resolved.
Show resolved Hide resolved
VarDecl* condVar = GlobalStoreImpl(m_Context.BoolTy, "_cond");
VarDecl* derivedCondVar = GlobalStoreImpl(
m_Context.DoubleTy, "_d" + condVar->getNameAsString());
addToBlock(BuildOp(BO_Assign, BuildDeclRef(derivedCondVar),
ConstantFolder::synthesizeLiteral(
m_Context.DoubleTy, m_Context, /*val=*/0)),
m_Globals);
Expr* condVarRef = BuildDeclRef(condVar);
Expr* assignExpr = BuildOp(BO_Assign, condVarRef, Clone(R));
m_Variables.emplace(condVar, BuildDeclRef(derivedCondVar));
Expand Down Expand Up @@ -3546,6 +3613,18 @@
Stmt* CFCaseStmt = activeBreakContHandler->GetNextCFCaseStmt();
Stmt* pushExprToCurrentCase = activeBreakContHandler
->CreateCFTapePushExprToCurrentCase();
if (isInsideLoop && !activeBreakContHandler->m_IsInvokedBySwitchStmt) {
Expr* tapeBackExprForCurrentCase =
activeBreakContHandler->CreateCFTapeBackExprForCurrentCase();
if (m_CurrentBreakFlagExpr) {
m_CurrentBreakFlagExpr =
BuildOp(BinaryOperatorKind::BO_LAnd, m_CurrentBreakFlagExpr,
tapeBackExprForCurrentCase);

} else {
m_CurrentBreakFlagExpr = tapeBackExprForCurrentCase;
}
}
addToCurrentBlock(pushExprToCurrentCase);
addToCurrentBlock(newBS);
return {endBlock(direction::forward), CFCaseStmt};
Expand Down Expand Up @@ -3607,6 +3686,14 @@
return CS;
}

Expr* ReverseModeVisitor::BreakContStmtHandler::
CreateCFTapeBackExprForCurrentCase() {
return m_RMV.BuildOp(
BinaryOperatorKind::BO_NE, m_ControlFlowTape->Last(),
ConstantFolder::synthesizeLiteral(m_RMV.m_Context.IntTy,
m_RMV.m_Context, m_CaseCounter));
}

Stmt* ReverseModeVisitor::BreakContStmtHandler::
CreateCFTapePushExprToCurrentCase() {
if (!m_ControlFlowTape)
Expand Down
Loading
Loading