diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 89c5589b94195..af10efa2d4f18 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -9,8 +9,8 @@ from taichi.lang._ndrange import GroupedNDRange, _Ndrange from taichi.lang.any_array import AnyArray, AnyArrayAccess from taichi.lang.enums import SNodeGradType -from taichi.lang.exception import (TaichiRuntimeError, TaichiSyntaxError, - TaichiTypeError) +from taichi.lang.exception import (TaichiCompilationError, TaichiRuntimeError, + TaichiSyntaxError, TaichiTypeError) from taichi.lang.expr import Expr, make_expr_group from taichi.lang.field import Field, ScalarField from taichi.lang.kernel_arguments import SparseMatrixProxy @@ -135,6 +135,21 @@ def begin_frontend_if(ast_builder, cond): ast_builder.begin_frontend_if(Expr(cond).ptr) +@taichi_scope +def _calc_slice(index, default_stop): + start, stop, step = index.start or 0, index.stop or default_stop, index.step or 1 + + def check_validity(x): + # TODO(mzmzm): support variable in slice + if isinstance(x, Expr): + raise TaichiCompilationError( + "Taichi does not support variables in slice now, please use constant instead of it." + ) + + check_validity(start), check_validity(stop), check_validity(step) + return [_ for _ in range(start, stop, step)] + + @taichi_scope def subscript(ast_builder, value, diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 05d2f6d6f7675..fd687fec6c8b4 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -249,10 +249,10 @@ def _subscript(self, is_global_mat, *indices, get_ref=False): j = 0 if len(indices) == 1 else indices[1] has_slice = False if isinstance(i, slice): - i = self._calc_slice(i, self.n) + i = impl._calc_slice(i, self.n) has_slice = True if isinstance(j, slice): - j = self._calc_slice(j, self.m) + j = impl._calc_slice(j, self.m) has_slice = True if has_slice: @@ -280,19 +280,6 @@ def _subscript(self, is_global_mat, *indices, get_ref=False): self.dynamic_index_stride) return self._get_entry(i, j) - def _calc_slice(self, index, default_stop): - start, stop, step = index.start or 0, index.stop or default_stop, index.step or 1 - - def check_validity(x): - # TODO(mzmzm): support variable in slice - if isinstance(x, expr.Expr): - raise TaichiCompilationError( - "Taichi does not support variables in slice now, please use constant instead of it." - ) - - check_validity(start), check_validity(stop), check_validity(step) - return [_ for _ in range(start, stop, step)] - class _MatrixEntriesInitializer: def pyscope_or_ref(self, arr):