Skip to content

Commit

Permalink
throw UnboundLocalError when accessing loop variables outside static …
Browse files Browse the repository at this point in the history
…for loops (#613)
  • Loading branch information
yuanming-hu authored Mar 18, 2020
1 parent 97e1b50 commit a03f70d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def static(x, *xs):
import taichi as ti
assert get_runtime(
).inside_kernel, 'ti.static can only be used inside Taichi kernels'
if isinstance(x, (bool, int, float, range, list, tuple, ti.ndrange, ti.GroupedNDRange)):
if isinstance(x, (bool, int, float, range, list, tuple, enumerate, ti.ndrange, ti.GroupedNDRange)):
return x
elif isinstance(x, ti.lang.expr.Expr) and x.ptr.is_global_var():
return x
Expand Down
11 changes: 10 additions & 1 deletion python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
from .util import to_taichi_type
import copy


class TaichiSyntaxError(Exception):
Expand Down Expand Up @@ -328,7 +329,15 @@ def visit_For(self, node):
node = ast.copy_location(t, node)
return self.visit(node) # further translate as a range for
elif is_static_for:
return node
t = self.parse_stmt('if 1: pass; del a')
t.body[0] = node
target = copy.deepcopy(node.target)
target.ctx = ast.Del()
if isinstance(target, ast.Tuple):
for tar in target.elts:
tar.ctx = ast.Del()
t.body[1].targets = [target]
return t
elif is_range_for:
loop_var = node.target.id
self.check_loop_var(loop_var)
Expand Down
22 changes: 22 additions & 0 deletions tests/python/test_lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,25 @@ def test():

for i in range(n):
assert val[i] == i + 45

@ti.must_throw(UnboundLocalError)
@ti.host_arch_only
def test_loop_var_life():
@ti.kernel
def test():
for i in ti.static(range(8)):
pass
print(i)

test()

@ti.must_throw(UnboundLocalError)
@ti.host_arch_only
def test_loop_var_life_double_iters():
@ti.kernel
def test():
for i, v in ti.static(enumerate(range(8))):
pass
print(i)

test()

0 comments on commit a03f70d

Please sign in to comment.