Skip to content

Commit

Permalink
Add support for simple lambda expressions in reverse mode
Browse files Browse the repository at this point in the history
This commit provides support for primitive lambda expressions
with no captures in reverse mode in the same way they are
currently supported in the forward mode (vgvassilev#937). That is, the
lambda expressions are not visited yet. Instead, the lambda
functions are treated as a special case of functors.

Fixes: vgvassilev#789
  • Loading branch information
gojakuch authored and vgvassilev committed Jul 1, 2024
1 parent 1d56ef8 commit 22b2590
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 6 deletions.
51 changes: 45 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "clad/Differentiator/StmtClone.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTLambda.h"
#include "clang/AST/Expr.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/TemplateBase.h"
Expand Down Expand Up @@ -1596,13 +1597,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff baseDiff;
// If it has more args or f_darg0 was not found, we look for its pullback
// function.
const auto* MD = dyn_cast<CXXMethodDecl>(FD);
if (!OverloadedDerivedFn) {
size_t idx = 0;

/// Add base derivative expression in the derived call output args list if
/// `CE` is a call to an instance member function.
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
if (MD->isInstance()) {
if (MD) {
if (isLambdaCallOperator(MD)) {
QualType ptrType = m_Context.getPointerType(m_Context.getRecordType(
FD->getDeclContext()->getOuterLexicalRecordContext()));
baseDiff =
StmtDiff(Clone(dyn_cast<CXXOperatorCallExpr>(CE)->getArg(0)),
new (m_Context) CXXNullPtrLiteralExpr(ptrType, Loc));
} else if (MD->isInstance()) {
const Expr* baseOriginalE = nullptr;
if (const auto* MCE = dyn_cast<CXXMemberCallExpr>(CE))
baseOriginalE = MCE->getImplicitObjectArgument();
Expand Down Expand Up @@ -1700,7 +1708,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis;
bool isaMethod = isa<CXXMethodDecl>(FD);
for (size_t i = 0, e = FD->getNumParams(); i < e; ++i)
if (DerivedCallOutputArgs[i + isaMethod])
if (MD && isLambdaCallOperator(MD)) {
if (const auto* paramDecl = FD->getParamDecl(i))
pullbackRequest.DVI.push_back(paramDecl);
} else if (DerivedCallOutputArgs[i + isaMethod])
pullbackRequest.DVI.push_back(FD->getParamDecl(i));

FunctionDecl* pullbackFD = nullptr;
Expand Down Expand Up @@ -2735,14 +2746,41 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bool promoteToFnScope =
!getCurrentScope()->isFunctionScope() &&
m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass;

// If the DeclStmt is not empty, check the first declaration in case it is a
// lambda function. This case it is treated separately for now and we don't
// create a variable for its derivative.
bool isLambda = false;
const auto* declsBegin = DS->decls().begin();
if (declsBegin != DS->decls().end() && isa<VarDecl>(*declsBegin)) {
auto* VD = dyn_cast<VarDecl>(*declsBegin);
QualType QT = VD->getType();
if (!QT->isPointerType()) {
auto* typeDecl = QT->getAsCXXRecordDecl();
// We should also simply copy the original lambda. The differentiation
// of lambdas is happening in the `VisitCallExpr`. For now, only the
// declarations with lambda expressions without captures are supported.
isLambda = typeDecl && typeDecl->isLambda();
if (isLambda) {
for (auto* D : DS->decls())
if (auto* VD = dyn_cast<VarDecl>(D))
decls.push_back(VD);
Stmt* DSClone = BuildDeclStmt(decls);
return StmtDiff(DSClone, nullptr);
}
}
}

// For each variable declaration v, create another declaration _d_v to
// store derivatives for potential reassignments. E.g.
// double y = x;
// ->
// double _d_y = _d_x; double y = x;
for (auto* D : DS->decls()) {
if (auto* VD = dyn_cast<VarDecl>(D)) {
DeclDiff<VarDecl> VDDiff = DifferentiateVarDecl(VD);
DeclDiff<VarDecl> VDDiff;
if (!isLambda)
VDDiff = DifferentiateVarDecl(VD);

// Check if decl's name is the same as before. The name may be changed
// if decl name collides with something in the derivative body.
Expand All @@ -2762,8 +2800,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// double _d_y = x; // copied from original function, collides with
// _d_y
// }
if (VDDiff.getDecl()->getDeclName() != VD->getDeclName() ||
VD->getType() != VDDiff.getDecl()->getType())
if (!isLambda &&
(VDDiff.getDecl()->getDeclName() != VD->getDeclName() ||
VD->getType() != VDDiff.getDecl()->getType()))
m_DeclReplacements[VD] = VDDiff.getDecl();

// Here, we move the declaration to the function global scope.
Expand Down
79 changes: 79 additions & 0 deletions test/Gradient/Lambdas.C
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// RUN: %cladclang %s -I%S/../../include -oLambdas.out 2>&1 | %filecheck %s
// RUN: ./Lambdas.out | %filecheck_exec %s
// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oLambdas.out
// RUN: ./Lambdas.out | %filecheck_exec %s
// CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"

double f1(double i, double j) {
auto _f = [] (double t) {
return t*t + 1.0;
};
return i + _f(j);
}

// CHECK: inline void operator_call_pullback(double t, double _d_y, double *_d_t) const;
// CHECK-NEXT: void f1_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: auto _f = []{{ ?}}(double t) {
// CHECK-NEXT: return t * t + 1.;
// CHECK-NEXT: }{{;?}}
// CHECK: {
// CHECK-NEXT: *_d_i += 1;
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: _f.operator_call_pullback(j, 1, &_r0);
// CHECK-NEXT: *_d_j += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }

double f2(double i, double j) {
auto _f = [] (double t, double k) {
return t + k;
};
double x = _f(i + j, i);
return x;
}

// CHECK: inline void operator_call_pullback(double t, double k, double _d_y, double *_d_t, double *_d_k) const;
// CHECK-NEXT: void f2_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _d_x = 0;
// CHECK-NEXT: auto _f = []{{ ?}}(double t, double k) {
// CHECK-NEXT: return t + k;
// CHECK-NEXT: }{{;?}}
// CHECK: double x = operator()(i + j, i);
// CHECK-NEXT: _d_x += 1;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: _f.operator_call_pullback(i + j, i, _d_x, &_r0, &_r1);
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r0;
// CHECK-NEXT: *_d_i += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }


int main() {
auto df1 = clad::gradient(f1);
double di = 0, dj = 0;
df1.execute(3, 4, &di, &dj);
printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 1.00 8.00

auto df2 = clad::gradient(f2);
di = 0, dj = 0;
df2.execute(3, 4, &di, &dj);
printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 2.00 1.00
}

// CHECK: inline void operator_call_pullback(double t, double _d_y, double *_d_t) const {
// CHECK-NEXT: {
// CHECK-NEXT: *_d_t += _d_y * t;
// CHECK-NEXT: *_d_t += t * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: inline void operator_call_pullback(double t, double k, double _d_y, double *_d_t, double *_d_k) const {
// CHECK-NEXT: {
// CHECK-NEXT: *_d_t += _d_y;
// CHECK-NEXT: *_d_k += _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit 22b2590

Please sign in to comment.