Skip to content

Commit

Permalink
[autodiff] Support binary operators for forward mode
Browse files Browse the repository at this point in the history
  • Loading branch information
erizmr committed Jul 11, 2022
1 parent 0dfd166 commit 9ff2b69
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
20 changes: 20 additions & 0 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,26 @@ class MakeDual : public ADTransform {
accumulate(bin, div(dual(bin->lhs), bin->rhs));
accumulate(bin, negate(div(mul(dual(bin->rhs), bin->lhs),
mul(bin->rhs, bin->rhs))));
} else if (bin->op_type == BinaryOpType::atan2) {
auto numerator = add(sqr(bin->lhs), sqr(bin->rhs));
accumulate(bin, div(mul(bin->rhs, dual(bin->lhs)), numerator));
accumulate(bin, negate(div(mul(bin->lhs, dual(bin->rhs)), numerator)));
} else if (bin->op_type == BinaryOpType::pow) {
// d (x ^ y) = x ^ (y-1) * (y * dx + log(x) * x * dy)
auto common_coeff =
pow(bin->lhs, sub(bin->rhs, constant(1))); // x ^ (y-1)
accumulate(bin, mul(dual(bin->lhs), mul(bin->rhs, common_coeff)));
accumulate(bin, mul(dual(bin->rhs),
mul(log(bin->lhs), mul(bin->lhs, common_coeff))));
} else if (bin->op_type == BinaryOpType::min ||
bin->op_type == BinaryOpType::max) {
auto cmp = bin->op_type == BinaryOpType::min ? cmp_lt(bin->lhs, bin->rhs)
: cmp_lt(bin->rhs, bin->lhs);
auto zero = insert<ConstStmt>(TypedConstant(bin->ret_type));
accumulate(bin, sel(cmp, dual(bin->lhs), zero));
accumulate(bin, sel(cmp, zero, dual(bin->rhs)));
} else if (bin->op_type == BinaryOpType::floordiv) {
// do nothing
} else if (is_comparison(bin->op_type) || is_bit_op(bin->op_type)) {
// do nothing
} else {
Expand Down
41 changes: 36 additions & 5 deletions tests/python/test_ad_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,10 @@ def test_unary(tifunc, npfunc):
(lambda x: ti.max(1, x), lambda x: np.maximum(1, x)),
])
@if_has_autograd
@test_utils.test()
@test_utils.test(exclude=[ti.cc])
def test_minmax(tifunc, npfunc):
grad_test(tifunc, npfunc)
grad_test_fwd(tifunc, npfunc)


@if_has_autograd
Expand All @@ -188,44 +189,74 @@ def func2():
func2.grad()


@if_has_autograd
@test_utils.test()
def test_mod_fwd():
x = ti.field(ti.f32)
y = ti.field(ti.f32)

ti.root.dense(ti.i, 1).place(x, y)
ti.root.lazy_dual()

@ti.kernel
def func():
y[0] = x[0] % 3

@ti.kernel
def func2():
ti.atomic_add(y[0], x[0] % 3)

with ti.ad.FwdMode(loss=y, parameters=x, seed=[1.0]):
func()
func2()


@pytest.mark.parametrize('tifunc,npfunc', [
(lambda x: ti.atan2(0.4, x), lambda x: np.arctan2(0.4, x)),
(lambda y: ti.atan2(y, 0.4), lambda y: np.arctan2(y, 0.4)),
])
@if_has_autograd
@test_utils.test()
@test_utils.test(exclude=[ti.cc])
def test_atan2(tifunc, npfunc):
grad_test(tifunc, npfunc)
grad_test_fwd(tifunc, npfunc)


@pytest.mark.parametrize('tifunc,npfunc', [
(lambda x: ti.atan2(0.4, x), lambda x: np.arctan2(0.4, x)),
(lambda y: ti.atan2(y, 0.4), lambda y: np.arctan2(y, 0.4)),
])
@if_has_autograd
@test_utils.test(require=ti.extension.data64, default_fp=ti.f64)
@test_utils.test(require=ti.extension.data64,
default_fp=ti.f64,
exclude=[ti.cc])
def test_atan2_f64(tifunc, npfunc):
grad_test(tifunc, npfunc)
grad_test_fwd(tifunc, npfunc)


@pytest.mark.parametrize('tifunc,npfunc', [
(lambda x: 0.4**x, lambda x: np.power(0.4, x)),
(lambda y: y**0.4, lambda y: np.power(y, 0.4)),
])
@if_has_autograd
@test_utils.test()
@test_utils.test(exclude=[ti.cc])
def test_pow(tifunc, npfunc):
grad_test(tifunc, npfunc)
grad_test_fwd(tifunc, npfunc)


@pytest.mark.parametrize('tifunc,npfunc', [
(lambda x: 0.4**x, lambda x: np.power(0.4, x)),
(lambda y: y**0.4, lambda y: np.power(y, 0.4)),
])
@if_has_autograd
@test_utils.test(require=ti.extension.data64, default_fp=ti.f64)
@test_utils.test(require=ti.extension.data64,
default_fp=ti.f64,
exclude=[ti.cc])
def test_pow_f64(tifunc, npfunc):
grad_test(tifunc, npfunc)
grad_test_fwd(tifunc, npfunc)


@test_utils.test()
Expand Down

0 comments on commit 9ff2b69

Please sign in to comment.