Skip to content

Commit

Permalink
[aot] Fix ndarray aot with information from type hints (#7214)
Browse files Browse the repository at this point in the history
Issue: fixes #7172 

### Brief Summary
Ideally we should reconstruct the dtype to the Tensortype from
taichi_core instead of python ones but that can be a separate PR.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ailzhang and pre-commit-ci[bot] authored Jan 20, 2023
1 parent c379f58 commit 800ff2b
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 44 deletions.
71 changes: 37 additions & 34 deletions python/taichi/aot/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from taichi.lang._ndarray import ScalarNdarray
from taichi.lang._texture import Texture
from taichi.lang.exception import TaichiCompilationError
from taichi.lang.matrix import Matrix, MatrixNdarray, MatrixType, VectorNdarray
from taichi.lang.matrix import (Matrix, MatrixNdarray, MatrixType,
VectorNdarray, VectorType)
from taichi.lang.util import cook_dtype
from taichi.types.annotations import template
from taichi.types.ndarray_type import NdarrayType
Expand All @@ -11,9 +12,14 @@


def check_type_match(lhs, rhs):
if cook_dtype(lhs) == cook_dtype(rhs):
return True
return False
if isinstance(lhs, MatrixType) and isinstance(rhs, MatrixType):
return lhs.n == rhs.n and lhs.m == rhs.m and (lhs.dtype == rhs.dtype
or lhs.dtype is None
or rhs.dtype is None)
if isinstance(lhs, MatrixType) or isinstance(rhs, MatrixType):
return False

return cook_dtype(lhs) == cook_dtype(rhs)


def produce_injected_args_from_template(kernel, template_args):
Expand Down Expand Up @@ -43,52 +49,49 @@ def produce_injected_args(kernel, symbolic_args=None):
for i, arg in enumerate(kernel.arguments):
anno = arg.annotation
if isinstance(anno, NdarrayType):
# TODO(Haidong) we should always use MatrixType and get rid of the element shapes
if symbolic_args is not None:
element_shape = tuple(symbolic_args[i].element_shape)
element_dim = len(element_shape)
# TODO: reconstruct dtype to be TensorType from taichi_core instead of the Python ones
element_dim = len(symbolic_args[i].element_shape)
if element_dim == 0 or symbolic_args[i].element_shape == (1, ):
dtype = symbolic_args[i].dtype()
elif element_dim == 1:
dtype = VectorType(symbolic_args[i].element_shape[0],
symbolic_args[i].dtype())
elif element_dim == 2:
dtype = MatrixType(symbolic_args[i].element_shape[0],
symbolic_args[i].element_shape[1], 2,
symbolic_args[i].dtype())
else:
raise TaichiCompilationError('Not supported')
ndim = symbolic_args[i].field_dim
dtype = symbolic_args[i].dtype()
else:
element_shape = anno.dtype.get_shape()
element_dim = anno.dtype.ndim
ndim = anno.ndim
dtype = anno.dtype

if element_shape is None or ndim is None:
raise TaichiCompilationError(
'Please either specify both `element_shape` and `ndim` '
'in the param annotation, or provide an example '
f'ndarray for param={arg.name}')
if anno.ndim is not None and ndim != anno.ndim:
raise TaichiCompilationError(
f'{ndim} from Arg {arg.name} doesn\'t match kernel\'s annotated ndim={anno.ndim}'
)
anno_dtype = anno.dtype
if isinstance(anno_dtype, MatrixType):
anno_dtype = anno.dtype.dtype
if anno_dtype is not None:
if not check_type_match(dtype, anno_dtype):
raise TaichiCompilationError(
f' Arg {arg.name}\'s dtype {dtype.to_string()} doesn\'t match kernel\'s annotated dtype={anno_dtype.to_string()}'
)

if element_dim is None or element_dim == 0 or element_shape == (
1, ):
injected_args.append(ScalarNdarray(dtype, (2, ) * ndim))
elif element_dim == 1:
if anno.dtype is not None and not check_type_match(
dtype, anno.dtype):
raise TaichiCompilationError(
f' Arg {arg.name}\'s dtype {dtype.to_string()} doesn\'t match kernel\'s annotated dtype={anno.dtype.to_string()}'
)

if isinstance(dtype, VectorType):
injected_args.append(
VectorNdarray(element_shape[0],
dtype=dtype,
VectorNdarray(dtype.n,
dtype=dtype.dtype,
shape=(2, ) * ndim))
elif element_dim == 2:
elif isinstance(dtype, MatrixType):
injected_args.append(
MatrixNdarray(element_shape[0],
element_shape[1],
dtype=dtype,
MatrixNdarray(dtype.n,
dtype.m,
dtype=dtype.dtype,
shape=(2, ) * ndim))
else:
raise RuntimeError('')
injected_args.append(ScalarNdarray(dtype, (2, ) * ndim))
elif isinstance(anno, RWTextureType):
texture_shape = (2, ) * anno.num_dimensions
fmt = anno.fmt
Expand Down
4 changes: 4 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,10 @@ def get_shape(self):
return (self.n, )
return (self.n, self.m)

def to_string(self):
dtype_str = self.dtype.to_string() if self.dtype is not None else ''
return f'MatrixType[{self.n},{self.m}, {dtype_str}]'


class VectorType(MatrixType):
def __init__(self, n, dtype):
Expand Down
17 changes: 7 additions & 10 deletions tests/cpp/aot/python_scripts/graph_aot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@ def run0(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.i32)):
arr[i] += base + i

@ti.kernel
def run1(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.i32)):
for i in arr:
arr[i] += base + i

@ti.kernel
def run2(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.i32)):
def run1(base: int, arr: ti.types.ndarray(ndim=1,
dtype=ti.types.vector(1,
ti.i32))):
for i in arr:
arr[i] += base + i

Expand All @@ -41,12 +38,12 @@ def run2(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.i32)):
g_builder = ti.graph.GraphBuilder()

g_builder.dispatch(run0, base0, arr0)
g_builder.dispatch(run1, base1, arr0)
g_builder.dispatch(run2, base2, arr0)
g_builder.dispatch(run0, base1, arr0)
g_builder.dispatch(run0, base2, arr0)

g_builder.dispatch(run0, base0, arr1)
g_builder.dispatch(run1, base0, arr1)
g_builder.dispatch(run1, base1, arr1)
g_builder.dispatch(run2, base2, arr1)
g_builder.dispatch(run1, base2, arr1)

run_graph = g_builder.compile()

Expand Down
23 changes: 23 additions & 0 deletions tests/python/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,29 @@ def run(arr: ti.types.ndarray(), val1: ti.f32, val2: ti.template()):
assert args_count == 2, res # `arr` and `val1`


@test_utils.test(arch=[ti.opengl, ti.vulkan])
def test_aot_ndarray_without_template_args():
@ti.kernel
def kernel1(arr: ti.types.ndarray(dtype=ti.f32, ndim=2)):
for I in ti.grouped(arr):
arr[I] = 0.

@ti.kernel
def kernel2(arr: ti.types.ndarray(dtype=ti.math.vec2, ndim=2)):
for I in ti.grouped(arr):
arr[I] = 0.

@ti.kernel
def kernel3(arr: ti.types.ndarray(dtype=ti.math.mat2, ndim=2)):
for I in ti.grouped(arr):
arr[I] = 0.

m = ti.aot.Module()
m.add_kernel(kernel1)
m.add_kernel(kernel2)
m.add_kernel(kernel3)


@test_utils.test(arch=[ti.opengl, ti.vulkan])
def test_archive():
density = ti.field(float, shape=(4, 4))
Expand Down

0 comments on commit 800ff2b

Please sign in to comment.