Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Support template arguments for ti.func #1043

Merged
merged 9 commits into from
May 24, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion python/taichi/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __init__(self, func, classfunc=False):
self.func = func
self.compiled = None
self.classfunc = classfunc
self.arguments = []
self.argument_names = []
self.extract_arguments()

def __call__(self, *args):
if self.compiled is None:
Expand All @@ -65,7 +68,7 @@ def do_compile(self):
print('Before preprocessing:')
print(astor.to_source(tree.body[0], indent_with=' '))

visitor = ASTTransformer(is_kernel=False, is_classfunc=self.classfunc)
visitor = ASTTransformer(is_kernel=False, func=self)
visitor.visit(tree)
ast.fix_missing_locations(tree)

Expand All @@ -87,6 +90,36 @@ def do_compile(self):
mode='exec'), global_vars, local_vars)
self.compiled = local_vars[self.func.__name__]

def extract_arguments(self):
sig = inspect.signature(self.func)
if sig.return_annotation not in (inspect._empty, None):
self.return_type = sig.return_annotation
params = sig.parameters
arg_names = params.keys()
for i, arg_name in enumerate(arg_names):
param = params[arg_name]
if param.kind == inspect.Parameter.VAR_KEYWORD:
raise KernelDefError(
'Taichi funcs do not support variable keyword parameters (i.e., **kwargs)'
)
if param.kind == inspect.Parameter.VAR_POSITIONAL:
raise KernelDefError(
'Taichi funcs do not support variable positional parameters (i.e., *args)'
)
if param.kind == inspect.Parameter.KEYWORD_ONLY:
raise KernelDefError(
'Taichi funcs do not support keyword parameters')
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
raise KernelDefError(
'Taichi funcs only support "positional or keyword" parameters'
)
annotation = param.annotation
if param.annotation is inspect.Parameter.empty:
if i == 0 and self.classfunc:
annotation = template()
self.arguments.append(annotation)
self.argument_names.append(param.name)


def classfunc(foo):
import warnings
Expand Down
21 changes: 11 additions & 10 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(self,
self.local_scopes = []
self.excluded_parameters = excluded_paremeters
self.is_kernel = is_kernel
self.is_classfunc = is_classfunc
self.func = func
self.arg_features = arg_features
self.returns = None
Expand Down Expand Up @@ -211,7 +210,7 @@ def visit_While(self, node):
raise TaichiSyntaxError(
"'else' clause for 'while' not supported in Taichi kernels")

template = '''
template = '''
if 1:
ti.core.begin_frontend_while(ti.Expr(1).ptr)
__while_cond = 0
Expand Down Expand Up @@ -249,7 +248,7 @@ def visit_If(self, node):
# Do nothing
return node

template = '''
template = '''
if 1:
__cond = 0
ti.core.begin_frontend_if(ti.Expr(__cond).ptr)
Expand Down Expand Up @@ -297,7 +296,7 @@ def visit_static_for(self, node, is_grouped):
self.generic_visit(node, ['body'])
if is_grouped:
assert len(node.iter.args[0].args) == 1
template = '''
template = '''
if 1:
__ndrange_arg = 0
from taichi.lang.exception import TaichiSyntaxError
Expand Down Expand Up @@ -326,10 +325,10 @@ def visit_range_for(self, node):
self.generic_visit(node, ['body'])
loop_var = node.target.id
self.check_loop_var(loop_var)
template = '''
template = '''
if 1:
{} = ti.Expr(ti.core.make_id_expr(''))
___begin = ti.Expr(0)
___begin = ti.Expr(0)
___end = ti.Expr(0)
___begin = ti.cast(___begin, ti.i32)
___end = ti.cast(___end, ti.i32)
Expand Down Expand Up @@ -445,7 +444,7 @@ def visit_struct_for(self, node, is_grouped):
t.body[0].value = node.iter
t.body = t.body[:cut] + node.body + t.body[cut:]
else:
template = '''
template = '''
if 1:
{}
___loop_var = 0
Expand Down Expand Up @@ -642,7 +641,7 @@ def visit_FunctionDef(self, node):
# Transform as func (all parameters passed by value)
arg_decls = []
for i, arg in enumerate(args.args):
if i == 0 and self.is_classfunc:
if isinstance(self.func.arguments[i], ti.template):
continue
Comment on lines +648 to 649
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(self.func.arguments[i], ti.template):
continue
# Directly pass in template arguments,
# such as class instances ("self"), tensors, SNodes, etc.
if isinstance(self.func.arguments[i], ti.template):
continue
# Create a copy for non-template arguments,
# so that they are passed by value.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aliright thanks! Yes, both tests are passing. Although the second one passes even without making x a template. I assume its related to the hack 0139195#r39384288 @archibate was talking about, which I think we can get rid of at some point.

arg_init = self.parse_stmt('x = ti.expr_init_func(0)')
arg_init.targets[0].id = arg.arg
Expand Down Expand Up @@ -752,8 +751,10 @@ def visit_Return(self, node):
# TODO: check if it's at the end of a kernel, throw TaichiSyntaxError if not
if node.value is not None:
if self.returns is None:
raise TaichiSyntaxError('kernel with return value must be '
'annotated with a return type, e.g. def func() -> ti.f32')
raise TaichiSyntaxError(
'kernel with return value must be '
'annotated with a return type, e.g. def func() -> ti.f32'
)
ret_expr = self.parse_expr('ti.cast(ti.Expr(0), 0)')
ret_expr.args[0].args[0] = node.value
ret_expr.args[1] = self.returns
Expand Down
62 changes: 62 additions & 0 deletions tests/python/test_kernel_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,65 @@ def compute_loss():
for i in range(16):
assert z[i] == i * 4 + 3
assert x.grad[i] == 4


@ti.all_archs
def test_func_template():
a = [ti.var(dt=ti.f32) for _ in range(2)]
b = [ti.var(dt=ti.f32) for _ in range(2)]

for l in range(2):
ti.root.dense(ti.ij, 16).place(a[l], b[l])

@ti.func
def sample(x: ti.template(), l: ti.template(), I):
return x[l][I]

@ti.kernel
def fill(l: ti.template()):
for I in ti.grouped(a[l]):
a[l][I] = l

@ti.kernel
def aTob(l: ti.template()):
for I in ti.grouped(b[l]):
b[l][I] = sample(a, l, I)

for l in range(2):
fill(l)
aTob(l)

for l in range(2):
for i in range(16):
for j in range(16):
assert b[l][i, j] == l


@ti.all_archs
def test_func_template2():
a = ti.var(dt=ti.f32)
b = ti.var(dt=ti.f32)

ti.root.dense(ti.ij, 16).place(a, b)

@ti.func
def sample(x: ti.template(), I):
return x[I]

@ti.kernel
def fill():
for I in ti.grouped(a):
a[I] = 1.0

@ti.kernel
def aTob():
for I in ti.grouped(b):
b[I] = sample(a, I)

for l in range(2):
fill()
aTob()

for i in range(16):
for j in range(16):
assert b[i, j] == 1.0