Skip to content

Commit

Permalink
Make cast of enum assignment static C++ cast and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Oct 11, 2024
1 parent 1c9b121 commit fd614ef
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 4 deletions.
11 changes: 7 additions & 4 deletions lib/Differentiator/ConstantFolder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,13 @@ namespace clad {
QT->isSignedIntegerOrEnumerationType());
Result = clad::synthesizeLiteral(
dyn_cast<EnumType>(QT)->getDecl()->getIntegerType(), C, APVal);
Expr* cast = ImplicitCastExpr::Create(
C, QT, clang::CastKind::CK_IntegralCast, Result, nullptr,
CLAD_COMPAT_ExprValueKind_R_or_PR_Value
CLAD_COMPAT_CLANG12_CastExpr_DefaultFPO);
SourceLocation noLoc;
Expr* cast = CXXStaticCastExpr::Create(
C, QT, CLAD_COMPAT_ExprValueKind_R_or_PR_Value,
clang::CastKind::CK_IntegralCast, Result, nullptr,
C.getTrivialTypeSourceInfo(QT, noLoc)
CLAD_COMPAT_CLANG12_CastExpr_DefaultFPO,
noLoc, noLoc, SourceRange());
Result = cast;
} else if (QT->isPointerType()) {
Result = clad::synthesizeLiteral(QT, C);
Expand Down
153 changes: 153 additions & 0 deletions test/Gradient/Switch.C
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,146 @@ double fn7(double u, double v) {
// CHECK-NEXT: }
// CHECK-NEXT: }

enum Op {
Add,
Sub,
Mul,
Div
};

double fn24(double x, double y, Op op) {
double res = 0;
switch (op) {
case Add:
res = x + y;
break;
case Sub:
res = x - y;
break;
case Mul:
res = x * y;
break;
case Div:
res = x / y;
break;
}
return res;
}

// CHECK: void fn24_grad_0_1(double x, double y, Op op, double *_d_x, double *_d_y) {
// CHECK-NEXT: Op _d_op = static_cast<Op>(0U);
// CHECK-NEXT: Op _cond0;
// CHECK-NEXT: double _t0;
// CHECK-NEXT: clad::tape<unsigned long> _t1 = {};
// CHECK-NEXT: double _t2;
// CHECK-NEXT: double _t3;
// CHECK-NEXT: double _t4;
// CHECK-NEXT: double _d_res = 0.;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: {
// CHECK-NEXT: _cond0 = op;
// CHECK-NEXT: switch (_cond0) {
// CHECK-NEXT: {
// CHECK-NEXT: case Add:
// CHECK-NEXT: res = x + y;
// CHECK-NEXT: _t0 = res;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::push(_t1, 1UL);
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: case Sub:
// CHECK-NEXT: res = x - y;
// CHECK-NEXT: _t2 = res;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::push(_t1, 2UL);
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: case Mul:
// CHECK-NEXT: res = x * y;
// CHECK-NEXT: _t3 = res;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::push(_t1, 3UL);
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: case Div:
// CHECK-NEXT: res = x / y;
// CHECK-NEXT: _t4 = res;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::push(_t1, 4UL);
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: clad::push(_t1, 5UL);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: {
// CHECK-NEXT: switch (clad::pop(_t1)) {
// CHECK-NEXT: case 5UL:
// CHECK-NEXT: ;
// CHECK-NEXT: case 4UL:
// CHECK-NEXT: ;
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: res = _t4;
// CHECK-NEXT: double _r_d3 = _d_res;
// CHECK-NEXT: _d_res = 0.;
// CHECK-NEXT: *_d_x += _r_d3 / y;
// CHECK-NEXT: double _r0 = _r_d3 * -(x / (y * y));
// CHECK-NEXT: _d_y += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: if (Div == _cond0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: case 3UL:
// CHECK-NEXT: ;
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: res = _t3;
// CHECK-NEXT: double _r_d2 = _d_res;
// CHECK-NEXT: _d_res = 0.;
// CHECK-NEXT: *_d_x += _r_d2 * y;
// CHECK-NEXT: _d_y += x * _r_d2;
// CHECK-NEXT: }
// CHECK-NEXT: if (Mul == _cond0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: case 2UL:
// CHECK-NEXT: ;
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: res = _t2;
// CHECK-NEXT: double _r_d1 = _d_res;
// CHECK-NEXT: _d_res = 0.;
// CHECK-NEXT: *_d_x += _r_d1;
// CHECK-NEXT: _d_y += -_r_d1;
// CHECK-NEXT: }
// CHECK-NEXT: if (Sub == _cond0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: case 1UL:
// CHECK-NEXT: ;
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: res = _t0;
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: _d_res = 0.;
// CHECK-NEXT: *_d_x += _r_d0;
// CHECK-NEXT: _d_y += _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: if (Add == _cond0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT:}


#define TEST_2(F, x, y) \
{ \
Expand All @@ -691,6 +831,14 @@ double fn7(double u, double v) {
printf("{%.2f, %.2f}\n", result[0], result[1]); \
}

#define TEST_2_Op(F, x, y, op) \
{ \
result[0] = result[1] = 0; \
auto d_##F = clad::gradient(F, "x, y"); \
d_##F.execute(x, y, op, result, result + 1); \
printf("{%.2f, %.2f}\n", result[0], result[1]); \
}

int main() {
double result[2] = {};

Expand All @@ -705,4 +853,9 @@ int main() {

TEST_GRADIENT(fn6, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {5.00, 3.00}
TEST_GRADIENT(fn7, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {3.00, 2.00}

TEST_2_Op(fn24, 3, 5, Add); // CHECK-EXEC: {1.00, 1.00}
TEST_2_Op(fn24, 3, 5, Sub); // CHECK-EXEC: {1.00, -1.00}
TEST_2_Op(fn24, 3, 5, Mul); // CHECK-EXEC: {5.00, 3.00}
TEST_2_Op(fn24, 3, 5, Div); // CHECK-EXEC: {0.20, -0.12}
}

0 comments on commit fd614ef

Please sign in to comment.