From 3347684a3c8707b068ba73016bf9ffa3cc9fa9b5 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 15 Feb 2024 09:02:23 +0100 Subject: [PATCH] Fix passing pointers as call arguments fixes #735, #636 --- include/clad/Differentiator/CladUtils.h | 2 + .../clad/Differentiator/ReverseModeVisitor.h | 2 + lib/Differentiator/CladUtils.cpp | 7 ++++ lib/Differentiator/ReverseModeVisitor.cpp | 38 ++++++++++++++----- test/Arrays/ArrayInputsReverseMode.C | 4 +- test/Gradient/FunctionCalls.C | 35 +++++++++++++++-- 6 files changed, 74 insertions(+), 14 deletions(-) diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 715707e0d..5690c3913 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -326,6 +326,8 @@ namespace clad { bool ContainsFunctionCalls(const clang::Stmt* E); void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt); + + bool IsLiteral(const clang::Expr* E); } // namespace utils } // namespace clad diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 52473036b..1cd1c0bfa 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -380,6 +380,8 @@ namespace clad { VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); StmtDiff VisitSubstNonTypeTemplateParmExpr( const clang::SubstNonTypeTemplateParmExpr* NTTP); + StmtDiff + VisitCXXNullPtrLiteralExpr(const clang::CXXNullPtrLiteralExpr* NPE); /// A helper method to differentiate a single Stmt in the reverse mode. /// Internally, calls Visit(S, expr). Its result is wrapped into a diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 526ba31e4..fbddd535b 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -634,5 +634,12 @@ namespace clad { else cast(SC)->setSubStmt(subStmt); } + + bool IsLiteral(const clang ::Expr* E) { + return isa(E) || isa(E) || + isa(E) || isa(E) || + isa(E) || isa(E) || + isa(E); + } } // namespace utils } // namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index bddb9de11..6e3f1af48 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1166,6 +1166,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, SL->getType(), utils::GetValidSLoc(m_Sema))); } + StmtDiff ReverseModeVisitor::VisitCXXNullPtrLiteralExpr( + const CXXNullPtrLiteralExpr* NPE) { + return StmtDiff(Clone(NPE), Clone(NPE)); + } + StmtDiff ReverseModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { // Initially, df/df = 1. const Expr* value = RS->getRetValue(); @@ -1360,7 +1365,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } StmtDiff ReverseModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) { - return StmtDiff(Clone(IL)); + auto* Constant0 = + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); + return StmtDiff(Clone(IL), Constant0); } StmtDiff ReverseModeVisitor::VisitFloatingLiteral(const FloatingLiteral* FL) { @@ -1461,22 +1468,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // argument by reference. passByRef = false; } + QualType argDiffType; // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly // modified by the derived callee function. if (passByRef) { argDiff = Visit(arg); + Expr* dArg = nullptr; + argDiffType = argDiff.getExpr()->getType(); QualType argResultValueType = - utils::GetValueType(argDiff.getExpr()->getType()) - .getNonReferenceType(); + utils::GetValueType(argDiffType).getNonReferenceType(); // Create ArgResult variable for each reference argument because it is // required by error estimator. For automatic differentiation, we do not need // to create ArgResult variable for arguments passed by reference. // ``` // _r0 = _d_a; // ``` - Expr* dArg = nullptr; - if (utils::isArrayOrPointerType(argDiff.getExpr()->getType())) { + if (argDiff.getExpr_dx() && utils::IsLiteral(argDiff.getExpr_dx())) { + dArg = StoreAndRef(argDiff.getExpr_dx(), arg->getType(), + direction::reverse, "_r", + /*forceDeclCreation=*/true); + } else if (argDiffType->isArrayType()) { Expr* init = argDiff.getExpr_dx(); if (isa(argDiff.getExpr_dx()->getType())) init = utils::BuildCladArrayInitByConstArray(m_Sema, @@ -1486,6 +1498,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, direction::reverse, "_r", /*forceDeclCreation=*/true, VarDecl::InitializationStyle::CallInit); + } else if (argDiffType->isPointerType()) { + dArg = StoreAndRef(argDiff.getExpr_dx(), argDiffType, + direction::reverse, "_r", + /*forceDeclCreation=*/true); } else { dArg = StoreAndRef(argDiff.getExpr_dx(), argResultValueType, direction::reverse, "_r", @@ -1511,6 +1527,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, cast(cast(dArg)->getDecl())); // Visit using uninitialized reference. argDiff = Visit(arg, dArg); + argDiffType = argDiff.getExpr()->getType(); } // FIXME: We may use same argDiff.getExpr_dx at two places. This can @@ -1536,7 +1553,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // FIXME: We cannot use GlobalStoreAndRef to store a whole array so now // arrays are not stored. StmtDiff argDiffStore; - if (passByRef && !argDiff.getExpr()->getType()->isArrayType()) + if (passByRef && !argDiffType->isArrayType() && + !argDiff.getExpr()->isEvaluatable(m_Context)) argDiffStore = GlobalStoreAndRef(argDiff.getExpr(), "_t", /*force=*/true); else @@ -1567,7 +1585,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // ``` // FIXME: We cannot use GlobalStoreAndRef to store a whole array so now // arrays are not stored. - if (passByRef && !argDiff.getExpr()->getType()->isArrayType()) { + if (passByRef && !argDiffType->isArrayType()) { if (isInsideLoop) { // Add tape push expression. We need to explicitly add it here because // we cannot add it as call expression argument -- we need to pass the @@ -1606,7 +1624,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // inside loop and outside loop cases separately. Expr* newArgE = Visit(arg).getExpr(); argDiffStore = {newArgE, argDiffLocalE}; - } else { + } else if (isa(argDiff.getExpr())) { // Restore args auto& block = getCurrentBlock(direction::reverse); auto* op = BuildOp(BinaryOperatorKind::BO_Assign, argDiff.getExpr(), @@ -1734,7 +1752,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, argDerivative = BuildDeclRef(derivativeArrayRefVD); } if ((argDerivative != nullptr) && - isCladArrayType(argDerivative->getType())) + (isCladArrayType(argDerivative->getType()) || + argDerivative->getType()->isPointerType() || + !argDerivative->isLValue())) gradArgExpr = argDerivative; else gradArgExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative); diff --git a/test/Arrays/ArrayInputsReverseMode.C b/test/Arrays/ArrayInputsReverseMode.C index c8135e207..944eec1a1 100644 --- a/test/Arrays/ArrayInputsReverseMode.C +++ b/test/Arrays/ArrayInputsReverseMode.C @@ -54,7 +54,7 @@ double f(double *arr) { //CHECK-NEXT: arr = _t0; //CHECK-NEXT: int _grad1 = 0; //CHECK-NEXT: addArr_pullback(_t0, 3, 1, _d_arr, &_grad1); -//CHECK-NEXT: clad::array _r0(_d_arr); +//CHECK-NEXT: double *_r0 = _d_arr; //CHECK-NEXT: int _r1 = _grad1; //CHECK-NEXT: } //CHECK-NEXT: } @@ -473,7 +473,7 @@ double func8(double i, double *arr, int n) { //CHECK-NEXT: helper2_pullback(i, _t2, n, _r_d1, &_grad0, _d_arr, &_grad2); //CHECK-NEXT: double _r0 = _grad0; //CHECK-NEXT: * _d_i += _r0; -//CHECK-NEXT: clad::array _r1(_d_arr); +//CHECK-NEXT: double *_r1 = _d_arr; //CHECK-NEXT: int _r2 = _grad2; //CHECK-NEXT: * _d_n += _r2; //CHECK-NEXT: } diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 7420f9402..62062f813 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -305,7 +305,7 @@ double fn4(double* arr, int n) { // CHECK-NEXT: arr = _t1; // CHECK-NEXT: int _grad1 = 0; // CHECK-NEXT: sum_pullback(_t1, n, _r_d0, _d_arr, &_grad1); -// CHECK-NEXT: clad::array _r0(_d_arr); +// CHECK-NEXT: double *_r0 = _d_arr; // CHECK-NEXT: int _r1 = _grad1; // CHECK-NEXT: * _d_n += _r1; // CHECK-NEXT: } @@ -348,7 +348,7 @@ double fn5(double* arr, int n) { // CHECK-NEXT: { // CHECK-NEXT: arr = _t0; // CHECK-NEXT: modify2_pullback(_t0, _d_temp, _d_arr); -// CHECK-NEXT: clad::array _r0(_d_arr); +// CHECK-NEXT: double *_r0 = _d_arr; // CHECK-NEXT: } // CHECK-NEXT: } @@ -497,7 +497,7 @@ double fn8(double x, double y) { // CHECK-NEXT: double _r0 = _grad0; // CHECK-NEXT: * _d_x += _r0; // CHECK-NEXT: char _r1 = _grad1; -// CHECK-NEXT: clad::array _r2({"", 3UL}); +// CHECK-NEXT: const char *_r2 = ""; // CHECK-NEXT: * _d_y += _t3 * 1 * _t0 * _t1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -645,6 +645,33 @@ double fn11(double x, double y) { // CHECK-NEXT: } // CHECK-NEXT: } +double do_nothing(double* u, double* v, double* w) { + return u[0]; +} + +// CHECK: void do_nothing_pullback(double *u, double *v, double *w, double _d_y, clad::array_ref _d_u, clad::array_ref _d_v, clad::array_ref _d_w) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_u[0] += _d_y; +// CHECK-NEXT: } + +double fn12(double x, double y) { + return do_nothing(&x, nullptr, 0); +} + +// CHECK: void fn12_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { +// CHECK-NEXT: double *_t0; +// CHECK-NEXT: _t0 = &x; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: do_nothing_pullback(_t0, nullptr, 0, 1, &* _d_x, nullptr, 0); +// CHECK-NEXT: double *_r0 = &* _d_x; +// CHECK-NEXT: {{(std::)?}}nullptr_t _r1 = nullptr; +// CHECK-NEXT: double *_r2 = 0; +// CHECK-NEXT: } +// CHECK-NEXT: } + template void reset(T* arr, int n) { for (int i=0; i