diff --git a/tests/python/test_ad_atomic_fwd.py b/tests/python/test_ad_atomic_fwd.py index d8952d2f8dde2..f070820ac9f1f 100644 --- a/tests/python/test_ad_atomic_fwd.py +++ b/tests/python/test_ad_atomic_fwd.py @@ -7,13 +7,13 @@ def test_ad_reduce_fwd(): N = 16 x = ti.field(dtype=ti.f32, shape=N) - loss = ti.field(dtype=ti.f32, shape=N) + loss = ti.field(dtype=ti.f32, shape=()) ti.root.lazy_dual() @ti.kernel def func(): for i in x: - loss[i] += x[i]**2 + loss[None] += x[i]**2 total_loss = 0 for i in range(N): @@ -23,10 +23,9 @@ def func(): with ti.ad.FwdMode(loss=loss, parameters=x, seed=[1.0 for _ in range(N)]): func() - total_loss_computed = 0 + assert total_loss == test_utils.approx(loss[None]) + sum = 0 for i in range(N): - total_loss_computed += loss[i] + sum += i * 2 - assert total_loss == test_utils.approx(total_loss_computed) - for i in range(N): - assert loss.dual[i] == test_utils.approx(i * 2) + assert loss.dual[None] == test_utils.approx(sum)