diff --git a/python/taichi/lang/_ndarray.py b/python/taichi/lang/_ndarray.py index 30514bac43393..7c0298ac37935 100644 --- a/python/taichi/lang/_ndarray.py +++ b/python/taichi/lang/_ndarray.py @@ -4,7 +4,7 @@ from taichi.lang.enums import Layout from taichi.lang.util import cook_dtype, python_scope, to_numpy_type from taichi.types import primitive_types -from taichi.types.ndarray_type import SpecializeNdarrayType +from taichi.types.ndarray_type import NdarrayTypeMetadata class Ndarray: @@ -23,8 +23,7 @@ def __init__(self): self.arr = None def get_type(self): - return SpecializeNdarrayType(self.element_type, self.shape, - self.layout) + return NdarrayTypeMetadata(self.element_type, self.shape, self.layout) @property def element_shape(self): diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index e5365981c7932..fab88a564e267 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -337,13 +337,13 @@ def extract_arg(arg, anno): return arg if isinstance(anno, ndarray_type.NdarrayType): if isinstance(arg, taichi.lang._ndarray.ScalarNdarray): - anno.match(arg.get_type()) + anno.check_matched(arg.get_type()) return arg.dtype, len(arg.shape), (), Layout.AOS if isinstance(arg, taichi.lang.matrix.VectorNdarray): - anno.match(arg.get_type()) + anno.check_matched(arg.get_type()) return arg.dtype, len(arg.shape) + 1, (arg.n, ), arg.layout if isinstance(arg, taichi.lang.matrix.MatrixNdarray): - anno.match(arg.get_type()) + anno.check_matched(arg.get_type()) return arg.dtype, len(arg.shape) + 2, (arg.n, arg.m), arg.layout # external arrays diff --git a/python/taichi/types/ndarray_type.py b/python/taichi/types/ndarray_type.py index 61be06f25213c..58ce958f09cfe 100644 --- a/python/taichi/types/ndarray_type.py +++ b/python/taichi/types/ndarray_type.py @@ -1,7 +1,7 @@ from taichi.types.primitive_types import f32 -class SpecializeNdarrayType: +class NdarrayTypeMetadata: def __init__(self, element_type, shape=None, layout=None): self.element_type = element_type self.shape = shape @@ -43,7 +43,7 @@ def __init__(self, self.field_dim = field_dim self.layout = layout - def match(self, ndarray_type: SpecializeNdarrayType): + def check_matched(self, ndarray_type: NdarrayTypeMetadata): if self.element_dim is not None and self.element_dim != len( ndarray_type.element_type.shape): raise ValueError(