Skip to content

Commit

Permalink
[autodiff] Store if condition in adstack (#6207)
Browse files Browse the repository at this point in the history
Issue: #6204 

### Brief Summary
Autodiff failed on cases where the condition of a if depends on the for
loop index. This PR makes the if condition stored in adstack to handle
these cases.
  • Loading branch information
erizmr authored Sep 30, 2022
1 parent a55ea42 commit ea4e258
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
16 changes: 16 additions & 0 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,22 @@ class AdStackAllocaJudger : public BasicStmtVisitor {
}
}

// Check whether the target serves as the condition of a if stmt
void visit(IfStmt *stmt) override {
if (is_stack_needed_)
return;

if (stmt->cond == target_alloca_) {
is_stack_needed_ = true;
return;
}

if (stmt->true_statements)
stmt->true_statements->accept(this);
if (stmt->false_statements)
stmt->false_statements->accept(this);
}

static bool run(AllocaStmt *target_alloca) {
AdStackAllocaJudger judger;
judger.target_alloca_ = target_alloca;
Expand Down
33 changes: 33 additions & 0 deletions tests/python/test_ad_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,36 @@ def func():
impl.call_internal("test_stack")

func()


# FIXME: There is no tensor constant (brought by dynamic index) until the whole mat/vec refactor is done, which will potentially break the adstack.
# Temporially disable the dynamic index, will make workaround to handle tensor constant in other PRs
@test_utils.test(dynamic_index=False)
def test_if_condition_depend_on_for_loop_index():
scalar = lambda: ti.field(dtype=ti.f32)
vec = lambda: ti.Vector.field(3, dtype=ti.f32)

pos = vec()
F = vec()
f_bend = scalar()
loss_n = scalar()
ti.root.dense(ti.ij, (10, 10)).place(pos, F)
ti.root.dense(ti.i, 1).place(f_bend)
ti.root.place(loss_n)
ti.root.lazy_grad()

@ti.kernel
def simulation(t: ti.i32):
for i, j in pos:
coord = ti.Vector([i, j])
for n in range(12):
f = ti.Vector([0.0, 0.0, 0.0])
if n < 4:
f = ti.Vector([1.0, 1.0, 1.0])
else:
f = f_bend[0] * pos[coord]
F[coord] += f
pos[coord] += 1.0 * t

with ti.ad.Tape(loss=loss_n):
simulation(5)

0 comments on commit ea4e258

Please sign in to comment.