Skip to content

Commit

Permalink
Fix gradient computation of higher order functions
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Oct 25, 2023
1 parent ac300de commit 72a3890
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 7 deletions.
23 changes: 16 additions & 7 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1345,9 +1345,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
// Create the (_d_param[idx] += dfdx) statement.
if (dfdx()) {
Expr* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
// FIXME: not sure if this is generic.
// Don't update derivatives of non-record types.
if (!decl->getType()->isRecordType()) {
auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
}
}
return StmtDiff(clonedDRE, it->second, it->second);
}
Expand Down Expand Up @@ -1694,10 +1698,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
else
gradArgExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative);
} else {
// Declare: diffArgType _grad = 0;
gradVarDecl = BuildVarDecl(
PVD->getType(), gradVarII,
ConstantFolder::synthesizeLiteral(PVD->getType(), m_Context, 0));
// Declare: diffArgType _grad;
Expr* initVal = nullptr;
if (!PVD->getType()->isRecordType()) {
// If the argument is not a class type, then initialize the grad
// variable with 0.
initVal =
ConstantFolder::synthesizeLiteral(PVD->getType(), m_Context, 0);
}
gradVarDecl = BuildVarDecl(PVD->getType(), gradVarII, initVal);
// Pass the address of the declared variable
gradVarExpr = BuildDeclRef(gradVarDecl);
gradArgExpr =
Expand Down
89 changes: 89 additions & 0 deletions test/Gradient/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,18 @@ double CallFunctor(double i, double j) {
return E(i, j);
}

// A function taking functor as an argument.
template<typename Func>
double FunctorAsArg(Func fn, double i, double j) {
return fn(i, j);
}

// A wrapper for function taking functor as an argument.
double FunctorAsArgWrapper(double i, double j) {
Experiment E(3, 5);
return FunctorAsArg(E, i, j);
}

#define INIT(E) \
auto E##_grad = clad::gradient(&E); \
auto E##Ref_grad = clad::gradient(E);
Expand Down Expand Up @@ -332,4 +344,81 @@ int main() {
double di = 0, dj = 0;
CallFunctor_grad.execute(7, 9, &di, &dj);
printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 27.00 21.00

// CHECK: void FunctorAsArg_grad(Experiment fn, double i, double j, clad::array_ref<Experiment> _d_fn, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: Experiment _t2;
// CHECK-NEXT: _t0 = i;
// CHECK-NEXT: _t1 = j;
// CHECK-NEXT: _t2 = fn;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: double _grad1 = 0.;
// CHECK-NEXT: _t2.operator_call_pullback(_t0, _t1, 1, &(* _d_fn), &_grad0, &_grad1);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_i += _r0;
// CHECK-NEXT: double _r1 = _grad1;
// CHECK-NEXT: * _d_j += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// testing differentiating a function taking functor as an argument
auto FunctorAsArg_grad = clad::gradient(FunctorAsArg<Experiment>);
di = 0, dj = 0;
Experiment E_temp(3, 5), dE_temp;
FunctorAsArg_grad.execute(E_temp, 7, 9, &dE_temp, &di, &dj);
printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 27.00 21.00

// CHECK: void FunctorAsArg_pullback(Experiment fn, double i, double j, double _d_y, clad::array_ref<Experiment> _d_fn, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: Experiment _t2;
// CHECK-NEXT: _t0 = i;
// CHECK-NEXT: _t1 = j;
// CHECK-NEXT: _t2 = fn;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: double _grad1 = 0.;
// CHECK-NEXT: _t2.operator_call_pullback(_t0, _t1, _d_y, &(* _d_fn), &_grad0, &_grad1);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_i += _r0;
// CHECK-NEXT: double _r1 = _grad1;
// CHECK-NEXT: * _d_j += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void FunctorAsArgWrapper_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: Experiment _d_E({});
// CHECK-NEXT: Experiment _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: double _t2;
// CHECK-NEXT: Experiment E(3, 5);
// CHECK-NEXT: _t0 = E
// CHECK-NEXT: _t1 = i;
// CHECK-NEXT: _t2 = j;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: Experiment _grad0;
// CHECK-NEXT: double _grad1 = 0.;
// CHECK-NEXT: double _grad2 = 0.;
// CHECK-NEXT: FunctorAsArg_pullback(_t0, _t1, _t2, 1, &_grad0, &_grad1, &_grad2);
// CHECK-NEXT: Experiment _r0(_grad0);
// CHECK-NEXT: double _r1 = _grad1;
// CHECK-NEXT: * _d_i += _r1;
// CHECK-NEXT: double _r2 = _grad2;
// CHECK-NEXT: * _d_j += _r2;
// CHECK-NEXT: }
// CHECK-NEXT: }

// testing differentiating a wrapper for function taking functor as an argument
auto FunctorAsArgWrapper_grad = clad::gradient(FunctorAsArgWrapper);
di = 0, dj = 0;
FunctorAsArgWrapper_grad.execute(7, 9, &di, &dj);
printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 27.00 21.00
}

0 comments on commit 72a3890

Please sign in to comment.