diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index a526a6d58..f926488e2 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -190,6 +190,20 @@ namespace clad { FunctionDecl* replacementFD = OverloadedFD ? OverloadedFD : FD; + auto codeArgIdx = -1; + auto derivedFnArgIdx = -1; + auto idx = 0; + for (auto* arg : call->arguments()) { + if (auto* default_arg_expr = dyn_cast(arg)) { + std::string argName = default_arg_expr->getParam()->getNameAsString(); + if (argName == "derivedFn") + derivedFnArgIdx = idx; + else if (argName == "code") + codeArgIdx = idx; + } + ++idx; + } + // Index of "CUDAkernel" parameter: int numArgs = static_cast(call->getNumArgs()); if (numArgs > 4) { @@ -204,8 +218,6 @@ namespace clad { call->setArg(kernelArgIdx, cudaKernelFlag); numArgs--; } - auto codeArgIdx = numArgs - 1; - auto derivedFnArgIdx = numArgs - 2; // Create ref to generated FD. DeclRefExpr* DRE = @@ -221,31 +233,35 @@ namespace clad { if (isa(DRE->getDecl())) DRE->setValueKind(CLAD_COMPAT_ExprValueKind_R_or_PR_Value); - // Add the "&" operator - auto newUnOp = - SemaRef.BuildUnaryOp(nullptr, noLoc, UnaryOperatorKind::UO_AddrOf, DRE) - .get(); - call->setArg(derivedFnArgIdx, newUnOp); - - // Update the code parameter. - if (CXXDefaultArgExpr* Arg - = dyn_cast(call->getArg(codeArgIdx))) { - clang::LangOptions LangOpts; - LangOpts.CPlusPlus = true; - clang::PrintingPolicy Policy(LangOpts); - Policy.Bool = true; - - std::string s; - llvm::raw_string_ostream Out(s); - FD->print(Out, Policy); - Out.flush(); - - StringLiteral* SL = utils::CreateStringLiteral(C, Out.str()); - Expr* newArg = - SemaRef.ImpCastExprToType(SL, - Arg->getType(), - CK_ArrayToPointerDecay).get(); - call->setArg(codeArgIdx, newArg); + if (derivedFnArgIdx != -1) { + // Add the "&" operator + auto newUnOp = + SemaRef + .BuildUnaryOp(nullptr, noLoc, UnaryOperatorKind::UO_AddrOf, DRE) + .get(); + call->setArg(derivedFnArgIdx, newUnOp); + } + + // Update the code parameter if it was found. + if (codeArgIdx != -1) { + if (auto* Arg = dyn_cast(call->getArg(codeArgIdx))) { + clang::LangOptions LangOpts; + LangOpts.CPlusPlus = true; + clang::PrintingPolicy Policy(LangOpts); + Policy.Bool = true; + + std::string s; + llvm::raw_string_ostream Out(s); + FD->print(Out, Policy); + Out.flush(); + + StringLiteral* SL = utils::CreateStringLiteral(C, Out.str()); + Expr* newArg = + SemaRef + .ImpCastExprToType(SL, Arg->getType(), CK_ArrayToPointerDecay) + .get(); + call->setArg(codeArgIdx, newArg); + } } }