Skip to content

Commit

Permalink
address comments related to const pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Dec 22, 2023
1 parent a2c1a70 commit 3028b7c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 23 deletions.
24 changes: 20 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2125,7 +2125,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {cloneE, derivedE};
} else {
if (opCode != UO_LNot)
// We should not output any warning on visiting boolean conditions
// We should only output warnings on visiting boolean conditions
// when it is related to some indepdendent variable and causes
// discontinuity in the function space.
// FIXME: We should support boolean differentiation or ignore it
// completely
unsupportedOpWarn(UnOp->getEndLoc());
Expand Down Expand Up @@ -2329,7 +2331,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
utils::GetInnermostReturnExpr(Ldiff.getExpr(), ExprsToStore);

// We need to store values of derivative pointer variables in forward pass
// and restore them in reverese pass.
// and restore them in reverse pass.
if (isPointerOp) {
Expr* Edx = Ldiff.getExpr_dx();
ExprsToStore.push_back(Edx);
Expand Down Expand Up @@ -2595,8 +2597,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// expression of the corresponding pointer type.
else if (isPointerType && VD->getInit()) {
initDiff = Visit(VD->getInit());
if (initDiff.getExpr_dx())
VDDerivedInit = initDiff.getExpr_dx();
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
// If it's a pointer to a constant type, then remove the constness.
if (VD->getType()->getPointeeType().isConstQualified()) {
// first extract the pointee type
auto pointeeType = VD->getType()->getPointeeType();
// then remove the constness
pointeeType.removeLocalConst();
// then create a new pointer type with the new pointee type
VDDerivedType = m_Context.getPointerType(pointeeType);
}
VDDerivedInit = getZeroInit(VDDerivedType);
}
// Here separate behaviour for record and non-record types is only
// necessary to preserve the old tests.
Expand Down Expand Up @@ -2681,6 +2692,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE);
}
}
if (isPointerType) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
derivedVDE, initDiff.getExpr_dx());
addToCurrentBlock(assignDerivativeE, direction::forward);
}
m_Variables.emplace(VDClone, derivedVDE);

return VarDeclDiff(VDClone, VDDerived);
Expand Down
29 changes: 10 additions & 19 deletions test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -21,44 +21,34 @@ double nonMemFn(double i) {
// CHECK-NEXT: }

double minimalPointer(double x) {
double *p;
p = &x;
double* const p = &x;
*p = (*p)*(*p);
return *p; // x*x
}

// CHECK: void minimalPointer_grad(double x, clad::array_ref<double> _d_x) {
// CHECK-NEXT: double *_d_p = 0;
// CHECK-NEXT: double *_t0;
// CHECK-NEXT: double *_t1;
// CHECK-NEXT: double _t2;
// CHECK-NEXT: double *p;
// CHECK-NEXT: _t0 = p;
// CHECK-NEXT: _t1 = _d_p;
// CHECK-NEXT: double _t0;
// CHECK-NEXT: _d_p = &* _d_x;
// CHECK-NEXT: p = &x;
// CHECK-NEXT: _t2 = *p;
// CHECK-NEXT: double *const p = &x;
// CHECK-NEXT: _t0 = *p;
// CHECK-NEXT: *p = *p * (*p);
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: *_d_p += 1;
// CHECK-NEXT: {
// CHECK-NEXT: *p = _t2;
// CHECK-NEXT: *p = _t0;
// CHECK-NEXT: double _r_d0 = *_d_p;
// CHECK-NEXT: double _r0 = _r_d0 * (*p);
// CHECK-NEXT: *_d_p += _r0;
// CHECK-NEXT: double _r1 = *p * _r_d0;
// CHECK-NEXT: *_d_p += _r1;
// CHECK-NEXT: *_d_p -= _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: p = _t0;
// CHECK-NEXT: _d_p = _t1;
// CHECK-NEXT: }
// CHECK-NEXT: }

double arrayPointer(double* arr) {
double *p = arr;
double arrayPointer(const double* arr) {
const double *p = arr;
p = p + 1;
double sum = *p;
p++;
Expand All @@ -73,8 +63,8 @@ double arrayPointer(double* arr) {
return sum; // 5*arr[0] + arr[1] + 2*arr[2] + 4*arr[3] + 3*arr[4]
}

// CHECK: void arrayPointer_grad(double *arr, clad::array_ref<double> _d_arr) {
// CHECK-NEXT: double *_d_p = _d_arr;
// CHECK: void arrayPointer_grad(const double *arr, clad::array_ref<double> _d_arr) {
// CHECK-NEXT: double *_d_p = 0;
// CHECK-NEXT: double *_t0;
// CHECK-NEXT: double *_t1;
// CHECK-NEXT: double _d_sum = 0;
Expand All @@ -88,6 +78,7 @@ double arrayPointer(double* arr) {
// CHECK-NEXT: double *_t9;
// CHECK-NEXT: double *_t10;
// CHECK-NEXT: double _t11;
// CHECK-NEXT: _d_p = _d_arr;
// CHECK-NEXT: double *p = arr;
// CHECK-NEXT: _t0 = p;
// CHECK-NEXT: _t1 = _d_p;
Expand Down

0 comments on commit 3028b7c

Please sign in to comment.