Skip to content

Commit

Permalink
[Lang] Support template arguments for ti.func (#1043)
Browse files Browse the repository at this point in the history
* support template arg in ti.func and slight refactor

* add non-working test

* [skip ci] enforce code format

* working tests

* [skip ci] enforce code format

* [skip ci] new test

* [skip ci] enforce code format

* [skip ci] add comments in transformer

Co-authored-by: Kenneth Lozes <klozes@system76-pc.localdomain>
Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
3 people authored May 24, 2020
1 parent 2707648 commit df26aa3
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 11 deletions.
35 changes: 34 additions & 1 deletion python/taichi/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __init__(self, func, classfunc=False):
self.func = func
self.compiled = None
self.classfunc = classfunc
self.arguments = []
self.argument_names = []
self.extract_arguments()

def __call__(self, *args):
if self.compiled is None:
Expand All @@ -65,7 +68,7 @@ def do_compile(self):
print('Before preprocessing:')
print(astor.to_source(tree.body[0], indent_with=' '))

visitor = ASTTransformer(is_kernel=False, is_classfunc=self.classfunc)
visitor = ASTTransformer(is_kernel=False, func=self)
visitor.visit(tree)
ast.fix_missing_locations(tree)

Expand All @@ -87,6 +90,36 @@ def do_compile(self):
mode='exec'), global_vars, local_vars)
self.compiled = local_vars[self.func.__name__]

def extract_arguments(self):
sig = inspect.signature(self.func)
if sig.return_annotation not in (inspect._empty, None):
self.return_type = sig.return_annotation
params = sig.parameters
arg_names = params.keys()
for i, arg_name in enumerate(arg_names):
param = params[arg_name]
if param.kind == inspect.Parameter.VAR_KEYWORD:
raise KernelDefError(
'Taichi funcs do not support variable keyword parameters (i.e., **kwargs)'
)
if param.kind == inspect.Parameter.VAR_POSITIONAL:
raise KernelDefError(
'Taichi funcs do not support variable positional parameters (i.e., *args)'
)
if param.kind == inspect.Parameter.KEYWORD_ONLY:
raise KernelDefError(
'Taichi funcs do not support keyword parameters')
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
raise KernelDefError(
'Taichi funcs only support "positional or keyword" parameters'
)
annotation = param.annotation
if param.annotation is inspect.Parameter.empty:
if i == 0 and self.classfunc:
annotation = template()
self.arguments.append(annotation)
self.argument_names.append(param.name)


def classfunc(foo):
import warnings
Expand Down
27 changes: 17 additions & 10 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(self,
self.local_scopes = []
self.excluded_parameters = excluded_paremeters
self.is_kernel = is_kernel
self.is_classfunc = is_classfunc
self.func = func
self.arg_features = arg_features
self.returns = None
Expand Down Expand Up @@ -211,7 +210,7 @@ def visit_While(self, node):
raise TaichiSyntaxError(
"'else' clause for 'while' not supported in Taichi kernels")

template = '''
template = '''
if 1:
ti.core.begin_frontend_while(ti.Expr(1).ptr)
__while_cond = 0
Expand Down Expand Up @@ -249,7 +248,7 @@ def visit_If(self, node):
# Do nothing
return node

template = '''
template = '''
if 1:
__cond = 0
ti.core.begin_frontend_if(ti.Expr(__cond).ptr)
Expand Down Expand Up @@ -297,7 +296,7 @@ def visit_static_for(self, node, is_grouped):
self.generic_visit(node, ['body'])
if is_grouped:
assert len(node.iter.args[0].args) == 1
template = '''
template = '''
if 1:
__ndrange_arg = 0
from taichi.lang.exception import TaichiSyntaxError
Expand Down Expand Up @@ -326,10 +325,10 @@ def visit_range_for(self, node):
self.generic_visit(node, ['body'])
loop_var = node.target.id
self.check_loop_var(loop_var)
template = '''
template = '''
if 1:
{} = ti.Expr(ti.core.make_id_expr(''))
___begin = ti.Expr(0)
___begin = ti.Expr(0)
___end = ti.Expr(0)
___begin = ti.cast(___begin, ti.i32)
___end = ti.cast(___end, ti.i32)
Expand Down Expand Up @@ -445,7 +444,7 @@ def visit_struct_for(self, node, is_grouped):
t.body[0].value = node.iter
t.body = t.body[:cut] + node.body + t.body[cut:]
else:
template = '''
template = '''
if 1:
{}
___loop_var = 0
Expand Down Expand Up @@ -605,6 +604,8 @@ def visit_FunctionDef(self, node):
node.returns = None

for i, arg in enumerate(args.args):
# Directly pass in template arguments,
# such as class instances ("self"), tensors, SNodes, etc.
if isinstance(self.func.arguments[i], ti.template):
continue
import taichi as ti
Expand Down Expand Up @@ -642,8 +643,12 @@ def visit_FunctionDef(self, node):
# Transform as func (all parameters passed by value)
arg_decls = []
for i, arg in enumerate(args.args):
if i == 0 and self.is_classfunc:
# Directly pass in template arguments,
# such as class instances ("self"), tensors, SNodes, etc.
if isinstance(self.func.arguments[i], ti.template):
continue
# Create a copy for non-template arguments,
# so that they are passed by value.
arg_init = self.parse_stmt('x = ti.expr_init_func(0)')
arg_init.targets[0].id = arg.arg
self.create_variable(arg.arg)
Expand Down Expand Up @@ -752,8 +757,10 @@ def visit_Return(self, node):
# TODO: check if it's at the end of a kernel, throw TaichiSyntaxError if not
if node.value is not None:
if self.returns is None:
raise TaichiSyntaxError('kernel with return value must be '
'annotated with a return type, e.g. def func() -> ti.f32')
raise TaichiSyntaxError(
'kernel with return value must be '
'annotated with a return type, e.g. def func() -> ti.f32'
)
ret_expr = self.parse_expr('ti.cast(ti.Expr(0), 0)')
ret_expr.args[0].args[0] = node.value
ret_expr.args[1] = self.returns
Expand Down
62 changes: 62 additions & 0 deletions tests/python/test_kernel_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,65 @@ def compute_loss():
for i in range(16):
assert z[i] == i * 4 + 3
assert x.grad[i] == 4


@ti.all_archs
def test_func_template():
a = [ti.var(dt=ti.f32) for _ in range(2)]
b = [ti.var(dt=ti.f32) for _ in range(2)]

for l in range(2):
ti.root.dense(ti.ij, 16).place(a[l], b[l])

@ti.func
def sample(x: ti.template(), l: ti.template(), I):
return x[l][I]

@ti.kernel
def fill(l: ti.template()):
for I in ti.grouped(a[l]):
a[l][I] = l

@ti.kernel
def aTob(l: ti.template()):
for I in ti.grouped(b[l]):
b[l][I] = sample(a, l, I)

for l in range(2):
fill(l)
aTob(l)

for l in range(2):
for i in range(16):
for j in range(16):
assert b[l][i, j] == l


@ti.all_archs
def test_func_template2():
a = ti.var(dt=ti.f32)
b = ti.var(dt=ti.f32)

ti.root.dense(ti.ij, 16).place(a, b)

@ti.func
def sample(x: ti.template(), I):
return x[I]

@ti.kernel
def fill():
for I in ti.grouped(a):
a[I] = 1.0

@ti.kernel
def aTob():
for I in ti.grouped(b):
b[I] = sample(a, I)

for l in range(2):
fill()
aTob()

for i in range(16):
for j in range(16):
assert b[i, j] == 1.0

0 comments on commit df26aa3

Please sign in to comment.