Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a reference to store the LHS of assignments #981

Merged
merged 3 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ namespace clad {
clang::Expr* GlobalStoreAndRef(clang::Expr* E,
llvm::StringRef prefix = "_t",
bool force = false);
StmtDiff StoreAndRestore(clang::Expr* E, llvm::StringRef prefix = "_t",
bool force = false);
StmtDiff StoreAndRestore(clang::Expr* E, llvm::StringRef prefix = "_t");

//// A type returned by DelayedGlobalStoreAndRef
/// .Result is a reference to the created (yet uninitialized) global
Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,8 @@ namespace clad {
}

bool DiffRequest::shouldBeRecorded(Expr* E) const {
assert(EnableTBRAnalysis && "TBR not enabled!");
if (!EnableTBRAnalysis)
return true;

if (!isa<DeclRefExpr>(E) && !isa<ArraySubscriptExpr>(E) &&
!isa<MemberExpr>(E))
Expand Down
96 changes: 44 additions & 52 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2033,7 +2033,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* diff_dx = diff.getExpr_dx();
if (isPointerOp)
addToCurrentBlock(BuildOp(opCode, diff_dx), direction::forward);
if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) {
if (m_DiffReq.shouldBeRecorded(E)) {
auto op = opCode == UO_PostInc ? UO_PostDec : UO_PostInc;
addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())),
direction::reverse);
Expand All @@ -2050,7 +2050,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* diff_dx = diff.getExpr_dx();
if (isPointerOp)
addToCurrentBlock(BuildOp(opCode, diff_dx), direction::forward);
if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) {
if (m_DiffReq.shouldBeRecorded(E)) {
auto op = opCode == UO_PreInc ? UO_PreDec : UO_PreInc;
addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())),
direction::reverse);
Expand Down Expand Up @@ -2332,42 +2332,46 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// in Lblock
beginBlock(direction::reverse);
Ldiff = Visit(L, dfdx());
auto* Lblock = endBlock(direction::reverse);
llvm::SmallVector<Expr*, 4> ExprsToStore;
utils::GetInnermostReturnExpr(Ldiff.getExpr(), ExprsToStore);

// We need to store values of derivative pointer variables in forward pass
// and restore them in reverse pass.
if (isPointerOp) {
Expr* Edx = Ldiff.getExpr_dx();
ExprsToStore.push_back(Edx);
}

if (L->HasSideEffects(m_Context)) {
Expr* E = Ldiff.getExpr();
auto* storeE =
StoreAndRef(E, m_Context.getLValueReferenceType(E->getType()));
Ldiff.updateStmt(storeE);
llvm::SmallVector<Expr*, 4> returnExprs;
utils::GetInnermostReturnExpr(E, returnExprs);
if (returnExprs.size() == 1) {
addToCurrentBlock(E, direction::forward);
Ldiff.updateStmt(returnExprs[0]);
} else {
auto* storeE = GlobalStoreAndRef(BuildOp(UO_AddrOf, E));
Ldiff.updateStmt(BuildOp(UO_Deref, storeE));
}
}

Stmts Lblock = EndBlockWithoutCreatingCS(direction::reverse);

Expr* LCloned = Ldiff.getExpr();
// For x, AssignedDiff is _d_x, for x[i] its _d_x[i], for reference exprs
// For x, ResultRef is _d_x, for x[i] its _d_x[i], for reference exprs
// like (x = y) it propagates recursively, so _d_x is also returned.
Expr* AssignedDiff = Ldiff.getExpr_dx();
if (!AssignedDiff)
ResultRef = Ldiff.getExpr_dx();
if (!ResultRef)
return Clone(BinOp);
ResultRef = AssignedDiff;
// If assigned expr is dependent, first update its derivative;
auto Lblock_begin = Lblock->body_rbegin();
auto Lblock_end = Lblock->body_rend();
if (dfdx() && !Lblock.empty()) {
addToCurrentBlock(*Lblock.begin(), direction::reverse);
Lblock.erase(Lblock.begin());
}

if (dfdx() && Lblock_begin != Lblock_end) {
addToCurrentBlock(*Lblock_begin, direction::reverse);
Lblock_begin = std::next(Lblock_begin);
// Store the value of the LHS of the assignment in the forward pass
// and restore it in the reverse pass
if (m_DiffReq.shouldBeRecorded(L)) {
StmtDiff pushPop = StoreAndRestore(LCloned);
addToCurrentBlock(pushPop.getExpr(), direction::forward);
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
}

for (auto& E : ExprsToStore) {
auto pushPop = StoreAndRestore(E);
// We need to store values of derivative pointer variables in forward pass
// and restore them in reverse pass.
if (isPointerOp) {
StmtDiff pushPop = StoreAndRestore(Ldiff.getExpr_dx());
addToCurrentBlock(pushPop.getExpr(), direction::forward);
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
}
Expand All @@ -2381,13 +2385,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// For pointer types, no need to store old derivatives.
if (!isPointerOp)
oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d",
oldValue = StoreAndRef(ResultRef, direction::reverse, "_r_d",
/*forceDeclCreation=*/true);
if (opCode == BO_Assign) {
if (!isPointerOp) {
// Add the statement `dl = 0;`
Expr* zero = getZeroInit(AssignedDiff->getType());
addToCurrentBlock(BuildOp(BO_Assign, AssignedDiff, zero),
Expr* zero = getZeroInit(ResultRef->getType());
addToCurrentBlock(BuildOp(BO_Assign, ResultRef, zero),
direction::reverse);
}
Rdiff = Visit(R, oldValue);
Expand Down Expand Up @@ -2419,8 +2423,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (isInsideLoop)
addToCurrentBlock(LCloned, direction::forward);
// Add the statement `dl = 0;`
Expr* zero = getZeroInit(AssignedDiff->getType());
addToCurrentBlock(BuildOp(BO_Assign, AssignedDiff, zero),
Expr* zero = getZeroInit(ResultRef->getType());
addToCurrentBlock(BuildOp(BO_Assign, ResultRef, zero),
direction::reverse);
/// Capture all the emitted statements while visiting R
/// and insert them after `dl += dl * R`
Expand All @@ -2429,7 +2433,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Rdiff = Visit(R, dr);
Stmts RBlock = EndBlockWithoutCreatingCS(direction::reverse);
addToCurrentBlock(
BuildOp(BO_AddAssign, AssignedDiff,
BuildOp(BO_AddAssign, ResultRef,
BuildOp(BO_Mul, oldValue, Rdiff.getRevSweepAsExpr())),
direction::reverse);
for (auto& S : RBlock)
Expand All @@ -2439,14 +2443,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, Rdiff.getExpr());
} else if (opCode == BO_DivAssign) {
// Add the statement `dl = 0;`
Expr* zero = getZeroInit(AssignedDiff->getType());
addToCurrentBlock(BuildOp(BO_Assign, AssignedDiff, zero),
Expr* zero = getZeroInit(ResultRef->getType());
addToCurrentBlock(BuildOp(BO_Assign, ResultRef, zero),
direction::reverse);
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
Expr* RStored =
StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse);
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff,
addToCurrentBlock(BuildOp(BO_AddAssign, ResultRef,
BuildOp(BO_Div, oldValue, RStored)),
direction::reverse);
if (!RDelayed.isConstant) {
Expand All @@ -2469,8 +2473,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
opCode);

// Output statements from Visit(L).
for (auto it = Lblock_begin; it != Lblock_end; ++it)
addToCurrentBlock(*it, direction::reverse);
for (Stmt* S : Lblock)
addToCurrentBlock(S, direction::reverse);
} else if (opCode == BO_Comma) {
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expand Down Expand Up @@ -2736,8 +2740,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
initDiff.getForwSweepExpr_dx()));
addToCurrentBlock(assignDerivativeE);
if (isInsideLoop) {
StmtDiff pushPop =
StoreAndRestore(derivedVDE, /*prefix=*/"_t", /*force=*/true);
StmtDiff pushPop = StoreAndRestore(derivedVDE);
addToCurrentBlock(pushPop.getExpr(), direction::forward);
m_LoopBlock.back().push_back(pushPop.getExpr_dx());
}
Expand Down Expand Up @@ -2922,8 +2925,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto* declRef = BuildDeclRef(decl);
auto* assignment = BuildOp(BO_Assign, declRef, decl->getInit());
if (isInsideLoop) {
auto pushPop =
StoreAndRestore(declRef, /*prefix=*/"_t", /*force=*/true);
auto pushPop = StoreAndRestore(declRef);
if (pushPop.getExpr() != declRef)
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
assignment = BuildOp(BO_Comma, pushPop.getExpr(), assignment);
Expand Down Expand Up @@ -3104,12 +3106,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (isa<CallExpr>(B))
return false;

// FIXME: Here will be the entry point of the advanced activity analysis.

// Check if the expression was marked as to-be recorded by an analysis.
if (m_DiffReq.EnableTBRAnalysis)
return m_DiffReq.shouldBeRecorded(B);

// Assume E is useful to store.
return true;
}
Expand Down Expand Up @@ -3171,13 +3167,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::StoreAndRestore(clang::Expr* E,
llvm::StringRef prefix,
bool force) {
llvm::StringRef prefix) {
auto Type = getNonConstType(E->getType(), m_Context, m_Sema);

if (!force && !UsefulToStoreGlobal(E))
return {};

if (isInsideLoop) {
auto CladTape = MakeCladTapeFor(Clone(E), prefix);
Expr* Push = CladTape.Push;
Expand Down
32 changes: 15 additions & 17 deletions lib/Differentiator/TBRAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,17 +295,7 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD, bool forceNonRefType) {
}

void TBRAnalyzer::markLocation(const clang::Expr* E) {
VarData* data = getExprVarData(E);
if (!data || findReq(*data)) {
// FIXME: If any of the data's child nodes are required to store then data
// itself is stored. We might add an option to store separate fields.
// FIXME: Sometimes one location might correspond to multiple stores. For
// example, in ``(x*=y)=u`` x's location will first be marked as required to
// be stored (when passing *= operator) but then marked as not required to
// be stored (when passing = operator). Current method of marking locations
// does not allow to differentiate between these two.
m_TBRLocs.insert(E->getBeginLoc());
}
m_TBRLocs.insert(E->getBeginLoc());
}

void TBRAnalyzer::setIsRequired(const clang::Expr* E, bool isReq) {
Expand Down Expand Up @@ -703,15 +693,19 @@ bool TBRAnalyzer::VisitBinaryOperator(BinaryOperator* BinOp) {
}
llvm::SmallVector<Expr*, 4> ExprsToStore;
utils::GetInnermostReturnExpr(L, ExprsToStore);
bool hasToBeSetReq = false;
for (const auto* innerExpr : ExprsToStore) {
// Mark corresponding SourceLocation as required/not required to be
// stored for all expressions that could be used changed.
markLocation(innerExpr);
// If at least one of ExprsToStore has to be stored,
// mark L as useful to store.
if (VarData* data = getExprVarData(innerExpr))
hasToBeSetReq = hasToBeSetReq || findReq(*data);
// Set them to not required to store because the values were changed.
// (if some value was not changed, this could only happen if it was
// already not required to store).
setIsRequired(innerExpr, /*isReq=*/false);
}
if (hasToBeSetReq)
markLocation(L);
} else if (opCode == BO_Comma) {
setMode(0);
TraverseStmt(L);
Expand All @@ -737,9 +731,13 @@ bool TBRAnalyzer::VisitUnaryOperator(clang::UnaryOperator* UnOp) {
llvm::SmallVector<Expr*, 4> ExprsToStore;
utils::GetInnermostReturnExpr(E, ExprsToStore);
for (const auto* innerExpr : ExprsToStore) {
// Mark corresponding SourceLocation as required/not required to be
// stored for all expressions that could be changed.
markLocation(innerExpr);
// If at least one of ExprsToStore has to be stored,
// mark L as useful to store.
if (VarData* data = getExprVarData(innerExpr))
if (findReq(*data)) {
markLocation(E);
break;
}
}
}
// FIXME: Ideally, `__real` and `__imag` operators should be treated as member
Expand Down
36 changes: 17 additions & 19 deletions test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -414,19 +414,19 @@ double f9(double x, double y) {
//CHECK: void f9_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: double _d_t = 0;
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t2;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double t = x;
//CHECK-NEXT: _t0 = t;
//CHECK-NEXT: double &_t1 = (t *= x);
//CHECK-NEXT: _t2 = t;
//CHECK-NEXT: _t1 *= y;
//CHECK-NEXT: (t *= x);
//CHECK-NEXT: _t1 = t;
//CHECK-NEXT: t *= y;
//CHECK-NEXT: _d_t += 1;
//CHECK-NEXT: {
//CHECK-NEXT: t = _t2;
//CHECK-NEXT: t = _t1;
//CHECK-NEXT: double _r_d1 = _d_t;
//CHECK-NEXT: _d_t = 0;
//CHECK-NEXT: _d_t += _r_d1 * y;
//CHECK-NEXT: *_d_y += _t1 * _r_d1;
//CHECK-NEXT: *_d_y += t * _r_d1;
//CHECK-NEXT: t = _t0;
//CHECK-NEXT: double _r_d0 = _d_t;
//CHECK-NEXT: _d_t = 0;
Expand Down Expand Up @@ -473,15 +473,15 @@ double f11(double x, double y) {
//CHECK: void f11_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: double _d_t = 0;
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t2;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double t = x;
//CHECK-NEXT: _t0 = t;
//CHECK-NEXT: double &_t1 = (t = x);
//CHECK-NEXT: _t2 = t;
//CHECK-NEXT: _t1 = y;
//CHECK-NEXT: (t = x);
//CHECK-NEXT: _t1 = t;
//CHECK-NEXT: t = y;
//CHECK-NEXT: _d_t += 1;
//CHECK-NEXT: {
//CHECK-NEXT: t = _t2;
//CHECK-NEXT: t = _t1;
//CHECK-NEXT: double _r_d1 = _d_t;
//CHECK-NEXT: _d_t = 0;
//CHECK-NEXT: *_d_y += _r_d1;
Expand All @@ -504,26 +504,24 @@ double f12(double x, double y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double *_t2;
//CHECK-NEXT: double _t3;
//CHECK-NEXT: double _t4;
//CHECK-NEXT: double t;
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: _t0 = t;
//CHECK-NEXT: else
//CHECK-NEXT: _t1 = t;
//CHECK-NEXT: double &_t2 = (_cond0 ? (t = x) : (t = y));
//CHECK-NEXT: _t3 = t;
//CHECK-NEXT: _t4 = t;
//CHECK-NEXT: _t2 *= y;
//CHECK-NEXT: _t2 = &(_cond0 ? (t = x) : (t = y));
//CHECK-NEXT: _t3 = *_t2;
//CHECK-NEXT: *_t2 *= y;
//CHECK-NEXT: _d_t += 1;
//CHECK-NEXT: {
//CHECK-NEXT: t = _t3;
//CHECK-NEXT: t = _t4;
//CHECK-NEXT: *_t2 = _t3;
//CHECK-NEXT: double _r_d2 = (_cond0 ? _d_t : _d_t);
//CHECK-NEXT: (_cond0 ? _d_t : _d_t) = 0;
//CHECK-NEXT: (_cond0 ? _d_t : _d_t) += _r_d2 * y;
//CHECK-NEXT: *_d_y += _t2 * _r_d2;
//CHECK-NEXT: *_d_y += *_t2 * _r_d2;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: t = _t0;
//CHECK-NEXT: double _r_d0 = _d_t;
Expand Down
Loading