Skip to content

Commit

Permalink
[Lang] Fix invalid assertion for matrix values
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 committed Sep 21, 2022
1 parent a4b24d1 commit b37816f
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@
]


def check_matrix(mat):
if isinstance(mat, ti.lang.matrix.Vector):
assert all(mat == 1)
elif isinstance(mat, ti.lang.matrix.Matrix):
for i in range(mat.m):
for j in range(mat.n):
assert (mat[i, j] == 1)
else:
assert False


@test_utils.test(arch=get_host_arch_list())
def test_python_scope_vector_operations():
for ops in vector_operation_types:
Expand Down Expand Up @@ -831,11 +842,11 @@ def func(a: ti.types.ndarray()):
x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5)
func(x)

assert (x[0] == [[0, 1], [2, 3]])
assert (x[1] == [[1, 2], [3, 4]])
assert (x[2] == [[2, 3], [4, 5]])
assert (x[3] == [[3, 4], [5, 6]])
assert (x[4] == [[4, 5], [6, 7]])
check_matrix(x[0] == [[0, 1], [2, 3]])
check_matrix(x[1] == [[1, 2], [3, 4]])
check_matrix(x[2] == [[2, 3], [4, 5]])
check_matrix(x[3] == [[3, 4], [5, 6]])
check_matrix(x[4] == [[4, 5], [6, 7]])


@test_utils.test(arch=[ti.cuda, ti.cpu],
Expand All @@ -853,8 +864,8 @@ def func(a: ti.types.ndarray()):
x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5)
func(x)

assert (x[3] == [[1, 2], [3, 4]])
assert (x[4] == [[2, 3], [4, 5]])
check_matrix(x[3] == [[1, 2], [3, 4]])
check_matrix(x[4] == [[2, 3], [4, 5]])


@test_utils.test(arch=[ti.cuda, ti.cpu],
Expand All @@ -872,8 +883,10 @@ def func(a: ti.types.ndarray()):
x = ti.Matrix.ndarray(2, 2, ti.f32, shape=5)
func(x)

assert (x[0] == [[0., 1.], [2., 3.]])
assert (x[1] == [[3., 4.], [5., 6.]])
assert (x[2] == [[-0., -1.], [-2., -3.]])
assert (x[3] == [[20.08553696, 54.59814835], [148.41316223, 403.42880249]])
assert (x[4] == [[4.48168898, 7.38905621], [12.18249416, 20.08553696]])
check_matrix(x[0] == [[0., 1.], [2., 3.]])
check_matrix(x[1] == [[3., 4.], [5., 6.]])
check_matrix(x[2] == [[-0., -1.], [-2., -3.]])
check_matrix(x[3] < [[20.086, 54.60], [148.42, 403.43]])
check_matrix(x[3] > [[20.085, 54.59], [148.41, 403.42]])
check_matrix(x[4] < [[4.49, 7.39], [12.19, 20.09]])
check_matrix(x[4] > [[4.48, 7.38], [12.18, 20.08]])

0 comments on commit b37816f

Please sign in to comment.