Skip to content

Commit

Permalink
update the test
Browse files Browse the repository at this point in the history
  • Loading branch information
erizmr committed Jun 29, 2022
1 parent 70efd40 commit 3642f9b
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions tests/python/test_ad_atomic_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ def test_ad_reduce_fwd():
N = 16

x = ti.field(dtype=ti.f32, shape=N)
loss = ti.field(dtype=ti.f32, shape=N)
loss = ti.field(dtype=ti.f32, shape=())
ti.root.lazy_dual()

@ti.kernel
def func():
for i in x:
loss[i] += x[i]**2
loss[None] += x[i]**2

total_loss = 0
for i in range(N):
Expand All @@ -23,10 +23,9 @@ def func():
with ti.ad.FwdMode(loss=loss, parameters=x, seed=[1.0 for _ in range(N)]):
func()

total_loss_computed = 0
assert total_loss == test_utils.approx(loss[None])
sum = 0
for i in range(N):
total_loss_computed += loss[i]
sum += i * 2

assert total_loss == test_utils.approx(total_loss_computed)
for i in range(N):
assert loss.dual[i] == test_utils.approx(i * 2)
assert loss.dual[None] == test_utils.approx(sum)

0 comments on commit 3642f9b

Please sign in to comment.