From 632f620c893b90dd392712886bd02e9457a929de Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Wed, 2 Oct 2024 12:27:26 +0200 Subject: [PATCH] call operators as methods --- lib/Differentiator/ReverseModeVisitor.cpp | 25 ++++++++++++++++------- test/Gradient/Lambdas.C | 2 +- test/ValidCodeGen/ValidCodeGen.C | 2 +- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 39de36caa..a5e8de659 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2131,13 +2131,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(resValue, resAdjoint, resAdjoint); } // Recreate the original call expression. - if (const auto* OCE = dyn_cast(CE)) { - CallArgs.insert(CallArgs.begin(), Clone(OCE->getArg(0))); - // OCE->getArg(0)->dump(); - call = CXXOperatorCallExpr::Create( - m_Context, OCE->getOperator(), Clone(CE->getCallee()), CallArgs, - FD->getCallResultType(), OCE->getValueKind(), Loc, - CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsOverride()); + if (auto* OCE = dyn_cast(CE)) { + CXXMethodDecl* FD = const_cast( + dyn_cast(OCE->getCalleeDecl())); + + NestedNameSpecifierLoc NNS(FD->getQualifier(), + /*Data=*/nullptr); + auto DAP = DeclAccessPair::make(FD, FD->getAccess()); + auto* memberExpr = MemberExpr::Create( + m_Context, Clone(OCE->getArg(0)), /*isArrow=*/false, Loc, NNS, noLoc, + FD, DAP, FD->getNameInfo(), + /*TemplateArgs=*/nullptr, m_Context.BoundMemberTy, + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, + ExprObjectKind::OK_Ordinary CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams( + NOUR_None)); + call = m_Sema + .BuildCallToMemberFunction(getCurrentScope(), memberExpr, Loc, + CallArgs, Loc) + .get(); return StmtDiff(call); } diff --git a/test/Gradient/Lambdas.C b/test/Gradient/Lambdas.C index 82bcd4cd2..336be0835 100644 --- a/test/Gradient/Lambdas.C +++ b/test/Gradient/Lambdas.C @@ -39,7 +39,7 @@ double f2(double i, double j) { // CHECK-NEXT: return t + k; // CHECK-NEXT: }{{;?}} // CHECK: double _d_x = 0.; -// CHECK-NEXT: double x = _f(i + j, i); +// CHECK-NEXT: double x = _f.operator()(i + j, i); // CHECK-NEXT: _d_x += 1; // CHECK-NEXT: { // CHECK-NEXT: double _r0 = 0.; diff --git a/test/ValidCodeGen/ValidCodeGen.C b/test/ValidCodeGen/ValidCodeGen.C index b789a57a5..1e638be1c 100644 --- a/test/ValidCodeGen/ValidCodeGen.C +++ b/test/ValidCodeGen/ValidCodeGen.C @@ -70,7 +70,7 @@ int main() { //CHECK-NEXT: TN::Test2 t; //CHECK-NEXT: TN::Test2 _t0 = t; //CHECK-NEXT: double _d_q = 0.; -//CHECK-NEXT: double q = t[x]; +//CHECK-NEXT: double q = t.operator[](x); //CHECK-NEXT: _d_q += 1; //CHECK-NEXT: { //CHECK-NEXT: double _r0 = 0.;