Skip to content

Commit

Permalink
[Lang] Support ndarray argument for real function
Browse files Browse the repository at this point in the history
ghstack-source-id: 78a359fac9fc628ed0bc578dec9d968b67dcb97b
Pull Request resolved: #8188
  • Loading branch information
lin-hitonami committed Jun 15, 2023
1 parent d02804d commit 7dded38
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 59 deletions.
5 changes: 3 additions & 2 deletions python/taichi/lang/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ class AnyArray:
layout (Layout): Memory layout.
"""

def __init__(self, ptr):
def __init__(self, ptr, element_type):
assert ptr.is_external_tensor_expr()
self.element_type = element_type
self.ptr = ptr
self.ptr.type_check(impl.get_runtime().prog.config())

Expand All @@ -33,7 +34,7 @@ def layout(self):

def get_type(self):
return NdarrayTypeMetadata(
self.ptr.get_ret_type().ptr_removed(), None, _ti_core.get_external_tensor_needs_grad(self.ptr)
self.element_type, None, _ti_core.get_external_tensor_needs_grad(self.ptr)
) # AnyArray can take any shape

@property
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def decl_sparse_matrix(dtype, name):

def decl_ndarray_arg(element_type, ndim, name, needs_grad, boundary):
arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(element_type, ndim, name, needs_grad)
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary))
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary), element_type)


def decl_texture_arg(num_dimensions, name):
Expand Down
17 changes: 13 additions & 4 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import taichi.lang
from taichi._lib import core as _ti_core
from taichi.lang import impl, ops, runtime_ops
from taichi.lang.any_array import AnyArray
from taichi.lang._wrap_inspect import getsourcefile, getsourcelines
from taichi.lang.argpack import ArgPackType, ArgPack
from taichi.lang.ast import (
Expand Down Expand Up @@ -226,12 +227,12 @@ def __call__(self, *args, **kwargs):
if self.is_real_function:
if impl.get_runtime().current_kernel.autodiff_mode != AutodiffMode.NONE:
raise TaichiSyntaxError("Real function in gradient kernels unsupported.")
instance_id, _ = self.mapper.lookup(args)
instance_id, arg_features = self.mapper.lookup(args)
key = _ti_core.FunctionKey(self.func.__name__, self.func_id, instance_id)
if self.compiled is None:
self.compiled = {}
if key.instance_id not in self.compiled:
self.do_compile(key=key, args=args)
self.do_compile(key=key, args=args, arg_features=arg_features)
return self.func_call_rvalue(key=key, args=args)
tree, ctx = _get_tree_and_ctx(
self,
Expand All @@ -257,6 +258,8 @@ def func_call_rvalue(self, key, args):
non_template_args.append(ops.cast(args[i], anno))
elif isinstance(anno, primitive_types.RefType):
non_template_args.append(_ti_core.make_reference(args[i].ptr))
elif isinstance(anno, ndarray_type.NdarrayType):
non_template_args.append(args[i].ptr)
else:
non_template_args.append(args[i])
non_template_args = impl.make_expr_group(non_template_args, real_func_arg=True)
Expand All @@ -274,8 +277,10 @@ def func_call_rvalue(self, key, args):
return self.return_type.from_taichi_object(func_call, (0,))
raise TaichiTypeError(f"Unsupported return type: {self.return_type}")

def do_compile(self, key, args):
tree, ctx = _get_tree_and_ctx(self, is_kernel=False, args=args, is_real_function=self.is_real_function)
def do_compile(self, key, args, arg_features):
tree, ctx = _get_tree_and_ctx(
self, is_kernel=False, args=args, arg_features=arg_features, is_real_function=self.is_real_function
)
fn = impl.get_runtime().prog.create_function(key)

def func_body():
Expand Down Expand Up @@ -403,6 +408,10 @@ def extract_arg(arg, anno, arg_name):
anno.check_matched(arg.get_type(), arg_name)
needs_grad = (arg.grad is not None) if anno.needs_grad is None else anno.needs_grad
return arg.element_type, len(arg.shape), needs_grad, anno.boundary
if isinstance(arg, AnyArray):
ty = arg.get_type()
anno.check_matched(arg.get_type(), arg_name)
return ty.element_type, len(arg.shape), ty.needs_grad, anno.boundary
# external arrays
shape = getattr(arg, "shape", None)
if shape is None:
Expand Down
113 changes: 67 additions & 46 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) {
void ExternalTensorExpression::flatten(FlattenContext *ctx) {
auto type =
TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, needs_grad);
type = TypeFactory::get_instance().get_pointer_type((Type *)type);

auto ptr = Stmt::make<ArgLoadStmt>(arg_id, type, /*is_ptr=*/true,
/*create_load=*/false);
Expand Down Expand Up @@ -1671,62 +1672,82 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
std::vector<Expr> expanded_exprs;
for (auto expr : exprs) {
TI_ASSERT_TYPE_CHECKED(expr);
if (auto struct_type = expr->ret_type.ptr_removed()->cast<StructType>()) {
auto num_elem = struct_type->elements().size();
for (int i = 0; i < num_elem; i++) {
std::vector<int> indices = {i};
auto elem = Expr(std::make_shared<GetElementExpression>(expr, indices));
elem.expr->ret_type = struct_type->get_element_type(indices);
expanded_exprs.push_back(elem);
}
} else if (!expr->ret_type.ptr_removed()->is<TensorType>()) {
expanded_exprs.push_back(expr);
} else {
// Expand TensorType expr
/*
Before:
TensorType<4 x i32> index = Expr;
After:
TensorType<4 x i32>* id_expr = FrontendAllocaStmt(TensorType<4 x i32>)
i32 ind0 = IndexExpression(id_expr, 0)
i32 ind1 = IndexExpression(id_expr, 1)
i32 ind2 = IndexExpression(id_expr, 2)
i32 ind3 = IndexExpression(id_expr, 3)
return {ind0, ind1, ind2, ind3}
*/
auto tensor_type = expr->ret_type.ptr_removed()->cast<TensorType>();

Expr id_expr;
if (expr.is<IdExpression>()) {
id_expr = expr;

auto expand_tensor_or_scalar = [&](Expr expr) {
if (!expr->ret_type.ptr_removed()->is<TensorType>()) {
expanded_exprs.push_back(expr);
} else {
id_expr = make_var(expr, expr->tb);
}
auto shape = tensor_type->get_shape();
if (shape.size() == 1) {
for (int i = 0; i < shape[0]; i++) {
auto ind = Expr(std::make_shared<IndexExpression>(
id_expr, ExprGroup(Expr(i)), expr->tb));
ind->type_check(nullptr);
expanded_exprs.push_back(ind);
// Expand TensorType expr
/*
Before:
TensorType<4 x i32> index = Expr;
After:
TensorType<4 x i32>* id_expr = FrontendAllocaStmt(TensorType<4 x i32>)
i32 ind0 = IndexExpression(id_expr, 0)
i32 ind1 = IndexExpression(id_expr, 1)
i32 ind2 = IndexExpression(id_expr, 2)
i32 ind3 = IndexExpression(id_expr, 3)
return {ind0, ind1, ind2, ind3}
*/
auto tensor_type = expr->ret_type.ptr_removed()->cast<TensorType>();

Expr id_expr;
if (expr.is<IdExpression>()) {
id_expr = expr;
} else {
id_expr = make_var(expr, expr->tb);
}
} else {
TI_ASSERT(shape.size() == 2);
for (int i = 0; i < shape[0]; i++) {
for (int j = 0; j < shape[1]; j++) {
auto shape = tensor_type->get_shape();
if (shape.size() == 1) {
for (int i = 0; i < shape[0]; i++) {
auto ind = Expr(std::make_shared<IndexExpression>(
id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb));
id_expr, ExprGroup(Expr(i)), expr->tb));
ind->type_check(nullptr);
expanded_exprs.push_back(ind);
}
} else {
TI_ASSERT(shape.size() == 2);
for (int i = 0; i < shape[0]; i++) {
for (int j = 0; j < shape[1]; j++) {
auto ind = Expr(std::make_shared<IndexExpression>(
id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb));
ind->type_check(nullptr);
expanded_exprs.push_back(ind);
}
}
}
}
};

std::function<void(const Expr &, const StructType *, std::vector<int> &)>
expand_struct = [&](const Expr &expr, const StructType *struct_type,
std::vector<int> &indices) {
auto num_elem = struct_type->elements().size();
for (int i = 0; i < num_elem; i++) {
indices.push_back(i);
auto element_type = struct_type->get_element_type({i});
if (auto element_struct_type = element_type->cast<StructType>()) {
expand_struct(expr, element_struct_type, indices);
} else {
auto elem =
Expr(std::make_shared<GetElementExpression>(expr, indices));
elem.expr->ret_type = element_type;
expand_tensor_or_scalar(elem);
}
indices.pop_back();
}
};
auto type = expr->ret_type.ptr_removed();
if (auto struct_type = type->cast<StructType>()) {
std::vector<int> indices;
expand_struct(expr, struct_type, indices);
} else {
expand_tensor_or_scalar(expr);
}
}

return expanded_exprs;
}

Expand Down
3 changes: 2 additions & 1 deletion taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ class ExternalTensorExpression : public Expression {
}

void type_check(const CompileConfig *config) override {
ret_type = dt;
ret_type = TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim,
needs_grad);
ret_type.set_is_pointer(true);
config_ = config;
}
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ void export_lang(py::module &m) {
py::class_<Function>(m, "Function")
.def("insert_scalar_param", &Function::insert_scalar_param)
.def("insert_arr_param", &Function::insert_arr_param)
.def("insert_ndarray_param", &Function::insert_ndarray_param)
.def("insert_texture_param", &Function::insert_texture_param)
.def("insert_pointer_param", &Function::insert_pointer_param)
.def("insert_rw_texture_param", &Function::insert_rw_texture_param)
Expand Down
26 changes: 21 additions & 5 deletions tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,15 +1005,31 @@ def test(x: ti.types.ndarray(dtype=ti.types.vector())):
@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_pass_ndarray_to_func():
@ti.func
def bar(weight: ti.types.ndarray(ti.f32, ndim=3)):
pass
def bar(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32:
return weight[1, 1, 1]

@ti.kernel
def foo(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32:
return bar(weight)

weight = ti.ndarray(dtype=ti.f32, shape=(2, 2, 2))
weight.fill(42.0)
assert foo(weight) == 42.0


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_pass_ndarray_to_real_func():
@ti.experimental.real_func
def bar(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32:
return weight[1, 1, 1]

@ti.kernel
def foo(weight: ti.types.ndarray(ti.f32, ndim=3)):
bar(weight)
def foo(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32:
return bar(weight)

weight = ti.ndarray(dtype=ti.f32, shape=(2, 2, 2))
foo(weight)
weight.fill(42.0)
assert foo(weight) == 42.0


@test_utils.test(arch=supported_archs_taichi_ndarray)
Expand Down

0 comments on commit 7dded38

Please sign in to comment.