Skip to content

Commit

Permalink
Remove m_Functor from VisitorBase since it is already stored in the d…
Browse files Browse the repository at this point in the history
…iff request
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Oct 6, 2024
1 parent 118b2c9 commit 8ed2707
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
2 changes: 0 additions & 2 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ namespace clad {
std::vector<Stmts> m_Blocks;
/// Stores output variables for vector-valued functions
VectorOutputs m_VectorOutput;
/// The functor type that is currently being differentiated, if any.
const clang::CXXRecordDecl* m_Functor = nullptr;
/// Stores derivative expression of the implicit `this` pointer.
///
/// In the forward mode, `this` pointer derivative expression is of pointer
Expand Down
20 changes: 9 additions & 11 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ bool IsRealNonReferenceType(QualType T) {
}

DerivativeAndOverload BaseForwardModeVisitor::Derive() {
m_Functor = m_DiffReq.Functor;
const FunctionDecl* FD = m_DiffReq.Function;
assert(m_DiffReq.Mode == DiffMode::forward);
assert(!m_DerivativeInFlight &&
Expand Down Expand Up @@ -138,11 +137,11 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
// class defining the call operator.
// Thus, we need to find index of the member variable instead.
unsigned argIndex = ~0;
if (m_DiffReq->param_empty() && m_Functor)
argIndex =
std::distance(m_Functor->field_begin(),
std::find(m_Functor->field_begin(),
m_Functor->field_end(), m_IndependentVar));
const CXXRecordDecl* functor = m_DiffReq.Functor;
if (m_DiffReq->param_empty() && functor)
argIndex = std::distance(functor->field_begin(),
std::find(functor->field_begin(),
functor->field_end(), m_IndependentVar));
else
argIndex = std::distance(
FD->param_begin(),
Expand Down Expand Up @@ -296,8 +295,8 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {

// Create derived variable for each member variable if we are
// differentiating a call operator.
if (m_Functor) {
for (FieldDecl* fieldDecl : m_Functor->fields()) {
if (m_DiffReq.Functor) {
for (FieldDecl* fieldDecl : m_DiffReq.Functor->fields()) {
Expr* dInitializer = nullptr;
QualType fieldType = fieldDecl->getType();

Expand Down Expand Up @@ -400,7 +399,6 @@ void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {

DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() {
const FunctionDecl* FD = m_DiffReq.Function;
m_Functor = m_DiffReq.Functor;
assert(m_DiffReq.Mode == GetPushForwardMode());
assert(!m_DerivativeInFlight &&
"Doesn't support recursive diff. Use DiffPlan.");
Expand Down Expand Up @@ -884,7 +882,7 @@ StmtDiff BaseForwardModeVisitor::VisitMemberExpr(const MemberExpr* ME) {
auto clonedME = dyn_cast<MemberExpr>(Clone(ME));
// Currently, we only differentiate member variables if we are
// differentiating a call operator.
if (m_Functor) {
if (m_DiffReq.Functor) {
if (isa<CXXThisExpr>(ME->getBase()->IgnoreParenImpCasts())) {
// Try to find the derivative of the member variable wrt independent
// variable
Expand Down Expand Up @@ -956,7 +954,7 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) {
ValueDecl* VD = nullptr;
// Derived variables for member variables are also created when we are
// differentiating a call operator.
if (m_Functor) {
if (m_DiffReq.Functor) {
if (auto ME = dyn_cast<MemberExpr>(clonedBase->IgnoreParenImpCasts())) {
ValueDecl* decl = ME->getMemberDecl();
auto it = m_Variables.find(decl);
Expand Down

0 comments on commit 8ed2707

Please sign in to comment.