Skip to content

Commit

Permalink
Reorder the statement to subtract the old derivative value in assig op
Browse files Browse the repository at this point in the history
closes #650
  • Loading branch information
vaithak authored and vgvassilev committed Jan 25, 2024
1 parent 93cc725 commit fa59708
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 67 deletions.
14 changes: 9 additions & 5 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2393,8 +2393,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (!isPointerOp)
oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d",
/*forceDeclCreation=*/true);

if (opCode == BO_Assign) {
// Add the statement `dl -= oldValue;`
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, oldValue);
valueForRevPass = Rdiff.getRevSweepAsExpr();
} else if (opCode == BO_AddAssign) {
Expand Down Expand Up @@ -2423,6 +2425,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// double r = _ref0 *= z;
if (isInsideLoop)
addToCurrentBlock(LCloned, direction::forward);
// Add the statement `dl -= oldValue;`
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue),
direction::reverse);
/// Capture all the emitted statements while visiting R
/// and insert them after `dl += dl * R`
beginBlock(direction::reverse);
Expand All @@ -2439,6 +2444,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Ldiff.getRevSweepAsExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, Rdiff.getExpr());
} else if (opCode == BO_DivAssign) {
// Add the statement `dl -= oldValue;`
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue),
direction::reverse);
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
Expr* RStored =
Expand All @@ -2464,10 +2472,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActBeforeFinalisingAssignOp(LCloned, oldValue);

// Update the derivative.
if (opCode != BO_SubAssign && opCode != BO_AddAssign)
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue),
direction::reverse);
// Output statements from Visit(L).
for (auto it = Lblock_begin; it != Lblock_end; ++it)
addToCurrentBlock(*it, direction::reverse);
Expand Down
Loading

0 comments on commit fa59708

Please sign in to comment.