Skip to content

Commit

Permalink
Fix derivative code for comma expressions in forward mode
Browse files Browse the repository at this point in the history
Earlier for an expression like `(E1, E2)`, the derivative code produced `(dE1, dE2)`,
thus a statement like `double temp = (++x, x*x)` produced:
```cpp
double _d_temp = (_d_x, _d_x * x + x * _d_x);
double temp = (++x, x*x);    <--- original cloned statement
```
This meant that the execution order is `dE1 -> dE2 -> E1 -> E2` (where -> means followed by).

But this doesn't seem right, as `dE2` can depend on `E1`; hence `E1` must be computed before `dE2`.
So, the execution order should be `dE1 -> E1 -> dE2 -> E2`.

This PR generates the derivative code as `(dE1, E1, dE2)`; the original statement is changed to just `E2`.
Note that the result of both expressions is still the same as earlier (i.e. the derivative code will still have the result of `dE2`)

So, the example statement of computing `(x+1)^2`, will produce:
```cpp
double _d_temp = (_d_x, ++x , (_d_x * x + x * _d_x));
double temp = x*x;
```

closes #573
  • Loading branch information
vaithak authored and vgvassilev committed Jun 24, 2023
1 parent 702c90c commit 8dbe639
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
21 changes: 15 additions & 6 deletions lib/Differentiator/ForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1704,12 +1704,16 @@ namespace clad {
deriveDiv(Ldiff, Rdiff));
}
} else if (opCode == BO_Comma) {
if (!isUnusedResult(Ldiff.getExpr_dx()))
opDiff = BuildOp(BO_Comma,
BuildParens(Ldiff.getExpr_dx()),
// if expression is (E1, E2) then derivative is (E1', E1, E2')
// because E1 may change some variables that E2 depends on.
if (!isUnusedResult(Ldiff.getExpr_dx())) {
opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr_dx()),
BuildParens(Ldiff.getExpr()));
opDiff = BuildOp(BO_Comma, BuildParens(opDiff),
BuildParens(Rdiff.getExpr_dx()));
} else
opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()),
BuildParens(Rdiff.getExpr_dx()));
else
opDiff = Rdiff.getExpr_dx();
}
if (!opDiff) {
// FIXME: add support for other binary operators
Expand All @@ -1719,7 +1723,12 @@ namespace clad {
opDiff = folder.fold(opDiff);
// Recover the original operation from the Ldiff and Rdiff instead of
// cloning the tree.
Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr());
Expr* op;
if (opCode == BO_Comma)
// Ldiff.getExpr() is already included in opDiff.
op = Rdiff.getExpr();
else
op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr());
return StmtDiff(op, opDiff);
}

Expand Down
57 changes: 57 additions & 0 deletions test/FirstDerivative/BasicArithmeticMulDiv.C
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,46 @@ float m_6(int x) {
// CHECK-NEXT: return 0.F * x + 3.F * _d_x;
// CHECK-NEXT: }

double m_7(double x) {
// returns (x+1)^2
return (++x, x * x);
}
// CHECK: double m_7_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: return (++x , (_d_x * x + x * _d_x));
// CHECK-NEXT: }

double m_8(double x) {
// returns (x+1)^2
double temp = (++x, x * x);
return temp;
}
// CHECK: double m_8_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_temp = (++x , (_d_x * x + x * _d_x));
// CHECK-NEXT: double temp = (x * x);
// CHECK-NEXT: return _d_temp;
// CHECK-NEXT: }

double m_9(double x) {
// returns (2x)^2
return (x*=2, x * x);
}
// CHECK: double m_9_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: return (((_d_x = _d_x * 2 + x * 0) , (x *= 2)) , (_d_x * x + x * _d_x));
// CHECK-NEXT: }

double m_10(double x, bool flag) {
// if flag is true, return 4x^2, else return (x+1)^2
return flag ? (x*=2, x * x) : (x+=1, x * x);
}
// CHECK: double m_10_darg0(double x, bool flag) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: bool _d_flag = 0;
// CHECK-NEXT: return flag ? (((_d_x = _d_x * 2 + x * 0) , (x *= 2)) , (_d_x * x + x * _d_x)) : (((_d_x += 0) , (x += 1)) , (_d_x * x + x * _d_x));
// CHECK-NEXT: }

int d_1(int x) {
int y = 4;
return y / y; // == 0
Expand Down Expand Up @@ -146,6 +186,10 @@ int m_3_darg0(int x);
int m_4_darg0(int x);
double m_5_darg0(int x);
float m_6_darg0(int x);
double m_7_darg0(double x);
double m_8_darg0(double x);
double m_9_darg0(double x);
double m_10_darg0(double x, bool flag);
int d_1_darg0(int x);
int d_2_darg0(int x);
int d_3_darg0(int x);
Expand Down Expand Up @@ -173,6 +217,19 @@ int main () {
clad::differentiate(m_6, 0);
printf("Result is = %f\n", m_6_darg0(1)); // CHECK-EXEC: Result is = 3

clad::differentiate(m_7, 0);
printf("Result is = %f\n", m_7_darg0(1)); // CHECK-EXEC: Result is = 4

clad::differentiate(m_8, 0);
printf("Result is = %f\n", m_8_darg0(1)); // CHECK-EXEC: Result is = 4

clad::differentiate(m_9, 0);
printf("Result is = %f\n", m_9_darg0(1)); // CHECK-EXEC: Result is = 8

clad::differentiate(m_10, 0);
printf("Result is = %f\n", m_10_darg0(1, true)); // CHECK-EXEC: Result is = 8
printf("Result is = %f\n", m_10_darg0(1, false)); // CHECK-EXEC: Result is = 4

clad::differentiate(d_1, 0);
printf("Result is = %d\n", d_1_darg0(1)); // CHECK-EXEC: Result is = 0

Expand Down

0 comments on commit 8dbe639

Please sign in to comment.