Skip to content

Commit

Permalink
[Lang] Indexing for new local matrix implementation (#5783)
Browse files Browse the repository at this point in the history
Related issue = #5478 
A part of PR #5551 

<!--
Thank you for your contribution!

If it is your first time contributing to Taichi, please read our
Contributor Guidelines:
  https://docs.taichi-lang.org/docs/contributor_guide

- Please always prepend your PR title with tags such as [CUDA], [Lang],
[Doc], [Example]. For a complete list of valid PR tags, please check out
https://github.com/taichi-dev/taichi/blob/master/misc/prtags.json.
- Use upper-case tags (e.g., [Metal]) for PRs that change public APIs.
Otherwise, please use lower-case tags (e.g., [metal]).
- More details:
https://docs.taichi-lang.org/docs/contributor_guide#pr-title-format-and-tags

- Please fill in the issue number that this PR relates to.
- If your PR fixes the issue **completely**, use the `close` or `fixes`
prefix so that GitHub automatically closes the issue when the PR is
merged. For example,
    Related issue = close #2345
- If the PR does not belong to any existing issue, free to leave it
blank.
-->

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yi Xu <xy_xuyi@foxmail.com>
  • Loading branch information
3 people authored Sep 14, 2022
1 parent e8d2a54 commit 6389f26
Show file tree
Hide file tree
Showing 20 changed files with 255 additions and 63 deletions.
15 changes: 9 additions & 6 deletions python/taichi/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from taichi.lang import impl, matrix, ops
from taichi.lang.impl import expr_init, get_runtime, grouped, static
from taichi.lang.kernel_impl import func, pyfunc
from taichi.lang.matrix import Matrix, Vector
from taichi.lang.matrix import Matrix, Vector, is_vector
from taichi.types import f32, f64
from taichi.types.annotations import template

Expand Down Expand Up @@ -59,6 +59,9 @@ def _matrix_transpose(mat):
Returns:
Transpose of the input matrix.
"""
if static(is_vector(mat)):
# Convert to row vector
return matrix.Matrix([[mat(i) for i in range(mat.n)]])
return matrix.Matrix([[mat(i, j) for i in range(mat.n)]
for j in range(mat.m)],
ndim=mat.ndim)
Expand All @@ -79,21 +82,21 @@ def _matrix_cross2d(self, other):


@pyfunc
def _matrix_outer_product(self, other):
"""Perform the outer product with the input Vector (1-D Matrix).
def _vector_outer_product(self, other):
"""Perform the outer product with the input Vector.
Args:
other (:class:`~taichi.lang.matrix.Matrix`): The input Vector (1-D Matrix) to perform the outer product.
other (:class:`~taichi.lang.matrix.Vector`): The input Vector to perform the outer product.
Returns:
:class:`~taichi.lang.matrix.Matrix`: The outer product result (Matrix) of the two Vectors.
"""
impl.static(
impl.static_assert(self.m == 1,
impl.static_assert(self.m == 1 and isinstance(self, Vector),
"lhs for outer_product is not a vector"))
impl.static(
impl.static_assert(other.m == 1,
impl.static_assert(other.m == 1 and isinstance(other, Vector),
"rhs for outer_product is not a vector"))
return matrix.Matrix([[self[i] * other[j] for j in range(other.n)]
for i in range(self.n)])
Expand Down
12 changes: 10 additions & 2 deletions python/taichi/_kernels.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Iterable

from taichi._lib.utils import get_os_name
from taichi.lang import ops
from taichi.lang._ndrange import ndrange
Expand Down Expand Up @@ -241,9 +243,15 @@ def fill_matrix(mat: template(), vals: template()):
for p in static(range(mat.n)):
for q in static(range(mat.m)):
if static(mat[I].ndim == 2):
mat[I][p, q] = vals[p][q]
if static(isinstance(vals[p], Iterable)):
mat[I][p, q] = vals[p][q]
else:
mat[I][p, q] = vals[p]
else:
mat[I][p] = vals[p][q]
if static(isinstance(vals[p], Iterable)):
mat[I][p] = vals[p][q]
else:
mat[I][p] = vals[p]


@kernel
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/_ndrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(self, r):

def __iter__(self):
for ind in self.r:
yield _IntermediateMatrix(len(ind), 1, list(ind))
yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1)


__all__ = ['ndrange']
7 changes: 5 additions & 2 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,10 +670,13 @@ def build_Return(ctx, node):
ti_ops.cast(expr.Expr(node.value.ptr),
ctx.func.return_type).ptr))
elif isinstance(ctx.func.return_type, MatrixType):
item_iter = iter(node.value.ptr.to_list())\
if isinstance(node.value.ptr, Vector) or node.value.ptr.ndim == 1\
else itertools.chain.from_iterable(node.value.ptr.to_list())
ctx.ast_builder.create_kernel_exprgroup_return(
expr.make_expr_group([
ti_ops.cast(exp, ctx.func.return_type.dtype) for exp in
itertools.chain.from_iterable(node.value.ptr.to_list())
ti_ops.cast(exp, ctx.func.return_type.dtype)
for exp in item_iter
]))
else:
raise TaichiSyntaxError(
Expand Down
7 changes: 5 additions & 2 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from taichi.lang.field import Field, ScalarField
from taichi.lang.kernel_arguments import SparseMatrixProxy
from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType,
_IntermediateMatrix, _MatrixFieldElement,
make_matrix)
Vector, _IntermediateMatrix,
_MatrixFieldElement, make_matrix)
from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance,
MeshRelationAccessProxy,
MeshReorderedMatrixFieldProxy,
Expand Down Expand Up @@ -64,6 +64,9 @@ def expr_init(rhs):
entries = [[rhs(i, j) for j in range(rhs.m)]
for i in range(rhs.n)]
return make_matrix(entries)
if isinstance(rhs, Vector) or getattr(rhs, "ndim", None) == 1:
# _IntermediateMatrix may reach here
return Vector(rhs.to_list(), ndim=rhs.ndim)
return Matrix(rhs.to_list(), ndim=rhs.ndim)
if isinstance(rhs, SharedArray):
return rhs
Expand Down
5 changes: 4 additions & 1 deletion python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from taichi.lang.any_array import AnyArray
from taichi.lang.enums import Layout
from taichi.lang.expr import Expr
from taichi.lang.matrix import Matrix, MatrixType
from taichi.lang.matrix import Matrix, MatrixType, Vector, VectorType
from taichi.lang.util import cook_dtype
from taichi.types.primitive_types import RefType, f32, u64

Expand Down Expand Up @@ -58,6 +58,9 @@ def decl_scalar_arg(dtype):


def decl_matrix_arg(matrixtype):
if isinstance(matrixtype, VectorType):
return Vector(
[decl_scalar_arg(matrixtype.dtype) for _ in range(matrixtype.n)])
return Matrix(
[[decl_scalar_arg(matrixtype.dtype) for _ in range(matrixtype.m)]
for _ in range(matrixtype.n)],
Expand Down
45 changes: 34 additions & 11 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def make_matrix(arr, dt=None):
[expr.Expr(elt).ptr for row in arr for elt in row]))


def is_vector(x):
return isinstance(x, Vector) or getattr(x, "ndim", None) == 1


class _MatrixBaseImpl:
def __init__(self, m, n, entries):
self.m = m
Expand Down Expand Up @@ -257,8 +261,7 @@ def _subscript(self, is_global_mat, *indices, get_ref=False):
is_ref=get_ref)
return Matrix([[self._subscript(is_global_mat, a, b) for b in j]
for a in i],
is_ref=get_ref,
ndim=1)
is_ref=get_ref)

if self.any_array_access:
return self.any_array_access.subscript(i, j)
Expand Down Expand Up @@ -441,7 +444,7 @@ def __init__(self,
elif isinstance(arr[0], Matrix):
raise Exception('cols/rows required when using list of vectors')
else:
is_matrix = isinstance(arr[0], Iterable)
is_matrix = isinstance(arr[0], Iterable) and not is_vector(self)
initializer = _make_entries_initializer(is_matrix)
self.ndim = 2 if is_matrix else 1

Expand Down Expand Up @@ -490,17 +493,26 @@ def __init__(self,

def _element_wise_binary(self, foo, other):
other = self._broadcast_copy(other)
if is_vector(self):
return Vector([foo(self(i), other(i)) for i in range(self.n)],
ndim=self.ndim)
return Matrix([[foo(self(i, j), other(i, j)) for j in range(self.m)]
for i in range(self.n)],
ndim=self.ndim)

def _broadcast_copy(self, other):
if isinstance(other, (list, tuple)):
other = Matrix(other)
if is_vector(self):
other = Vector(other, ndim=self.ndim)
else:
other = Matrix(other, ndim=self.ndim)
if not isinstance(other, Matrix):
other = Matrix([[other for _ in range(self.m)]
for _ in range(self.n)],
ndim=self.ndim)
if isinstance(self, Vector):
other = Vector([other for _ in range(self.n)])
else:
other = Matrix([[other for _ in range(self.m)]
for _ in range(self.n)],
ndim=self.ndim)
assert self.m == other.m and self.n == other.n, f"Dimension mismatch between shapes ({self.n}, {self.m}), ({other.n}, {other.m})"
return other

Expand Down Expand Up @@ -645,6 +657,8 @@ def to_list(self):
This is similar to `numpy.ndarray`'s `flatten` and `ravel` methods,
the difference is that this function always returns a new list.
"""
if is_vector(self):
return [self(i) for i in range(self.n)]
return [[self(i, j) for j in range(self.m)] for i in range(self.n)]

@taichi_scope
Expand All @@ -665,6 +679,10 @@ def cast(self, dtype):
>>> B
[0.0, 1.0, 2.0]
"""
if is_vector(self):
# when using _IntermediateMatrix, we can only check `self.ndim`
return Vector(
[ops_mod.cast(self(i), dtype) for i in range(self.n)])
return Matrix(
[[ops_mod.cast(self(i, j), dtype) for j in range(self.m)]
for i in range(self.n)],
Expand Down Expand Up @@ -1421,8 +1439,8 @@ def outer_product(self, other):
:class:`~taichi.Matrix`: The outer product of the two Vectors.
"""
from taichi._funcs import \
_matrix_outer_product # pylint: disable=C0415
return _matrix_outer_product(self, other)
_vector_outer_product # pylint: disable=C0415
return _vector_outer_product(self, other)


class Vector(Matrix):
Expand Down Expand Up @@ -1600,7 +1618,9 @@ def fill(self, val):
elif isinstance(val,
(list, tuple)) and isinstance(val[0], numbers.Number):
assert self.m == 1
val = tuple([(v, ) for v in val])
val = tuple(val)
elif is_vector(val) or self.ndim == 1:
val = tuple([(val(i), ) for i in range(self.n)])
elif isinstance(val, Matrix):
val_tuple = []
for i in range(val.n):
Expand All @@ -1611,7 +1631,8 @@ def fill(self, val):
val_tuple.append(row)
val = tuple(val_tuple)
assert len(val) == self.n
assert len(val[0]) == self.m
if self.ndim != 1:
assert len(val[0]) == self.m

if in_python_scope():
from taichi._kernels import fill_matrix # pylint: disable=C0415
Expand Down Expand Up @@ -1724,6 +1745,8 @@ def __getitem__(self, key):
self._initialize_host_accessors()
key = self._pad_key(key)
_host_access = self._host_access(key)
if self.ndim == 1:
return Vector([_host_access[i] for i in range(self.n)])
return Matrix([[_host_access[i * self.m + j] for j in range(self.m)]
for i in range(self.n)],
ndim=self.ndim)
Expand Down
56 changes: 40 additions & 16 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
llvm::Type::getInt8PtrTy(*llvm_context)));
}

std::tuple<llvm::Value *, llvm::Type *> create_value_and_type(
llvm::Value *value,
DataType dt) {
auto value_type = tlctx->get_data_type(dt);
if (dt->is_primitive(PrimitiveTypeID::f32) ||
dt->is_primitive(PrimitiveTypeID::f16)) {
value_type = tlctx->get_data_type(PrimitiveType::f64);
value = builder->CreateFPExt(value, value_type);
}
if (dt->is_primitive(PrimitiveTypeID::i8)) {
value_type = tlctx->get_data_type(PrimitiveType::i16);
value = builder->CreateSExt(value, value_type);
}
if (dt->is_primitive(PrimitiveTypeID::u8)) {
value_type = tlctx->get_data_type(PrimitiveType::u16);
value = builder->CreateZExt(value, value_type);
}
return std::make_tuple(value, value_type);
}

void visit(PrintStmt *stmt) override {
TI_ASSERT_INFO(stmt->contents.size() < 32,
"CUDA `print()` doesn't support more than 32 entries");
Expand All @@ -74,31 +94,33 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
std::vector<llvm::Value *> values;

std::string formats;
size_t num_contents = 0;
for (auto const &content : stmt->contents) {
if (std::holds_alternative<Stmt *>(content)) {
auto arg_stmt = std::get<Stmt *>(content);

formats += data_type_format(arg_stmt->ret_type);

auto value_type = tlctx->get_data_type(arg_stmt->ret_type);
auto value = llvm_val[arg_stmt];
if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) ||
arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) {
value_type = tlctx->get_data_type(PrimitiveType::f64);
value = builder->CreateFPExt(value, value_type);
}
if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::i8)) {
value_type = tlctx->get_data_type(PrimitiveType::i16);
value = builder->CreateSExt(value, value_type);
}
if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::u8)) {
value_type = tlctx->get_data_type(PrimitiveType::u16);
value = builder->CreateZExt(value, value_type);
if (arg_stmt->ret_type->is<TensorType>()) {
auto dtype = arg_stmt->ret_type->cast<TensorType>();
num_contents += dtype->get_num_elements();
auto elem_type = dtype->get_element_type();
for (int i = 0; i < dtype->get_num_elements(); ++i) {
auto elem_value = builder->CreateExtractElement(value, i);
auto [casted_value, elem_value_type] =
create_value_and_type(elem_value, elem_type);
types.push_back(elem_value_type);
values.push_back(casted_value);
}
} else {
num_contents++;
auto [val, dtype] = create_value_and_type(value, arg_stmt->ret_type);
types.push_back(dtype);
values.push_back(val);
}

types.push_back(value_type);
values.push_back(value);
} else {
num_contents += 1;
auto arg_str = std::get<std::string>(content);

auto value = builder->CreateGlobalStringPtr(arg_str, "content_string");
Expand All @@ -110,6 +132,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
values.push_back(value);
formats += "%s";
}
TI_ASSERT_INFO(num_contents < 32,
"CUDA `print()` doesn't support more than 32 entries");
}

llvm_val[stmt] = create_print(formats, types, values);
Expand Down
Loading

0 comments on commit 6389f26

Please sign in to comment.