diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 34116dc39..56a192335 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1499,6 +1499,27 @@ BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD, VD->getType(), "_d_" + VD->getNameAsString(), initDiff.getExpr_dx(), VD->isDirectInit(), nullptr, VD->getInitStyle()); m_Variables.emplace(VDClone, BuildDeclRef(VDDerived)); + // Check if decl's name is the same as before. The name may be changed + // if decl name collides with something in the derivative body. + // This can happen in rare cases, e.g. when the original function + // has both y and _d_y (here _d_y collides with the name produced by + // the derivation process), e.g. + // double f(double x) { + // double y = x; + // double _d_y = x; + // } + // -> + // double f_darg0(double x) { + // double _d_x = 1; + // double _d_y = _d_x; // produced as a derivative for y + // double y = x; + // double _d__d_y = _d_x; + // double _d_y = x; // copied from original funcion, collides with + // _d_y + // } + + if (VDClone->getDeclName() != VD->getDeclName()) + m_DeclReplacements[VD] = VDClone; return DeclDiff(VDClone, VDDerived); } @@ -1544,26 +1565,6 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) { for (auto D : DS->decls()) { if (auto VD = dyn_cast(D)) { DeclDiff VDDiff = DifferentiateVarDecl(VD); - // Check if decl's name is the same as before. The name may be changed - // if decl name collides with something in the derivative body. - // This can happen in rare cases, e.g. when the original function - // has both y and _d_y (here _d_y collides with the name produced by - // the derivation process), e.g. - // double f(double x) { - // double y = x; - // double _d_y = x; - // } - // -> - // double f_darg0(double x) { - // double _d_x = 1; - // double _d_y = _d_x; // produced as a derivative for y - // double y = x; - // double _d__d_y = _d_x; - // double _d_y = x; // copied from original funcion, collides with - // _d_y - // } - if (VDDiff.getDecl()->getDeclName() != VD->getDeclName()) - m_DeclReplacements[VD] = VDDiff.getDecl(); decls.push_back(VDDiff.getDecl()); declsDiff.push_back(VDDiff.getDecl_dx()); } else if (auto* SAD = dyn_cast(D)) {