Skip to content

Commit

Permalink
Added timings report if -ftime-report flag is enabled, fixes issue #769
Browse files Browse the repository at this point in the history
Clad will now print a timings report for all clad function calls
Adds CladTimerGroup class to time the clad functions
Repurposes LIBCLAD_TIMING flag to print only the timings report for clad
  • Loading branch information
DeadSpheroid authored and vgvassilev committed Feb 29, 2024
1 parent 2e7d50a commit f242077
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 34 deletions.
31 changes: 31 additions & 0 deletions test/Misc/TimingsReport.C
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: %cladclang %s -I%S/../../include -oTimingsReport.out -ftime-report 2>&1 | FileCheck %s

#include "clad/Differentiator/Differentiator.h"
// CHECK-NOT: {{.*error|warning|note:.*}}
// CHECK: Timers for Clad Funcs

double nested1(double c){
return c*3*c;
}

double nested2(double z){
return 4*z*z;
}

double test1(double x, double y) {
return 2*y*nested1(y) * 3 * x * nested1(x);
}

double test2(double a, double b) {
return 3*a*a + b * nested2(a) + a * b;
}

int main() {
auto d_fn_1 = clad::differentiate(test1, "x");
double dp = -1, dq = -1;
auto f_grad = clad::gradient(test2);
f_grad.execute(3, 4, &dp, &dq);
printf("Result is = %f\n", d_fn_1.execute(3,4));
printf("Result is = %f %f\n", dp, dq);
return 0;
}
65 changes: 32 additions & 33 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,6 @@

using namespace clang;

namespace {
class SimpleTimer {
bool WantTiming;
llvm::TimeRecord Start;
std::string Output;

public:
explicit SimpleTimer(bool WantTiming) : WantTiming(WantTiming) {
if (WantTiming)
Start = llvm::TimeRecord::getCurrentTime();
}

void setOutput(const Twine &Output) {
if (WantTiming)
this->Output = Output.str();
}

~SimpleTimer() {
if (WantTiming) {
llvm::TimeRecord Elapsed = llvm::TimeRecord::getCurrentTime();
Elapsed -= Start;
llvm::errs() << Output << ": user | system | process | all :";
Elapsed.print(Elapsed, llvm::errs());
llvm::errs() << '\n';
}
}
};
}

namespace clad {
namespace plugin {
/// Keeps track if we encountered #pragma clad on/off.
Expand Down Expand Up @@ -103,7 +74,6 @@ namespace clad {
CladPlugin::CladPlugin(CompilerInstance& CI, DifferentiationOptions& DO)
: m_CI(CI), m_DO(DO), m_HasRuntime(false) {
#if CLANG_VERSION_MAJOR > 8

FrontendOptions& Opts = CI.getFrontendOpts();
// Find the path to clad.
llvm::StringRef CladSoPath;
Expand Down Expand Up @@ -228,19 +198,31 @@ namespace clad {
// FIXME: Move the timing inside the DerivativeBuilder. This would
// require to pass in the DifferentiationOptions in the DiffPlan.
// derive the collected functions
bool WantTiming = getenv("LIBCLAD_TIMING");
SimpleTimer Timer(WantTiming);
Timer.setOutput("Generation time for " + FD->getNameAsString());

#if CLANG_VERSION_MAJOR > 11
bool WantTiming =
getenv("LIBCLAD_TIMING") || m_CI.getCodeGenOpts().TimePasses;
#else
bool WantTiming =
getenv("LIBCLAD_TIMING") || m_CI.getFrontendOpts().ShowTimers;
#endif

auto DFI = m_DFC.Find(request);
if (DFI.IsValid()) {
DerivativeDecl = DFI.DerivedFn();
OverloadedDerivativeDecl = DFI.OverloadedDerivedFn();
alreadyDerived = true;
} else {
// Only time the function when it is first encountered
if (WantTiming)
m_CTG.StartNewTimer("Timer for clad func",
request.BaseFunctionName);

auto deriveResult = m_DerivativeBuilder->Derive(request);
DerivativeDecl = deriveResult.derivative;
OverloadedDerivativeDecl = deriveResult.overload;
if (WantTiming)
m_CTG.StopTimer();
}
}

Expand Down Expand Up @@ -338,6 +320,23 @@ namespace clad {
}
} // end namespace plugin

clad::CladTimerGroup::CladTimerGroup()
: m_Tg("Timers for Clad Funcs", "Timers for Clad Funcs") {}

void clad::CladTimerGroup::StartNewTimer(llvm::StringRef TimerName,
llvm::StringRef TimerDesc) {
std::unique_ptr<llvm::Timer> tm(
new llvm::Timer(TimerName, TimerDesc, m_Tg));
m_Timers.push_back(std::move(tm));
m_Timers.back()->startTimer();
}

void clad::CladTimerGroup::StopTimer() {
m_Timers.back()->stopTimer();
if (m_Timers.size() != 1)
m_Timers.pop_back();
}

// Routine to check clang version at runtime against the clang version for
// which clad was built.
bool checkClangVersion() {
Expand Down
13 changes: 12 additions & 1 deletion tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
#include "clang/Basic/Version.h"
#include "clang/Frontend/FrontendPluginRegistry.h"

#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Timer.h"

namespace clang {
class ASTContext;
Expand Down Expand Up @@ -61,6 +62,15 @@ namespace clad {
/// argument `DFI`.
bool AlreadyExists(const DerivedFnInfo& DFI) const;
};
class CladTimerGroup {
llvm::TimerGroup m_Tg;
std::vector<std::unique_ptr<llvm::Timer>> m_Timers;

public:
CladTimerGroup();
void StartNewTimer(llvm::StringRef TimerName, llvm::StringRef TimerDesc);
void StopTimer();
};

namespace plugin {
struct DifferentiationOptions {
Expand Down Expand Up @@ -89,6 +99,7 @@ namespace clad {
bool m_HasRuntime = false;
bool m_PendingInstantiationsInFlight = false;
bool m_HandleTopLevelDeclInternal = false;
CladTimerGroup m_CTG;
DerivedFnCollector m_DFC;
public:
CladPlugin(clang::CompilerInstance& CI, DifferentiationOptions& DO);
Expand Down

0 comments on commit f242077

Please sign in to comment.