Skip to content

Commit

Permalink
add support for operator overload in reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-07 authored and vgvassilev committed Oct 15, 2023
1 parent c468fa8 commit 9948b21
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 31 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor {
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override;
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override;
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp) override;
};
} // namespace clad

Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ namespace clad {
StmtDiff VisitParenExpr(const clang::ParenExpr* PE);
virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
StmtDiff VisitStmt(const clang::Stmt* S);
StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
virtual StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC);
/// Decl is not Stmt, so it cannot be visited directly.
StmtDiff VisitWhileStmt(const clang::WhileStmt* WS);
Expand Down
30 changes: 27 additions & 3 deletions lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
DiffParams args{};
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));

auto fnName = m_Function->getNameAsString() + "_forw";
auto fnName = clad::utils::ComputeEffectiveFnName(m_Function) + "_forw";
auto fnDNI = utils::BuildDeclarationNameInfo(m_Sema, fnName);

auto paramTypes = ComputeParamTypes(args);
Expand Down Expand Up @@ -86,8 +86,6 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
QualType
ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType,
QualType xType) {
assert(yType.getNonReferenceType()->isRealType() &&
"yType should be a builtin-numerical scalar type!!");
QualType xValueType = utils::GetValueType(xType);
// derivative variables should always be of non-const type.
xValueType.removeLocalConst();
Expand Down Expand Up @@ -240,4 +238,30 @@ ReverseModeForwPassVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) {
Stmt* newRS = m_Sema.BuildReturnStmt(noLoc, returnInitList).get();
return {newRS};
}

StmtDiff
ReverseModeForwPassVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
auto opCode = UnOp->getOpcode();
StmtDiff diff{};
// If it is a post-increment/decrement operator, its result is a reference
// and we should return it.
Expr* ResultRef = nullptr;
if (opCode == UnaryOperatorKind::UO_Deref) {
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_Function)) {
if (MD->isInstance()) {
diff = Visit(UnOp->getSubExpr());
Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr());
Expr* derivedE = diff.getExpr_dx();
return {cloneE, derivedE};
}
}
} else if (opCode == UO_Plus)
diff = Visit(UnOp->getSubExpr(), dfdx());
else if (opCode == UO_Minus) {
auto d = BuildOp(UO_Minus, dfdx());
diff = Visit(UnOp->getSubExpr(), d);
}
Expr* op = BuildOp(opCode, diff.getExpr());
return StmtDiff(op, ResultRef);
}
} // namespace clad
63 changes: 39 additions & 24 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnScope();

auto params = BuildParams(args);

if (m_ExternalSource)
Expand Down Expand Up @@ -416,7 +416,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_IndependentVars.push_back(arg);
}
}

if (m_ExternalSource)
m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope();

Expand Down Expand Up @@ -748,7 +748,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff SDiff = DifferentiateSingleStmt(S);
addToCurrentBlock(SDiff.getStmt(), direction::forward);
addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse);

if (m_ExternalSource)
m_ExternalSource->ActAfterProcessingStmtInVisitCompoundStmt();
}
Expand Down Expand Up @@ -866,7 +866,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt();
StmtDiff BranchDiff = DifferentiateSingleStmt(Branch, /*dfdS=*/nullptr);
addToCurrentBlock(BranchDiff.getStmt(), direction::forward);

if (m_ExternalSource)
m_ExternalSource->ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt();

Expand Down Expand Up @@ -1377,7 +1377,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If the function has no args and is not a member function call then we
// assume that it is not related to independent variables and does not
// contribute to gradient.
if (!NArgs && !isa<CXXMemberCallExpr>(CE))
if ((NArgs == 0U) && !isa<CXXMemberCallExpr>(CE) &&
!isa<CXXOperatorCallExpr>(CE))
return StmtDiff(Clone(CE));

// Stores the call arguments for the function to be derived
Expand All @@ -1397,7 +1398,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// derived function. In the case of member functions, `implicit`
// this object is always passed by reference.
if (!dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) &&
!isa<CXXMemberCallExpr>(CE)) {
!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE)) {
for (const Expr* Arg : CE->arguments()) {
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
Expand Down Expand Up @@ -1429,9 +1430,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// FIXME: We should add instructions for handling non-differentiable
// arguments. Currently we are implicitly assuming function call only
// contains differentiable arguments.
for (std::size_t i = skipFirstArg, e = CE->getNumArgs(); i != e; ++i) {
bool isCXXOperatorCall = isa<CXXOperatorCallExpr>(CE);

for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall),
e = CE->getNumArgs();
i != e; ++i) {
const Expr* arg = CE->getArg(i);
const auto* PVD = FD->getParamDecl(i - skipFirstArg);
const auto* PVD =
FD->getParamDecl(i - static_cast<unsigned long>(isCXXOperatorCall));
StmtDiff argDiff{};
bool passByRef = utils::IsReferenceOrPointerType(PVD->getType());
// We do not need to create result arg for arguments passed by reference
Expand Down Expand Up @@ -1719,11 +1725,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullback);

// Try to find it in builtin derivatives
std::string customPullback = FD->getNameAsString() + "_pullback";
if (baseDiff.getExpr())
pullbackCallArgs.insert(
pullbackCallArgs.begin(),
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr()));
std::string customPullback =
clad::utils::ComputeEffectiveFnName(FD) + "_pullback";
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
if (baseDiff.getExpr())
pullbackCallArgs.erase(pullbackCallArgs.begin());
}

// should be true if we are using numerical differentiation to differentiate
Expand Down Expand Up @@ -1754,7 +1767,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// derive the called function.
DiffRequest pullbackRequest{};
pullbackRequest.Function = FD;
pullbackRequest.BaseFunctionName = FD->getNameAsString();
pullbackRequest.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
pullbackRequest.Mode = DiffMode::experimental_pullback;
// Silence diag outputs in nested derivation process.
pullbackRequest.VerboseDiags = false;
Expand Down Expand Up @@ -1887,7 +1901,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
DiffRequest calleeFnForwPassReq;
calleeFnForwPassReq.Function = FD;
calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass;
calleeFnForwPassReq.BaseFunctionName = FD->getNameAsString();
calleeFnForwPassReq.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
calleeFnForwPassReq.VerboseDiags = true;
FunctionDecl* calleeFnForwPassFD =
plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq);
Expand All @@ -1911,13 +1926,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// (isCladArrayType(derivedBase->getType()))
// CallArgs.push_back(derivedBase);
// else
// Currently derivedBase `*d_this` can never be CladArrayType
CallArgs.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, noLoc));
}

for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall),
e = CE->getNumArgs();
i != e; ++i) {
const Expr* arg = CE->getArg(i);
const ParmVarDecl* PVD = FD->getParamDecl(i);
const ParmVarDecl* PVD =
FD->getParamDecl(i - static_cast<unsigned long>(isCXXOperatorCall));
StmtDiff argDiff = Visit(arg);
if ((argDiff.getExpr_dx() != nullptr) &&
PVD->getType()->isReferenceType()) {
Expand Down Expand Up @@ -1993,8 +2012,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
}
}
else {
} else {
// FIXME: This is not adding 'address-of' operator support.
// This is just making this special case differentiable that is required
// for computing hessian:
Expand Down Expand Up @@ -2387,13 +2405,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDDerivedInit = getZeroInit(VD->getType());

// `specialThisDiffCase` is only required for correctly differentiating
// the following code:
// the following code:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// Computation of hessian requires this code to be correctly
// differentiated.
// differentiated.
bool specialThisDiffCase = false;
if (auto MD = dyn_cast<CXXMethodDecl>(m_Function)) {
if (VDDerivedType->isPointerType() && MD->isInstance()) {
Expand Down Expand Up @@ -2512,10 +2530,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

return VarDeclDiff(VDClone, VDDerived);
}

// TODO: 'shouldEmit' parameter should be removed after converting
// Error estimation framework to callback style. Some more research
// need to be done to
// need to be done to
StmtDiff
ReverseModeVisitor::DifferentiateSingleStmt(const Stmt* S, Expr* dfdS) {
if (m_ExternalSource)
Expand Down Expand Up @@ -3126,7 +3144,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

bodyDiff = {bodyDiff.getStmt(), CFSS};
}

void ReverseModeVisitor::AddExternalSource(ExternalRMVSource& source) {
if (!m_ExternalSource)
m_ExternalSource = new MultiplexExternalRMVSource();
Expand Down Expand Up @@ -3182,13 +3200,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

QualType ReverseModeVisitor::GetParameterDerivativeType(QualType yType,
QualType xType) {

if (m_Mode == DiffMode::reverse)
assert(yType->isRealType() &&
"yType should be a non-reference builtin-numerical scalar type!!");
else if (m_Mode == DiffMode::experimental_pullback)
assert(yType.getNonReferenceType()->isRealType() &&
"yType should be a builtin-numerical scalar type!!");
QualType xValueType = utils::GetValueType(xType);
// derivative variables should always be of non-const type.
xValueType.removeLocalConst();
Expand Down
Loading

0 comments on commit 9948b21

Please sign in to comment.