Skip to content

Commit

Permalink
Refactor _calc_slice
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier committed Oct 19, 2022
1 parent b4d1087 commit d4c54f2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
19 changes: 17 additions & 2 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from taichi.lang._ndrange import GroupedNDRange, _Ndrange
from taichi.lang.any_array import AnyArray, AnyArrayAccess
from taichi.lang.enums import SNodeGradType
from taichi.lang.exception import (TaichiRuntimeError, TaichiSyntaxError,
TaichiTypeError)
from taichi.lang.exception import (TaichiCompilationError, TaichiRuntimeError,
TaichiSyntaxError, TaichiTypeError)
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field, ScalarField
from taichi.lang.kernel_arguments import SparseMatrixProxy
Expand Down Expand Up @@ -135,6 +135,21 @@ def begin_frontend_if(ast_builder, cond):
ast_builder.begin_frontend_if(Expr(cond).ptr)


@taichi_scope
def _calc_slice(index, default_stop):
start, stop, step = index.start or 0, index.stop or default_stop, index.step or 1

def check_validity(x):
# TODO(mzmzm): support variable in slice
if isinstance(x, Expr):
raise TaichiCompilationError(
"Taichi does not support variables in slice now, please use constant instead of it."
)

check_validity(start), check_validity(stop), check_validity(step)
return [_ for _ in range(start, stop, step)]


@taichi_scope
def subscript(ast_builder,
value,
Expand Down
17 changes: 2 additions & 15 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,10 @@ def _subscript(self, is_global_mat, *indices, get_ref=False):
j = 0 if len(indices) == 1 else indices[1]
has_slice = False
if isinstance(i, slice):
i = self._calc_slice(i, self.n)
i = impl._calc_slice(i, self.n)
has_slice = True
if isinstance(j, slice):
j = self._calc_slice(j, self.m)
j = impl._calc_slice(j, self.m)
has_slice = True

if has_slice:
Expand Down Expand Up @@ -280,19 +280,6 @@ def _subscript(self, is_global_mat, *indices, get_ref=False):
self.dynamic_index_stride)
return self._get_entry(i, j)

def _calc_slice(self, index, default_stop):
start, stop, step = index.start or 0, index.stop or default_stop, index.step or 1

def check_validity(x):
# TODO(mzmzm): support variable in slice
if isinstance(x, expr.Expr):
raise TaichiCompilationError(
"Taichi does not support variables in slice now, please use constant instead of it."
)

check_validity(start), check_validity(stop), check_validity(step)
return [_ for _ in range(start, stop, step)]


class _MatrixEntriesInitializer:
def pyscope_or_ref(self, arr):
Expand Down

0 comments on commit d4c54f2

Please sign in to comment.