Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Support simple matrix slicing #4488

Merged
merged 11 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 51 additions & 11 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from taichi.lang._ndarray import Ndarray, NdarrayHostAccess
from taichi.lang.common_ops import TaichiOperations
from taichi.lang.enums import Layout
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
from taichi.lang.exception import (TaichiCompilationError, TaichiSyntaxError,
TaichiTypeError)
from taichi.lang.field import Field, ScalarField, SNodeHostAccess
from taichi.lang.util import (cook_dtype, in_python_scope, python_scope,
taichi_scope, to_numpy_type, to_pytorch_type,
Expand Down Expand Up @@ -259,21 +260,60 @@ def _get_slice(self, a, b):
b = range(b.start or 0, b.stop or self.m, b.step or 1)
return Matrix([[self(i, j) for j in b] for i in a])

def _cal_slice(self, index, dim):
start, stop, step = index.start or 0, index.stop or (
self.n if dim is 0 else self.m), index.step or 1

def helper(x):
if isinstance(x, expr.Expr):
if isinstance(x.ptr, int):
return x.ptr
mzmzm marked this conversation as resolved.
Show resolved Hide resolved
raise TaichiSyntaxError(
"The element type of slice of Matrix/Vector index must be a compile-time constant integer!"
)
return x

start, stop, step = helper(start), helper(stop), helper(step)
return [_ for _ in range(start, stop, step)]

@taichi_scope
def _subscript(self, *indices):
assert len(indices) in [1, 2]
i = indices[0]
j = 0 if len(indices) == 1 else indices[1]
if isinstance(i, slice) or isinstance(j, slice):
for a in (i, j):
if isinstance(a, slice):
if isinstance(a.start, expr.Expr) or isinstance(
a.step, expr.Expr) or isinstance(
a.stop, expr.Expr):
raise TaichiSyntaxError(
"The element type of slice of Matrix/Vector index must be a compile-time constant integer!"
)
return self._get_slice(i, j)
has_slice = False
if isinstance(i, slice):
i = self._cal_slice(i, 0)
has_slice = True
if isinstance(j, slice):
j = self._cal_slice(j, 1)
has_slice = True

if has_slice:
if not isinstance(i, list):
i = [i]
if not isinstance(j, list):
j = [j]

if self.local_tensor_proxy is not None:
assert self.dynamic_index_stride is not None
mzmzm marked this conversation as resolved.
Show resolved Hide resolved
if len(indices) == 1:
return Vector([
impl.make_tensor_element_expr(
self.local_tensor_proxy, (a, ), (self.n, ),
self.dynamic_index_stride) for a in i
])
return Matrix([[
impl.make_tensor_element_expr(self.local_tensor_proxy,
(a, b), (self.n, self.m),
self.dynamic_index_stride)
for b in j
] for a in i])
if isinstance(i[0], expr.Expr) or isinstance(j[0], expr.Expr):
raise TaichiCompilationError(
"Please turn on ti.init(..., dynamic_index=True) to support indexing with variables!"
)
return Matrix([[self(a, b) for b in j] for a in i])

if self.any_array_access:
return self.any_array_access.subscript(i, j)
Expand Down
37 changes: 36 additions & 1 deletion tests/python/test_simple_matrix_slice.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest

import taichi as ti
from tests import test_utils


@test_utils.test()
@test_utils.test(dynamic_index=True)
def test_slice():
b = 3

Expand All @@ -20,3 +22,36 @@ def foo2() -> ti.types.matrix(2, 2, dtype=ti.i32):
assert (v1 == ti.Vector([0, 2, 4])).all() == 1
m1 = foo2()
assert (m1 == ti.Matrix([[1, 3], [4, 6]])).all() == 1

@ti.kernel
def test_one_row_slice() -> ti.types.matrix(2, 1, dtype=ti.i32):
m = ti.Matrix([[1, 2, 3], [4, 5, 6]])
index = 1
return m[:, index]

@ti.kernel
def test_one_col_slice() -> ti.types.matrix(1, 3, dtype=ti.i32):
m = ti.Matrix([[1, 2, 3], [4, 5, 6]])
mzmzm marked this conversation as resolved.
Show resolved Hide resolved
index = 1
return m[index, :]

r1 = test_one_row_slice()
assert (r1 == ti.Matrix([[2], [5]])).all() == 1
c1 = test_one_col_slice()
assert (c1 == ti.Matrix([[4, 5, 6]])).all() == 1


@test_utils.test(dynamic_index=False)
def test_no_dyn():
@ti.kernel
def test_one_col_slice() -> ti.types.matrix(1, 3, dtype=ti.i32):
m = ti.Matrix([[1, 2, 3], [4, 5, 6]])
index = 1
return m[index, :]

with pytest.raises(
ti.TaichiCompilationError,
match=
"Please turn on ti.init\(..., dynamic_index=True\) to support indexing with variables!"
):
test_one_col_slice()