diff --git a/python/taichi/aot/utils.py b/python/taichi/aot/utils.py index 05bc0a7c1e074..57dc17b65b222 100644 --- a/python/taichi/aot/utils.py +++ b/python/taichi/aot/utils.py @@ -54,7 +54,7 @@ def produce_injected_args(kernel, symbolic_args=None): raise TaichiCompilationError( f'{field_dim} from Arg {arg.name} doesn\'t match kernel\'s annotated field_dim={anno.field_dim}' ) - if dtype != anno.dtype: + if anno.dtype is not None and dtype != anno.dtype: raise TaichiCompilationError( f' Arg {arg.name}\'s dtype {dtype.to_string()} doesn\'t match kernel\'s annotated dtype={anno.dtype.to_string()}' ) diff --git a/python/taichi/types/ndarray_type.py b/python/taichi/types/ndarray_type.py index 7c097b2e52f1a..95d51ff531a99 100644 --- a/python/taichi/types/ndarray_type.py +++ b/python/taichi/types/ndarray_type.py @@ -1,6 +1,3 @@ -from taichi.types.primitive_types import f32 - - class NdarrayTypeMetadata: def __init__(self, element_type, shape=None, layout=None): self.element_type = element_type @@ -20,13 +17,12 @@ class NdarrayType: field_dim (Union[Int, NoneType]): None if not specified, number of field dimensions. This argument is ignored for external arrays for now. layout (Union[Layout, NoneType], optional): None if not specified (will be treated as Layout.AOS for external arrays), Layout.AOS or Layout.SOA. """ - def __init__( - self, - dtype=f32, # TODO: default should be None - element_dim=None, - element_shape=None, - field_dim=None, - layout=None): + def __init__(self, + dtype=None, + element_dim=None, + element_shape=None, + field_dim=None, + layout=None): if element_dim is not None and (element_dim < 0 or element_dim > 2): raise ValueError( "Only scalars, vectors, and matrices are allowed as elements of ti.types.ndarray()" diff --git a/tests/python/test_graph.py b/tests/python/test_graph.py index 7e27516e5455a..d8fa102e1c32b 100644 --- a/tests/python/test_graph.py +++ b/tests/python/test_graph.py @@ -142,7 +142,7 @@ def test_arg_mismatched_ndarray_dtype(): n = 4 @ti.kernel - def test(pos: ti.types.ndarray(field_dim=1)): + def test(pos: ti.types.ndarray(dtype=ti.f32, field_dim=1)): for i in range(n): pos[i] = 2.5