Skip to content

Commit

Permalink
Remove the enableTBR state from reverse mode visitor. NFC
Browse files Browse the repository at this point in the history
Partially addresses #721.
  • Loading branch information
vgvassilev committed Jun 19, 2024
1 parent 8bc762c commit 2a3e9bc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
1 change: 0 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ namespace clad {
std::vector<Stmts> m_LoopBlock;
unsigned outputArrayCursor = 0;
unsigned numParams = 0;
bool enableTBR = false;
// FIXME: Should we make this an object instead of a pointer?
// Downside of making it an object: We will need to include
// 'MultiplexExternalRMVSource.h' file
Expand Down
18 changes: 6 additions & 12 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
outputArrayStr = m_DiffReq->getParamDecl(lastArgN)->getNameAsString();
}

// Check if DiffRequest asks for TBR analysis to be enabled
if (request.EnableTBRAnalysis)
enableTBR = true;

auto derivativeBaseName = request.BaseFunctionName;
std::string gradientName = derivativeBaseName + funcPostfix();
// To be consistent with older tests, nothing is appended to 'f_grad' if
Expand Down Expand Up @@ -475,10 +471,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
DerivativeAndOverload
ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD,
const DiffRequest& request) {
if (request.EnableTBRAnalysis)
enableTBR = true;
TBRAnalyzer analyzer(m_Context);
if (enableTBR) {
if (request.EnableTBRAnalysis) {
TBRAnalyzer analyzer(m_Context);
analyzer.Analyze(FD);
m_ToBeRecorded = analyzer.getResult();
}
Expand Down Expand Up @@ -602,8 +596,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

void ReverseModeVisitor::DifferentiateWithClad() {
TBRAnalyzer analyzer(m_Context);
if (enableTBR) {
if (m_DiffReq.EnableTBRAnalysis) {
TBRAnalyzer analyzer(m_Context);
analyzer.Analyze(m_DiffReq.Function);
m_ToBeRecorded = analyzer.getResult();
}
Expand Down Expand Up @@ -1695,7 +1689,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullbackRequest.Mode = DiffMode::experimental_pullback;
// Silence diag outputs in nested derivation process.
pullbackRequest.VerboseDiags = false;
pullbackRequest.EnableTBRAnalysis = enableTBR;
pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis;
bool isaMethod = isa<CXXMethodDecl>(FD);
for (size_t i = 0, e = FD->getNumParams(); i < e; ++i)
if (DerivedCallOutputArgs[i + isaMethod])
Expand Down Expand Up @@ -2943,7 +2937,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (isa<DeclRefExpr>(B) || isa<ArraySubscriptExpr>(B) ||
isa<MemberExpr>(B)) {
// If TBR analysis is off, assume E is useful to store.
if (!enableTBR)
if (!m_DiffReq.EnableTBRAnalysis)
return true;
// FIXME: currently, we allow all pointer operations to be stored.
// This is not correct, but we need to implement a more advanced analysis
Expand Down

0 comments on commit 2a3e9bc

Please sign in to comment.