Skip to content

Commit

Permalink
[Lang] MatrixField refactor 4/n: Disallow invalid matrix field defini…
Browse files Browse the repository at this point in the history
…tion (#6074)

Related issue = #5959, #4857

Support for different element types of matrix fields was introduced in
#2135 for quant. As discussed in
#4857 (comment),
the only case we need to support is different element types with **same
compute type**. This PR adds the validity check and removes test cases
which are actually not allowed.

<!--
Thank you for your contribution!

If it is your first time contributing to Taichi, please read our
Contributor Guidelines:
  https://docs.taichi-lang.org/docs/contributor_guide

- Please always prepend your PR title with tags such as [CUDA], [Lang],
[Doc], [Example]. For a complete list of valid PR tags, please check out
https://github.com/taichi-dev/taichi/blob/master/misc/prtags.json.
- Use upper-case tags (e.g., [Metal]) for PRs that change public APIs.
Otherwise, please use lower-case tags (e.g., [metal]).
- More details:
https://docs.taichi-lang.org/docs/contributor_guide#pr-title-format-and-tags

- Please fill in the issue number that this PR relates to.
- If your PR fixes the issue **completely**, use the `close` or `fixes`
prefix so that GitHub automatically closes the issue when the PR is
merged. For example,
    Related issue = close #2345
- If the PR does not belong to any existing issue, free to leave it
blank.
-->

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Sep 16, 2022
1 parent 50c925b commit 8fd7522
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 77 deletions.
13 changes: 13 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,19 @@ class MatrixFieldExpression : public Expression {
MatrixFieldExpression(const std::vector<Expr> &fields,
const std::vector<int> &element_shape)
: fields(fields), element_shape(element_shape) {
for (auto &field : fields) {
TI_ASSERT(field.is<FieldExpression>());
}
TI_ASSERT(!fields.empty());
auto compute_type =
fields[0].cast<FieldExpression>()->dt->get_compute_type();
for (auto &field : fields) {
if (field.cast<FieldExpression>()->dt->get_compute_type() !=
compute_type) {
throw TaichiRuntimeError(
"Member fields of a matrix field must have the same compute type");
}
}
}

void type_check(CompileConfig *config) override {
Expand Down
100 changes: 23 additions & 77 deletions tests/python/test_matrix_different_type.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,14 @@
from pytest import approx
import pytest

import taichi as ti
from tests import test_utils


# TODO: test more matrix operations
@test_utils.test()
def test_vector():
type_list = [ti.f32, ti.i32]

a = ti.Vector.field(len(type_list), dtype=type_list, shape=())
b = ti.Vector.field(len(type_list), dtype=type_list, shape=())
c = ti.Vector.field(len(type_list), dtype=type_list, shape=())

@ti.kernel
def init():
a[None] = [1.0, 3]
b[None] = [2.0, 4]
c[None] = a[None] + b[None]

def verify():
assert isinstance(a[None][0], float)
assert isinstance(a[None][1], int)
assert isinstance(b[None][0], float)
assert isinstance(b[None][1], int)
assert c[None][0] == 3.0
assert c[None][1] == 7

init()
verify()


# TODO: Support different element types of Matrix on opengl
@test_utils.test(require=ti.extension.data64, exclude=ti.opengl)
def test_matrix():
type_list = [[ti.f32, ti.i32], [ti.i64, ti.f32]]
a = ti.Matrix.field(len(type_list),
len(type_list[0]),
dtype=type_list,
shape=())
b = ti.Matrix.field(len(type_list),
len(type_list[0]),
dtype=type_list,
shape=())
c = ti.Matrix.field(len(type_list),
len(type_list[0]),
dtype=type_list,
shape=())

@ti.kernel
def init():
a[None] = [[1.0, 3], [1, 3.0]]
b[None] = [[2.0, 4], [-2, -3.0]]
c[None] = a[None] + b[None]

def verify():
assert isinstance(a[None][0, 0], float)
assert isinstance(a[None][0, 1], int)
assert isinstance(b[None][0, 0], float)
assert isinstance(b[None][0, 1], int)
assert c[None][0, 0] == 3.0
assert c[None][0, 1] == 7
assert c[None][1, 0] == -1
assert c[None][1, 1] == 0.0

init()
verify()


@test_utils.test(require=ti.extension.quant_basic)
def test_quant_type():
qit1 = ti.types.quant.int(bits=10, signed=True)
qfxt1 = ti.types.quant.fixed(bits=10, signed=True, scale=0.1)
qit2 = ti.types.quant.int(bits=22, signed=False)
qfxt2 = ti.types.quant.fixed(bits=22, signed=False, scale=0.1)
type_list = [[qit1, qfxt2], [qfxt1, qit2]]
def test_valid():
qflt = ti.types.quant.float(exp=8, frac=5, signed=True)
qfxt = ti.types.quant.fixed(bits=10, signed=True, scale=0.1)
type_list = [[qflt, qfxt], [qflt, qfxt]]
a = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list)
b = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list)
c = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list)
Expand All @@ -99,15 +33,27 @@ def test_quant_type():

@ti.kernel
def init():
a[0] = [[1, 3.], [2., 1]]
b[0] = [[2, 4.], [-2., 1]]
a[0] = [[1.0, 3.0], [2.0, 1.0]]
b[0] = [[2.0, 4.0], [-2.0, 1.0]]
c[0] = a[0] + b[0]

def verify():
assert c[0][0, 0] == approx(3, 1e-3)
assert c[0][0, 1] == approx(7.0, 1e-3)
assert c[0][1, 0] == approx(0, 1e-3)
assert c[0][1, 1] == approx(2, 1e-3)
assert c[0][0, 0] == pytest.approx(3.0)
assert c[0][0, 1] == pytest.approx(7.0)
assert c[0][1, 0] == pytest.approx(0.0)
assert c[0][1, 1] == pytest.approx(2.0)

init()
verify()


@test_utils.test(require=ti.extension.quant_basic)
def test_invalid():
qit = ti.types.quant.int(bits=10, signed=True)
qfxt = ti.types.quant.fixed(bits=10, signed=True, scale=0.1)
type_list = [qit, qfxt]
with pytest.raises(
RuntimeError,
match=
'Member fields of a matrix field must have the same compute type'):
a = ti.Vector.field(len(type_list), dtype=type_list)

0 comments on commit 8fd7522

Please sign in to comment.