Skip to content

Commit

Permalink
[Autodiff] Implement ti.ad.no_grad to skip autograd (#4751)
Browse files Browse the repository at this point in the history
* [Autodiff]: implement ti.ad.no_gradto skip autograd

* [Autodiff]: update ti.ad.no_grad tests and remove comment
  • Loading branch information
shwnyao authored Apr 13, 2022
1 parent 7319eb4 commit dc3e96c
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
38 changes: 38 additions & 0 deletions python/taichi/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,41 @@ def decorated(*args, **kwargs):
return decorated

return decorator


def no_grad(func):
"""A decorator for python function to skip gradient calculation within Taichi's
autodiff system, e.g. `ti.Tape()` and `kernel.grad()`.
This decorator forces Taichi's autodiff system to use an empty gradient function
for the decorated function.
Args:
fn (Callable): The python function to be decorated.
Returns:
Callable: The decorated function.
Example::
>>> @ti.kernel
>>> def multiply(a: ti.float32):
>>> for I in ti.grouped(x):
>>> y[I] = x[I] * a
>>>
>>> @ti.no_grad
>>> def foo(a):
>>> multiply(a)"""
def decorated(*args, **kwargs):
impl.get_runtime().grad_replaced = True
if impl.get_runtime().target_tape:
impl.get_runtime().target_tape.insert(decorated, args)
try:
func(*args, **kwargs)
finally:
impl.get_runtime().grad_replaced = False

def placeholder(*args, **kwargs):
return

decorated.grad = placeholder
return decorated
54 changes: 54 additions & 0 deletions tests/python/test_customized_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,57 @@ def backward(mul):

with ti.Tape(loss=total):
func(4)


@test_utils.test()
def test_customized_kernels_tape_no_grad():
x = ti.field(ti.f32)
total = ti.field(ti.f32)

n = 128

ti.root.dense(ti.i, n).place(x)
ti.root.place(total)
ti.root.lazy_grad()

@ti.kernel
def func(mul: ti.f32):
for i in range(n):
ti.atomic_add(total[None], x[i] * mul)

@ti.ad.no_grad
def forward(mul):
func(mul)
func(mul)

with ti.Tape(loss=total):
forward(4)
func(5)
assert x.grad[0] == 5


@test_utils.test()
def test_customized_kernels_grad_no_grad():
x = ti.field(ti.f32)
total = ti.field(ti.f32)

n = 128

ti.root.dense(ti.i, n).place(x)
ti.root.place(total)
ti.root.lazy_grad()

@ti.kernel
def func(mul: ti.f32):
for i in range(n):
ti.atomic_add(total[None], x[i] * mul)

@ti.ad.no_grad
def forward(mul):
func(mul)
func(mul)

total.grad[None] = 1
forward(4)
forward.grad(4)
assert x.grad[0] == 0

0 comments on commit dc3e96c

Please sign in to comment.