Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass the DiffRequest down to the visitors. NFC. #933

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions .github/workflows/arch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@ jobs:
arch: ${{ matrix.target }}
branch: edge
packages: >
llvm-dev
clang-dev
clang-static
llvm17-dev
clang17-dev
clang17-static
llvm17-static
llvm17-gtest
cmake
llvm
clang
make
git
- name: "Setup"
Expand Down
5 changes: 3 additions & 2 deletions demos/ErrorEstimation/CustomModel/CustomModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
// FPErrorEstimationModel class.
class CustomModel : public clad::FPErrorEstimationModel {
public:
CustomModel(clad::DerivativeBuilder& builder)
: FPErrorEstimationModel(builder) {}
CustomModel(clad::DerivativeBuilder& builder,
const clad::DiffRequest& request)
: FPErrorEstimationModel(builder, request) {}
/// Return an expression of the following kind:
/// dfdx * delta_x
clang::Expr* AssignError(clad::StmtDiff refExpr,
Expand Down
4 changes: 2 additions & 2 deletions demos/ErrorEstimation/PrintModel/PrintModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
// FPErrorEstimationModel class.
class PrintModel : public clad::FPErrorEstimationModel {
public:
PrintModel(clad::DerivativeBuilder& builder)
: FPErrorEstimationModel(builder) {}
PrintModel(clad::DerivativeBuilder& builder, const clad::DiffRequest& request)
: FPErrorEstimationModel(builder, request) {}
// Return an expression of the following kind:
// dfdx * delta_x
clang::Expr* AssignError(clad::StmtDiff refExpr, const std::string& name) override;
Expand Down
6 changes: 3 additions & 3 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class BaseForwardModeVisitor
unsigned m_ArgIndex = ~0;

public:
BaseForwardModeVisitor(DerivativeBuilder& builder);
BaseForwardModeVisitor(DerivativeBuilder& builder,
const DiffRequest& request);
virtual ~BaseForwardModeVisitor();

///\brief Produces the first derivative of a given function.
Expand All @@ -41,8 +42,7 @@ class BaseForwardModeVisitor
const DiffRequest& request);

/// Returns the return type for the pushforward function of the function
/// `m_Function`.
/// \note `m_Function` field should be set before using this function.
/// `m_DiffReq->Function`.
clang::QualType ComputePushforwardFnReturnType();

virtual void ExecuteInsidePushforwardFunctionBlock();
Expand Down
2 changes: 2 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ struct DiffRequest {
DeclarationOnly == other.DeclarationOnly;
}

const clang::FunctionDecl* operator->() const { return Function; }

// String operator for printing the node.
operator std::string() const {
std::string res = BaseFunctionName + "__order_" +
Expand Down
20 changes: 14 additions & 6 deletions include/clad/Differentiator/EstimationModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ namespace clad {
std::unordered_map<const clang::VarDecl*, clang::Expr*> m_EstimateVar;

public:
FPErrorEstimationModel(DerivativeBuilder& builder) : VisitorBase(builder) {}
// FIXME: Add a proper parameter for the DiffRequest here.
FPErrorEstimationModel(DerivativeBuilder& builder,
const DiffRequest& request)
: VisitorBase(builder, request) {}
virtual ~FPErrorEstimationModel();

/// Clear the variable estimate map so that we can start afresh.
Expand Down Expand Up @@ -83,10 +86,13 @@ namespace clad {
/// custom model.
/// \param[in] builder A build instance to pass to the custom model
/// constructor.
/// \param[in] request The differentiation configuration passed to the
/// custom model
/// \returns A reference to the custom class wrapped in the
/// FPErrorEstimationModel class.
virtual std::unique_ptr<FPErrorEstimationModel>
InstantiateCustomModel(DerivativeBuilder& builder) = 0;
InstantiateCustomModel(DerivativeBuilder& builder,
const DiffRequest& request) = 0;
};

/// A class used to register custom plugins.
Expand All @@ -99,16 +105,18 @@ namespace clad {
///
/// \param[in] builder The current instance of derivative builder.
std::unique_ptr<FPErrorEstimationModel>
InstantiateCustomModel(DerivativeBuilder& builder) override {
return std::unique_ptr<FPErrorEstimationModel>(new CustomClass(builder));
InstantiateCustomModel(DerivativeBuilder& builder,
const DiffRequest& request) override {
return std::unique_ptr<FPErrorEstimationModel>(
new CustomClass(builder, request));
}
};

/// Example class for taylor series approximation based error estimation.
class TaylorApprox : public FPErrorEstimationModel {
public:
TaylorApprox(DerivativeBuilder& builder)
: FPErrorEstimationModel(builder) {}
TaylorApprox(DerivativeBuilder& builder, const DiffRequest& request)
: FPErrorEstimationModel(builder, request) {}
// Return an expression of the following kind:
// std::abs(dfdx * delta_x * Em)
clang::Expr* AssignError(StmtDiff refExpr,
Expand Down
6 changes: 3 additions & 3 deletions include/clad/Differentiator/HessianModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ namespace clad {
size_t TotalIndependentArgsSize, std::string hessianFuncName);

public:
HessianModeVisitor(DerivativeBuilder& builder);
~HessianModeVisitor();
HessianModeVisitor(DerivativeBuilder& builder, const DiffRequest& request);
~HessianModeVisitor() = default;

///\brief Produces the hessian second derivative columns of a given
/// function.
Expand All @@ -53,4 +53,4 @@ namespace clad {
};
} // end namespace clad

#endif // CLAD_HESSIAN_MODE_VISITOR_H
#endif // CLAD_HESSIAN_MODE_VISITOR_H
4 changes: 3 additions & 1 deletion include/clad/Differentiator/PushForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
#include "BaseForwardModeVisitor.h"

namespace clad {

/// A visitor for processing the function code in forward mode.
/// Used to compute derivatives by clad::differentiate.
class PushForwardModeVisitor : public BaseForwardModeVisitor {

public:
PushForwardModeVisitor(DerivativeBuilder& builder);
PushForwardModeVisitor(DerivativeBuilder& builder,
const DiffRequest& request);
~PushForwardModeVisitor() override;

StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
Expand Down
5 changes: 3 additions & 2 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor {
clang::QualType xType);

public:
ReverseModeForwPassVisitor(DerivativeBuilder& builder);
ReverseModeForwPassVisitor(DerivativeBuilder& builder,
const DiffRequest& request);
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);

Expand All @@ -34,4 +35,4 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor {
};
} // namespace clad

#endif
#endif
10 changes: 3 additions & 7 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ namespace clad {
llvm::SmallVectorImpl<clang::Expr*>& outputArgs);

public:
ReverseModeVisitor(DerivativeBuilder& builder);
ReverseModeVisitor(DerivativeBuilder& builder, const DiffRequest& request);
virtual ~ReverseModeVisitor();

///\brief Produces the gradient of a given function.
Expand Down Expand Up @@ -629,18 +629,14 @@ namespace clad {
/// Computes and returns the sequence of derived function parameter types.
///
/// Information about the original function and the differentiation mode
/// are taken from the data member variables. In particular, `m_Function`,
/// `m_Mode` data members should be correctly set before using this
/// function.
/// are taken from the data member variables.
llvm::SmallVector<clang::QualType, 8> ComputeParamTypes(const DiffParams& diffParams);

/// Builds and returns the sequence of derived function parameters.
///
/// Information about the original function, derived function, derived
/// function parameter types and the differentiation mode are implicitly
/// taken from the data member variables. In particular, `m_Function`,
/// `m_Mode` and `m_Derivative` should be correctly set before using this
/// function.
/// taken from the data member variables.
llvm::SmallVector<clang::ParmVarDecl*, 8>
BuildParams(DiffParams& diffParams);

Expand Down
7 changes: 3 additions & 4 deletions include/clad/Differentiator/VectorForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
clang::Expr* m_IndVarCountExpr;

public:
VectorForwardModeVisitor(DerivativeBuilder& builder);
VectorForwardModeVisitor(DerivativeBuilder& builder,
const DiffRequest& request);
~VectorForwardModeVisitor();

///\brief Produces the first derivative of a given function with
Expand Down Expand Up @@ -53,9 +54,7 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
///
/// Information about the original function, derived function, derived
/// function parameter types and the differentiation mode are implicitly
/// taken from the data member variables. In particular, `m_Function`,
/// `m_Mode` and `m_Derivative` should be correctly set before using this
/// function.
/// taken from the data member variables.
llvm::SmallVector<clang::ParmVarDecl*, 8>
BuildVectorModeParams(DiffParams& diffParams);

Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/VectorPushForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace clad {
class VectorPushForwardModeVisitor : public VectorForwardModeVisitor {

public:
VectorPushForwardModeVisitor(DerivativeBuilder& builder);
VectorPushForwardModeVisitor(DerivativeBuilder& builder,
const DiffRequest& request);
~VectorPushForwardModeVisitor() override;

void ExecuteInsidePushforwardFunctionBlock() override;
Expand Down
8 changes: 4 additions & 4 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ namespace clad {
/// A base class for all common functionality for visitors
class VisitorBase {
protected:
VisitorBase(DerivativeBuilder& builder)
VisitorBase(DerivativeBuilder& builder, const DiffRequest& request)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: constructor does not initialize these fields: m_Mode [cppcoreguidelines-pro-type-member-init]

    VisitorBase(DerivativeBuilder& builder, const DiffRequest& request)
    ^

: m_Builder(builder), m_Sema(builder.m_Sema),
m_CladPlugin(builder.m_CladPlugin), m_Context(builder.m_Context),
m_DerivativeFnScope(nullptr), m_DerivativeInFlight(false),
m_Derivative(nullptr), m_Function(nullptr) {}
m_Derivative(nullptr), m_DiffReq(request) {}

using Stmts = llvm::SmallVector<clang::Stmt*, 16>;

Expand All @@ -117,8 +117,8 @@ namespace clad {
bool m_DerivativeInFlight;
/// The Derivative function that is being generated.
clang::FunctionDecl* m_Derivative;
/// The function that is currently differentiated.
const clang::FunctionDecl* m_Function;
/// The differentiation request that is being currently processed.
const DiffRequest& m_DiffReq;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member 'm_DiffReq' of type 'const DiffRequest &' is a reference [cppcoreguidelines-avoid-const-or-ref-data-members]

    const DiffRequest& m_DiffReq;
                       ^

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'm_DiffReq' has protected visibility [cppcoreguidelines-non-private-member-variables-in-classes]

    const DiffRequest& m_DiffReq;
                       ^

DiffMode m_Mode;
/// Map used to keep track of variable declarations and match them
/// with their derivatives.
Expand Down
Loading
Loading