Skip to content

Commit

Permalink
[refactor] Resolve comments from #5065
Browse files Browse the repository at this point in the history
  • Loading branch information
ailzhang committed May 31, 2022
1 parent 2faf489 commit 468cb8e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
5 changes: 2 additions & 3 deletions python/taichi/lang/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from taichi.lang.enums import Layout
from taichi.lang.util import cook_dtype, python_scope, to_numpy_type
from taichi.types import primitive_types
from taichi.types.ndarray_type import SpecializeNdarrayType
from taichi.types.ndarray_type import NdarrayTypeMetadata


class Ndarray:
Expand All @@ -23,8 +23,7 @@ def __init__(self):
self.arr = None

def get_type(self):
return SpecializeNdarrayType(self.element_type, self.shape,
self.layout)
return NdarrayTypeMetadata(self.element_type, self.shape, self.layout)

@property
def element_shape(self):
Expand Down
6 changes: 3 additions & 3 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,13 @@ def extract_arg(arg, anno):
return arg
if isinstance(anno, ndarray_type.NdarrayType):
if isinstance(arg, taichi.lang._ndarray.ScalarNdarray):
anno.match(arg.get_type())
anno.check_matched(arg.get_type())
return arg.dtype, len(arg.shape), (), Layout.AOS
if isinstance(arg, taichi.lang.matrix.VectorNdarray):
anno.match(arg.get_type())
anno.check_matched(arg.get_type())
return arg.dtype, len(arg.shape) + 1, (arg.n, ), arg.layout
if isinstance(arg, taichi.lang.matrix.MatrixNdarray):
anno.match(arg.get_type())
anno.check_matched(arg.get_type())
return arg.dtype, len(arg.shape) + 2, (arg.n,
arg.m), arg.layout
# external arrays
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/types/ndarray_type.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from taichi.types.primitive_types import f32


class SpecializeNdarrayType:
class NdarrayTypeMetadata:
def __init__(self, element_type, shape=None, layout=None):
self.element_type = element_type
self.shape = shape
Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(self,
self.field_dim = field_dim
self.layout = layout

def match(self, ndarray_type: SpecializeNdarrayType):
def check_matched(self, ndarray_type: NdarrayTypeMetadata):
if self.element_dim is not None and self.element_dim != len(
ndarray_type.element_type.shape):
raise ValueError(
Expand Down

0 comments on commit 468cb8e

Please sign in to comment.