Skip to content

Commit

Permalink
[Lang] MatrixNdarray refactor part2: Remove redundant members in pyth…
Browse files Browse the repository at this point in the history
…on-scope AnyArray (#5885)

* [Lang] MatrixNdarray refactor part0: Support direct TensorType construction in Ndarray and refactored the use of element_shape

* Fixed minor issue

* Fixed CI failures

* Minor refactor

* Fixed minor issue

* [Lang] MatrixNdarray refactor part1: Refactored Taichi kernel argument to use TensorType

* Fixed CI failure with Metal backend

* Addressed review comments

* Fixed format issue with clang-tidy

* Review comments

* [Lang] MatrixNdarray refactor part2: Remove redundant members in python-scope AnyArray

* Fixed CI failure

* Fix CI failures

* Renamed interface

* Minor bug fix
  • Loading branch information
jim19930609 authored Aug 30, 2022
1 parent fb62f1c commit 804e713
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 31 deletions.
28 changes: 19 additions & 9 deletions python/taichi/lang/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,26 @@ class AnyArray:
element_shape (Tuple[Int]): () if scalar elements (default), (n) if vector elements, and (n, m) if matrix elements.
layout (Layout): Memory layout.
"""
def __init__(self, ptr, element_shape, layout):
def __init__(self, ptr):
assert ptr.is_external_var()
self.ptr = ptr
self.element_shape = element_shape
self.layout = layout

def element_shape(self):
return _ti_core.get_external_tensor_element_shape(self.ptr)

def layout(self):
# 0: scalar; 1: vector (SOA); 2: matrix (SOA); -1: vector
# (AOS); -2: matrix (AOS)
element_dim = _ti_core.get_external_tensor_element_dim(self.ptr)
if element_dim == 1 or element_dim == 2:
return Layout.SOA
return Layout.AOS

def get_type(self):
return NdarrayTypeMetadata(
self.ptr.get_ret_type(),
None, # AnyArray can take any shape
self.layout)
self.layout())

@property
@taichi_scope
Expand All @@ -39,11 +48,11 @@ def shape(self):
Expr(_ti_core.get_external_tensor_shape_along_axis(self.ptr, i))
for i in range(dim)
]
element_dim = len(self.element_shape)
element_dim = len(self.element_shape())
if element_dim == 0:
return ret
return ret[
element_dim:] if self.layout == Layout.SOA else ret[:-element_dim]
return ret[element_dim:] if self.layout(
) == Layout.SOA else ret[:-element_dim]

@taichi_scope
def _loop_range(self):
Expand All @@ -70,8 +79,9 @@ def __init__(self, arr, indices_first):

@taichi_scope
def subscript(self, i, j):
indices_second = (i, ) if len(self.arr.element_shape) == 1 else (i, j)
if self.arr.layout == Layout.SOA:
indices_second = (i, ) if len(self.arr.element_shape()) == 1 else (i,
j)
if self.arr.layout() == Layout.SOA:
indices = indices_second + self.indices_first
else:
indices = self.indices_first + indices_second
Expand Down
6 changes: 3 additions & 3 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def subscript(value, *_indices, skip_reordered=False, get_ref=False):
if isinstance(value, AnyArray):
# TODO: deprecate using get_attribute to get dim
field_dim = int(value.ptr.get_attribute("dim"))
element_dim = len(value.element_shape)
element_dim = len(value.element_shape())
if field_dim != index_dim + element_dim:
raise IndexError(
f'Field with dim {field_dim - element_dim} accessed with indices of dim {index_dim}'
Expand All @@ -213,8 +213,8 @@ def subscript(value, *_indices, skip_reordered=False, get_ref=False):
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
n = value.element_shape[0]
m = 1 if element_dim == 1 else value.element_shape[1]
n = value.element_shape()[0]
m = 1 if element_dim == 1 else value.element_shape()[1]
any_array_access = AnyArrayAccess(value, _indices)
ret = _IntermediateMatrix(n,
m, [
Expand Down
3 changes: 1 addition & 2 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def decl_ndarray_arg(dtype, dim, element_shape, layout):
element_dim = -element_dim
return AnyArray(
_ti_core.make_external_tensor_expr(dtype, dim, arg_id, element_dim,
element_shape), element_shape,
layout)
element_shape))


def decl_texture_arg(num_dimensions):
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/types/compound_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ class TensorType(CompoundType):
def __init__(self, shape, dtype):
self.ptr = _type_factory.get_tensor_type(shape, dtype)

def get_shape(self):
return tuple(self.ptr.get_shape())
def shape(self):
return tuple(self.ptr.shape())

def get_element_type(self):
return self.ptr.get_element_type()
def element_type(self):
return self.ptr.element_type()


# TODO: maybe move MatrixType, StructType here to avoid the circular import?
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/types/ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def check_matched(self, ndarray_type: NdarrayTypeMetadata):
raise TypeError(
f"Expect TensorType element for Ndarray with element_dim: {self.element_dim} > 0"
)
if self.element_dim != len(ndarray_type.element_type.get_shape()):
if self.element_dim != len(ndarray_type.element_type.shape()):
raise ValueError(
f"Invalid argument into ti.types.ndarray() - required element_dim={self.element_dim}, but {len(ndarray_type.element_type.get_shape())} is provided"
f"Invalid argument into ti.types.ndarray() - required element_dim={self.element_dim}, but {len(ndarray_type.element_type.shape())} is provided"
)

if self.element_shape is not None and len(self.element_shape) > 0:
Expand All @@ -60,9 +60,9 @@ def check_matched(self, ndarray_type: NdarrayTypeMetadata):
f"Expect TensorType element for Ndarray with element_shape: {self.element_shape}"
)

if self.element_shape != ndarray_type.element_type.get_shape():
if self.element_shape != ndarray_type.element_type.shape():
raise ValueError(
f"Invalid argument into ti.types.ndarray() - required element_shape={self.element_shape}, but {ndarray_type.element_type.get_shape()} is provided"
f"Invalid argument into ti.types.ndarray() - required element_shape={self.element_shape}, but {ndarray_type.element_type.shape()} is provided"
)

if self.layout is not None and self.layout != ndarray_type.layout:
Expand Down
6 changes: 3 additions & 3 deletions taichi/program/callable.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class TI_DLL_EXPORT Callable {
int total_dim = 0,
std::vector<int> element_shape = {}) {
if (dt->is<PrimitiveType>() && element_shape.size() > 0) {
this->dt_ = taichi::lang::TypeFactory::get_instance().get_tensor_type(
element_shape, dt.operator->());

this->dt_ =
taichi::lang::TypeFactory::get_instance().create_tensor_type(
element_shape, dt);
} else {
this->dt_ = dt;
}
Expand Down
14 changes: 12 additions & 2 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ void export_lang(py::module &m) {
.def("__hash__", &DataType::hash)
.def("to_string", &DataType::to_string)
.def("__str__", &DataType::to_string)
.def("get_shape", &DataType::get_shape)
.def("get_element_type", &DataType::get_element_type)
.def("shape", &DataType::get_shape)
.def("element_type", &DataType::get_element_type)
.def(
"get_ptr", [](DataType *dtype) -> Type * { return *dtype; },
py::return_value_policy::reference)
Expand Down Expand Up @@ -985,6 +985,16 @@ void export_lang(py::module &m) {
Expr::make<StrideExpression, const Expr &, const ExprGroup &,
const std::vector<int> &, int>);

m.def("get_external_tensor_element_dim", [](const Expr &expr) {
TI_ASSERT(expr.is<ExternalTensorExpression>());
return expr.cast<ExternalTensorExpression>()->element_dim;
});

m.def("get_external_tensor_element_shape", [](const Expr &expr) {
TI_ASSERT(expr.is<ExternalTensorExpression>());
return expr.cast<ExternalTensorExpression>()->element_shape;
});

m.def("get_external_tensor_dim", [](const Expr &expr) {
TI_ASSERT(expr.is<ExternalTensorExpression>());
return expr.cast<ExternalTensorExpression>()->dim;
Expand Down
8 changes: 4 additions & 4 deletions tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ def test_ndarray_compound_element():
b = ti.ndarray(vec3, shape=(n, n))
assert isinstance(b, ti.MatrixNdarray)
assert b.shape == (n, n)
assert b.element_type.get_element_type() == ti.i32
assert b.element_type.get_shape() == (3, 1)
assert b.element_type.element_type() == ti.i32
assert b.element_type.shape() == (3, 1)

matrix34 = ti.types.matrix(3, 4, float)
c = ti.ndarray(matrix34, shape=(n, n + 1), layout=ti.Layout.SOA)
assert isinstance(c, ti.MatrixNdarray)
assert c.shape == (n, n + 1)
assert c.element_type.get_element_type() == ti.f32
assert c.element_type.get_shape() == (3, 4)
assert c.element_type.element_type() == ti.f32
assert c.element_type.shape() == (3, 4)
assert c.layout == ti.Layout.SOA


Expand Down

0 comments on commit 804e713

Please sign in to comment.