Skip to content

Commit

Permalink
call operators as methods
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Oct 2, 2024
1 parent f16011f commit 632f620
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
25 changes: 18 additions & 7 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2131,13 +2131,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(resValue, resAdjoint, resAdjoint);
} // Recreate the original call expression.

if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE)) {
CallArgs.insert(CallArgs.begin(), Clone(OCE->getArg(0)));
// OCE->getArg(0)->dump();
call = CXXOperatorCallExpr::Create(
m_Context, OCE->getOperator(), Clone(CE->getCallee()), CallArgs,
FD->getCallResultType(), OCE->getValueKind(), Loc,
CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsOverride());
if (auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE)) {
CXXMethodDecl* FD = const_cast<CXXMethodDecl*>(
dyn_cast<CXXMethodDecl>(OCE->getCalleeDecl()));

NestedNameSpecifierLoc NNS(FD->getQualifier(),
/*Data=*/nullptr);
auto DAP = DeclAccessPair::make(FD, FD->getAccess());
auto* memberExpr = MemberExpr::Create(
m_Context, Clone(OCE->getArg(0)), /*isArrow=*/false, Loc, NNS, noLoc,
FD, DAP, FD->getNameInfo(),
/*TemplateArgs=*/nullptr, m_Context.BoundMemberTy,
CLAD_COMPAT_ExprValueKind_R_or_PR_Value,
ExprObjectKind::OK_Ordinary CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams(
NOUR_None));
call = m_Sema
.BuildCallToMemberFunction(getCurrentScope(), memberExpr, Loc,
CallArgs, Loc)
.get();
return StmtDiff(call);
}

Expand Down
2 changes: 1 addition & 1 deletion test/Gradient/Lambdas.C
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ double f2(double i, double j) {
// CHECK-NEXT: return t + k;
// CHECK-NEXT: }{{;?}}
// CHECK: double _d_x = 0.;
// CHECK-NEXT: double x = _f(i + j, i);
// CHECK-NEXT: double x = _f.operator()(i + j, i);
// CHECK-NEXT: _d_x += 1;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
Expand Down
2 changes: 1 addition & 1 deletion test/ValidCodeGen/ValidCodeGen.C
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ int main() {
//CHECK-NEXT: TN::Test2<double> t;
//CHECK-NEXT: TN::Test2<double> _t0 = t;
//CHECK-NEXT: double _d_q = 0.;
//CHECK-NEXT: double q = t[x];
//CHECK-NEXT: double q = t.operator[](x);
//CHECK-NEXT: _d_q += 1;
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 0.;
Expand Down

0 comments on commit 632f620

Please sign in to comment.