From df26aa3cd35d87dd4c7777d6b9c5625fddb2f523 Mon Sep 17 00:00:00 2001 From: Kenneth Lozes <39635884+KLozes@users.noreply.github.com> Date: Sun, 24 May 2020 14:06:00 -0700 Subject: [PATCH] [Lang] Support template arguments for ti.func (#1043) * support template arg in ti.func and slight refactor * add non-working test * [skip ci] enforce code format * working tests * [skip ci] enforce code format * [skip ci] new test * [skip ci] enforce code format * [skip ci] add comments in transformer Co-authored-by: Kenneth Lozes Co-authored-by: Taichi Gardener --- python/taichi/lang/kernel.py | 35 ++++++++++++++- python/taichi/lang/transformer.py | 27 +++++++----- tests/python/test_kernel_templates.py | 62 +++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 11 deletions(-) diff --git a/python/taichi/lang/kernel.py b/python/taichi/lang/kernel.py index fd72364409479..b3375670e63d5 100644 --- a/python/taichi/lang/kernel.py +++ b/python/taichi/lang/kernel.py @@ -45,6 +45,9 @@ def __init__(self, func, classfunc=False): self.func = func self.compiled = None self.classfunc = classfunc + self.arguments = [] + self.argument_names = [] + self.extract_arguments() def __call__(self, *args): if self.compiled is None: @@ -65,7 +68,7 @@ def do_compile(self): print('Before preprocessing:') print(astor.to_source(tree.body[0], indent_with=' ')) - visitor = ASTTransformer(is_kernel=False, is_classfunc=self.classfunc) + visitor = ASTTransformer(is_kernel=False, func=self) visitor.visit(tree) ast.fix_missing_locations(tree) @@ -87,6 +90,36 @@ def do_compile(self): mode='exec'), global_vars, local_vars) self.compiled = local_vars[self.func.__name__] + def extract_arguments(self): + sig = inspect.signature(self.func) + if sig.return_annotation not in (inspect._empty, None): + self.return_type = sig.return_annotation + params = sig.parameters + arg_names = params.keys() + for i, arg_name in enumerate(arg_names): + param = params[arg_name] + if param.kind == inspect.Parameter.VAR_KEYWORD: + raise KernelDefError( + 'Taichi funcs do not support variable keyword parameters (i.e., **kwargs)' + ) + if param.kind == inspect.Parameter.VAR_POSITIONAL: + raise KernelDefError( + 'Taichi funcs do not support variable positional parameters (i.e., *args)' + ) + if param.kind == inspect.Parameter.KEYWORD_ONLY: + raise KernelDefError( + 'Taichi funcs do not support keyword parameters') + if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD: + raise KernelDefError( + 'Taichi funcs only support "positional or keyword" parameters' + ) + annotation = param.annotation + if param.annotation is inspect.Parameter.empty: + if i == 0 and self.classfunc: + annotation = template() + self.arguments.append(annotation) + self.argument_names.append(param.name) + def classfunc(foo): import warnings diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 7e51650e242f3..5cf54268d0b42 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -34,7 +34,6 @@ def __init__(self, self.local_scopes = [] self.excluded_parameters = excluded_paremeters self.is_kernel = is_kernel - self.is_classfunc = is_classfunc self.func = func self.arg_features = arg_features self.returns = None @@ -211,7 +210,7 @@ def visit_While(self, node): raise TaichiSyntaxError( "'else' clause for 'while' not supported in Taichi kernels") - template = ''' + template = ''' if 1: ti.core.begin_frontend_while(ti.Expr(1).ptr) __while_cond = 0 @@ -249,7 +248,7 @@ def visit_If(self, node): # Do nothing return node - template = ''' + template = ''' if 1: __cond = 0 ti.core.begin_frontend_if(ti.Expr(__cond).ptr) @@ -297,7 +296,7 @@ def visit_static_for(self, node, is_grouped): self.generic_visit(node, ['body']) if is_grouped: assert len(node.iter.args[0].args) == 1 - template = ''' + template = ''' if 1: __ndrange_arg = 0 from taichi.lang.exception import TaichiSyntaxError @@ -326,10 +325,10 @@ def visit_range_for(self, node): self.generic_visit(node, ['body']) loop_var = node.target.id self.check_loop_var(loop_var) - template = ''' + template = ''' if 1: {} = ti.Expr(ti.core.make_id_expr('')) - ___begin = ti.Expr(0) + ___begin = ti.Expr(0) ___end = ti.Expr(0) ___begin = ti.cast(___begin, ti.i32) ___end = ti.cast(___end, ti.i32) @@ -445,7 +444,7 @@ def visit_struct_for(self, node, is_grouped): t.body[0].value = node.iter t.body = t.body[:cut] + node.body + t.body[cut:] else: - template = ''' + template = ''' if 1: {} ___loop_var = 0 @@ -605,6 +604,8 @@ def visit_FunctionDef(self, node): node.returns = None for i, arg in enumerate(args.args): + # Directly pass in template arguments, + # such as class instances ("self"), tensors, SNodes, etc. if isinstance(self.func.arguments[i], ti.template): continue import taichi as ti @@ -642,8 +643,12 @@ def visit_FunctionDef(self, node): # Transform as func (all parameters passed by value) arg_decls = [] for i, arg in enumerate(args.args): - if i == 0 and self.is_classfunc: + # Directly pass in template arguments, + # such as class instances ("self"), tensors, SNodes, etc. + if isinstance(self.func.arguments[i], ti.template): continue + # Create a copy for non-template arguments, + # so that they are passed by value. arg_init = self.parse_stmt('x = ti.expr_init_func(0)') arg_init.targets[0].id = arg.arg self.create_variable(arg.arg) @@ -752,8 +757,10 @@ def visit_Return(self, node): # TODO: check if it's at the end of a kernel, throw TaichiSyntaxError if not if node.value is not None: if self.returns is None: - raise TaichiSyntaxError('kernel with return value must be ' - 'annotated with a return type, e.g. def func() -> ti.f32') + raise TaichiSyntaxError( + 'kernel with return value must be ' + 'annotated with a return type, e.g. def func() -> ti.f32' + ) ret_expr = self.parse_expr('ti.cast(ti.Expr(0), 0)') ret_expr.args[0].args[0] = node.value ret_expr.args[1] = self.returns diff --git a/tests/python/test_kernel_templates.py b/tests/python/test_kernel_templates.py index d6d046dbef84d..4dda891cdf5e3 100644 --- a/tests/python/test_kernel_templates.py +++ b/tests/python/test_kernel_templates.py @@ -68,3 +68,65 @@ def compute_loss(): for i in range(16): assert z[i] == i * 4 + 3 assert x.grad[i] == 4 + + +@ti.all_archs +def test_func_template(): + a = [ti.var(dt=ti.f32) for _ in range(2)] + b = [ti.var(dt=ti.f32) for _ in range(2)] + + for l in range(2): + ti.root.dense(ti.ij, 16).place(a[l], b[l]) + + @ti.func + def sample(x: ti.template(), l: ti.template(), I): + return x[l][I] + + @ti.kernel + def fill(l: ti.template()): + for I in ti.grouped(a[l]): + a[l][I] = l + + @ti.kernel + def aTob(l: ti.template()): + for I in ti.grouped(b[l]): + b[l][I] = sample(a, l, I) + + for l in range(2): + fill(l) + aTob(l) + + for l in range(2): + for i in range(16): + for j in range(16): + assert b[l][i, j] == l + + +@ti.all_archs +def test_func_template2(): + a = ti.var(dt=ti.f32) + b = ti.var(dt=ti.f32) + + ti.root.dense(ti.ij, 16).place(a, b) + + @ti.func + def sample(x: ti.template(), I): + return x[I] + + @ti.kernel + def fill(): + for I in ti.grouped(a): + a[I] = 1.0 + + @ti.kernel + def aTob(): + for I in ti.grouped(b): + b[I] = sample(a, I) + + for l in range(2): + fill() + aTob() + + for i in range(16): + for j in range(16): + assert b[i, j] == 1.0