Skip to content

Commit

Permalink
Adding check for name collisions in DifferentiateVarDecl
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk authored and Max Andriychuk committed Jun 24, 2024
1 parent 216201a commit a0b29f6
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,27 @@ BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD,
VD->getType(), "_d_" + VD->getNameAsString(), initDiff.getExpr_dx(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
// Check if decl's name is the same as before. The name may be changed
// if decl name collides with something in the derivative body.
// This can happen in rare cases, e.g. when the original function
// has both y and _d_y (here _d_y collides with the name produced by
// the derivation process), e.g.
// double f(double x) {
// double y = x;
// double _d_y = x;
// }
// ->
// double f_darg0(double x) {
// double _d_x = 1;
// double _d_y = _d_x; // produced as a derivative for y
// double y = x;
// double _d__d_y = _d_x;
// double _d_y = x; // copied from original funcion, collides with
// _d_y
// }

if (VDClone->getDeclName() != VD->getDeclName())
m_DeclReplacements[VD] = VDClone;
return DeclDiff<VarDecl>(VDClone, VDDerived);
}

Expand Down Expand Up @@ -1544,26 +1565,6 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
for (auto D : DS->decls()) {
if (auto VD = dyn_cast<VarDecl>(D)) {
DeclDiff<VarDecl> VDDiff = DifferentiateVarDecl(VD);
// Check if decl's name is the same as before. The name may be changed
// if decl name collides with something in the derivative body.
// This can happen in rare cases, e.g. when the original function
// has both y and _d_y (here _d_y collides with the name produced by
// the derivation process), e.g.
// double f(double x) {
// double y = x;
// double _d_y = x;
// }
// ->
// double f_darg0(double x) {
// double _d_x = 1;
// double _d_y = _d_x; // produced as a derivative for y
// double y = x;
// double _d__d_y = _d_x;
// double _d_y = x; // copied from original funcion, collides with
// _d_y
// }
if (VDDiff.getDecl()->getDeclName() != VD->getDeclName())
m_DeclReplacements[VD] = VDDiff.getDecl();
decls.push_back(VDDiff.getDecl());
declsDiff.push_back(VDDiff.getDecl_dx());
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
Expand Down

0 comments on commit a0b29f6

Please sign in to comment.