Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Enable definition of local matrices/vectors #5782

Merged
merged 46 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
4ce08c8
cherrypick Matrix repr support
AD1024 Aug 15, 2022
549a359
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2022
07a4dc1
matrix assign
AD1024 Aug 15, 2022
d2264f4
Merge branch 'matrix-repr' of github.com:AD1024/taichi into matrix-repr
AD1024 Aug 15, 2022
efca3f0
move checks to caller side
AD1024 Aug 15, 2022
82c9413
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2022
38ec750
use ==
AD1024 Aug 15, 2022
13159fd
merge and format
AD1024 Aug 15, 2022
9c91103
refine impl
AD1024 Aug 17, 2022
28f3e0a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2022
bf719a3
no long in use
AD1024 Aug 17, 2022
72e8f26
Merge branch 'matrix-repr' of github.com:AD1024/taichi into matrix-repr
AD1024 Aug 17, 2022
65199ea
add some comments
AD1024 Aug 17, 2022
cbf1ea8
get rid of always-true condition
AD1024 Aug 23, 2022
4c8d6b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2022
1a9df8c
save
AD1024 Aug 23, 2022
834699e
some fixes for print and matrix expr
AD1024 Aug 23, 2022
23d7bf7
fix codegen alloca size
AD1024 Aug 23, 2022
6a8a8cb
unsupport empty matrix
AD1024 Aug 24, 2022
08926ef
only check and cast elements
AD1024 Aug 24, 2022
1ae02aa
fmt Vectors to one line
AD1024 Aug 24, 2022
17412e8
lift duplicate part
AD1024 Aug 24, 2022
5421fe1
clean-up
AD1024 Aug 24, 2022
b9fd3a9
clean-up cse code
AD1024 Aug 24, 2022
43a456a
breaks ci; keep as original impl
AD1024 Aug 24, 2022
78ad14a
handle alloca
AD1024 Aug 24, 2022
40825f4
move checks to front
AD1024 Aug 24, 2022
88d01b6
Merge branch 'master' into matrix-repr
AD1024 Aug 24, 2022
f395d2a
reuse code
AD1024 Aug 24, 2022
2a6a8e6
Revert "clean-up cse code"
AD1024 Aug 24, 2022
7f8ca37
clean up together
AD1024 Aug 24, 2022
988abb3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2022
93b3a03
also checks for tlctx
AD1024 Aug 24, 2022
13d2efd
Merge branch 'matrix-repr' of github.com:AD1024/taichi into matrix-repr
AD1024 Aug 24, 2022
b2e101a
format
AD1024 Aug 24, 2022
6fcf070
fix codegen: allocate pointer to vector
AD1024 Aug 24, 2022
f43d2a8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2022
c613318
check real matrix when allocating memory
AD1024 Aug 24, 2022
41b6641
Merge branch 'matrix-repr' of github.com:AD1024/taichi into matrix-repr
AD1024 Aug 24, 2022
9fc758d
format and fix tc for variable holding matrix expression
AD1024 Aug 24, 2022
f635b7c
refactor: change to `make_local_matrix` which returns only an Expr; p…
AD1024 Aug 25, 2022
f82dc25
get rid of duplicated check
AD1024 Aug 25, 2022
bd68c23
save changes
AD1024 Aug 25, 2022
5d00c98
format
AD1024 Aug 25, 2022
3451397
also rename cxx part
AD1024 Aug 25, 2022
5ed0e93
Apply suggestions from code review
strongoier Aug 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
from taichi.lang.field import Field
from taichi.lang.matrix import (Matrix, MatrixType, _PyScopeMatrixImpl,
from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl,
_TiScopeMatrixImpl)
from taichi.lang.snode import append
from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type
Expand Down Expand Up @@ -488,6 +488,12 @@ def build_Call(ctx, node):
node.ptr = impl.ti_format(*args, **keywords)
return node.ptr

if (isinstance(node.func, ast.Attribute) and
(func == Matrix or func == Vector)
) and impl.current_cfg().real_matrix and in_taichi_scope():
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
node.ptr = matrix.make_matrix(*args, **keywords)
return node.ptr

if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords):
return node.ptr

Expand Down
8 changes: 8 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def expr_init_local_tensor(shape, element_type, elements):
shape, element_type, elements)


@taichi_scope
def expr_init_matrix(shape, element_type, elements):
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
return get_runtime().prog.current_ast_builder().expr_alloca_matrix(
shape, element_type, elements)


@taichi_scope
def expr_init_shared_array(shape, element_type):
return get_runtime().prog.current_ast_builder().expr_alloca_shared_array(
Expand All @@ -48,6 +54,8 @@ def expr_init(rhs):
if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")):
return Matrix(*rhs.to_list(), ndim=rhs.ndim)
if isinstance(rhs, Matrix):
if current_cfg().real_matrix:
return rhs
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
return Matrix(rhs.to_list(), ndim=rhs.ndim)
if isinstance(rhs, SharedArray):
return rhs
Expand Down
14 changes: 14 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ def prop_setter(instance, value):
return cls


def make_matrix(arr, dt=None):
cast = (lambda x: ops_mod.cast(x, dt)) if dt else (
lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x))
if len(arr) == 0:
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
return impl.expr_init(impl.expr_init_matrix([0], dt, []))
if not isinstance(arr[0], Iterable):
return impl.expr_init(
impl.expr_init_matrix([len(arr)], dt,
[cast(elt).ptr for elt in arr]))
return impl.expr_init(
impl.expr_init_matrix([len(arr), len(arr[0])], dt,
[cast(elt).ptr for row in arr for elt in row]))


class _MatrixBaseImpl:
def __init__(self, m, n, entries):
self.m = m
Expand Down
6 changes: 5 additions & 1 deletion taichi/analysis/data_source_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ std::vector<Stmt *> get_load_pointers(Stmt *load_stmt) {
return external_func->arg_stmts;
} else if (auto ref = load_stmt->cast<ReferenceStmt>()) {
return {ref->var};
} else if (auto matrix_init = load_stmt->cast<MatrixInitStmt>()) {
return matrix_init->values;
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
} else if (auto ptr_offset = load_stmt->cast<PtrOffsetStmt>()) {
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
return {ptr_offset->origin};
} else {
return std::vector<Stmt *>();
}
Expand All @@ -59,7 +63,7 @@ Stmt *get_store_data(Stmt *store_stmt) {

std::vector<Stmt *> get_store_destination(Stmt *store_stmt) {
// If store_stmt provides some data sources, return the pointers of the data.
if (store_stmt->is<AllocaStmt>() && !store_stmt->ret_type->is<TensorType>()) {
if (store_stmt->is<AllocaStmt>()) {
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
// The statement itself provides a data source (const [0]).
return std::vector<Stmt *>(1, store_stmt);
} else if (auto local_store = store_stmt->cast<LocalStoreStmt>()) {
Expand Down
8 changes: 8 additions & 0 deletions taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(expr->indices.exprs);
}

void visit(MatrixExpression *expr) override {
emit(ExprOpCode::MatrixExpression);
emit(expr->dt);
for (auto elt : expr->elements) {
emit(elt);
}
}

void visit(StrideExpression *expr) override {
emit(ExprOpCode::StrideExpression);
emit(expr->var);
Expand Down
1 change: 1 addition & 0 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ static std::vector<std::uint8_t> get_offline_cache_key_of_compile_config(
serializer(config->demote_no_access_mesh_fors);
serializer(config->experimental_auto_mesh_local);
serializer(config->auto_mesh_local_default_occupacy);
serializer(config->real_matrix);
serializer.finalize();

return serializer.data;
Expand Down
18 changes: 18 additions & 0 deletions taichi/analysis/same_statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,24 @@ class IRNodeComparator : public IRVisitor {
basic_check(stmt);
}

void visit(MatrixInitStmt *stmt) override {
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
basic_check(stmt);
if (!same)
return;
auto o = other_node_->as<MatrixInitStmt>();
if (stmt->values.size() != o->values.size()) {
same = false;
return;
}
for (int i = 0; i < stmt->values.size(); ++i) {
other_node_ = o->values[i];
stmt->values[i]->accept(this);
other_node_ = o;
if (!same)
return;
}
}

void visit(IfStmt *stmt) override {
basic_check(stmt);
if (!same)
Expand Down
37 changes: 30 additions & 7 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) {
void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
if (stmt->ret_type->is<TensorType>()) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
auto type = tlctx->get_data_type(tensor_type->get_element_type());
auto type = tlctx->get_data_type(tensor_type);
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
auto array_size = tlctx->get_constant(tensor_type->get_num_elements());
// Return type is [array_size x type]*.
if (stmt->is_shared) {
Expand Down Expand Up @@ -688,6 +688,11 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) {
return llvm::Type::getDoubleTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::f16)) {
return llvm::Type::getHalfTy(*llvm_context);
} else if (dt->is<TensorType>()) {
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
auto tensor_type = dt->cast<TensorType>();
auto element_type = llvm_type(tensor_type->get_element_type());
return llvm::VectorType::get(element_type, tensor_type->get_num_elements(),
false);
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
} else {
TI_NOT_IMPLEMENTED;
}
Expand Down Expand Up @@ -800,12 +805,20 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) {
if (std::holds_alternative<Stmt *>(content)) {
auto arg_stmt = std::get<Stmt *>(content);
auto value = llvm_val[arg_stmt];
if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) ||
arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16))
value = builder->CreateFPExt(value,
tlctx->get_data_type(PrimitiveType::f64));
args.push_back(value);
formats += data_type_format(arg_stmt->ret_type);
if (arg_stmt->ret_type->is<TensorType>()) {
auto dtype = arg_stmt->ret_type->cast<TensorType>();
for (int i = 0; i < dtype->get_num_elements(); ++i) {
args.push_back(builder->CreateExtractElement(value, i));
}
formats += data_type_format(arg_stmt->ret_type);
} else {
if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) ||
arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16))
value = builder->CreateFPExt(
value, tlctx->get_data_type(PrimitiveType::f64));
args.push_back(value);
formats += data_type_format(arg_stmt->ret_type);
}
} else {
auto arg_str = std::get<std::string>(content);
auto value = builder->CreateGlobalStringPtr(arg_str, "content_string");
Expand Down Expand Up @@ -2515,6 +2528,16 @@ void TaskCodeGenLLVM::visit(MeshPatchIndexStmt *stmt) {
llvm_val[stmt] = get_arg(2);
}

void TaskCodeGenLLVM::visit(MatrixInitStmt *stmt) {
auto type = tlctx->get_data_type(stmt->ret_type->as<TensorType>());
llvm::Value *vec = llvm::UndefValue::get(type);
for (int i = 0; i < stmt->values.size(); ++i) {
auto *elem = llvm_val[stmt->values[i]];
vec = builder->CreateInsertElement(vec, elem, i);
}
llvm_val[stmt] = vec;
}

void TaskCodeGenLLVM::eliminate_unused_functions() {
TaichiLLVMContext::eliminate_unused_functions(
module.get(), [&](std::string func_name) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(ReferenceStmt *stmt) override;

void visit(MatrixInitStmt *stmt) override;

llvm::Value *create_xlogue(std::unique_ptr<Block> &block);

llvm::Value *create_mesh_xlogue(std::unique_ptr<Block> &block);
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/expressions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ PER_EXPRESSION(InternalFuncCallExpression)
PER_EXPRESSION(ExternalTensorExpression)
PER_EXPRESSION(GlobalVariableExpression)
PER_EXPRESSION(IndexExpression)
PER_EXPRESSION(MatrixExpression)
PER_EXPRESSION(StrideExpression)
PER_EXPRESSION(RangeAssumptionExpression)
PER_EXPRESSION(LoopUniqueExpression)
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ PER_STATEMENT(LoopUniqueStmt)
PER_STATEMENT(AssertStmt)
PER_STATEMENT(ExternalFuncCallStmt)
PER_STATEMENT(ExternalTensorShapeAlongAxisStmt)
PER_STATEMENT(MatrixInitStmt)

// Locals with reverse-mode autodiff
PER_STATEMENT(AdStackAllocaStmt)
Expand Down
3 changes: 2 additions & 1 deletion taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
}
}
auto load_ptrs = irpass::analysis::get_load_pointers(stmt);
if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1) {
if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 &&
!stmt->is<PtrOffsetStmt>()) {
strongoier marked this conversation as resolved.
Show resolved Hide resolved
// Identical load elimination
auto load_ptr = load_ptrs.front();
if (!after_lower_access ||
Expand Down
7 changes: 7 additions & 0 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
}
}

void visit(MatrixExpression *expr) override {
emit('[');
emit_vector(expr->elements);
emit(']');
emit(fmt::format(" (dt={})", expr->dt->to_string()));
}

void visit(IndexExpression *expr) override {
expr->var->accept(this);
emit('[');
Expand Down
26 changes: 26 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,25 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx,
return ctx->push_back<PtrOffsetStmt>(var->stmt, offset_stmt);
}

void MatrixExpression::type_check(CompileConfig *config) {
// TODO: typecheck matrix
for (auto &arg : elements) {
TI_ASSERT_TYPE_CHECKED(arg);
}
}

void MatrixExpression::flatten(FlattenContext *ctx) {
// TODO: implement flatten
TI_ASSERT(this->dt->is<TensorType>());
std::vector<Stmt *> values;
for (auto &elt : elements) {
flatten_rvalue(elt, ctx);
values.push_back(elt->stmt);
}
stmt = ctx->push_back<MatrixInitStmt>(values);
stmt->ret_type = this->dt;
}

bool IndexExpression::is_field() const {
return var.is<GlobalVariableExpression>();
}
Expand Down Expand Up @@ -960,6 +979,13 @@ Expr ASTBuilder::expr_alloca() {
return var;
}

Expr ASTBuilder::expr_alloca_local_matrix(const std::vector<int> &shape,
const std::optional<DataType> &dt,
const std::vector<Expr> &elements) {
auto dtype = dt.value_or(PrimitiveType::unknown);
return Expr(std::make_shared<MatrixExpression>(elements, shape, dtype));
}

Expr ASTBuilder::expr_alloca_local_tensor(const std::vector<int> &shape,
const DataType &element_type,
const ExprGroup &elements) {
Expand Down
23 changes: 23 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,26 @@ class GlobalVariableExpression : public Expression {
TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class MatrixExpression : public Expression {
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
public:
std::vector<Expr> elements;
DataType dt;

MatrixExpression(const std::vector<Expr> &elements,
std::vector<int> shape,
DataType element_type)
: elements(elements) {
this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type));
this->ret_type = this->dt;
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
}

void type_check(CompileConfig *config) override;

void flatten(FlattenContext *ctx) override;

TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class IndexExpression : public Expression {
public:
// `var` is one of GlobalVariableExpression, ExternalTensorExpression,
Expand Down Expand Up @@ -876,6 +896,9 @@ class ASTBuilder {
const ExprGroup &args,
const ExprGroup &outputs);
Expr expr_alloca();
Expr expr_alloca_local_matrix(const std::vector<int> &shape,
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
const std::optional<DataType> &dt,
const std::vector<Expr> &elements);
Expr expr_alloca_local_tensor(const std::vector<int> &shape,
const DataType &element_type,
const ExprGroup &elements);
Expand Down
12 changes: 12 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -1807,5 +1807,17 @@ class MeshPatchIndexStmt : public Stmt {
TI_DEFINE_ACCEPT_AND_CLONE
};

class MatrixInitStmt : public Stmt {
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
public:
std::vector<Stmt *> values;

MatrixInitStmt(const std::vector<Stmt *> &values) : values(values) {
TI_STMT_REG_FIELDS;
}

TI_STMT_DEF_FIELDS(ret_type, values);
TI_DEFINE_ACCEPT_AND_CLONE
};

} // namespace lang
} // namespace taichi
8 changes: 8 additions & 0 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ std::string TensorType::to_string() const {
return s;
}

int TensorType::vector_width() const {
int vw = 1;
for (auto dim : shape_) {
vw *= dim;
}
return vw;
}

int Type::vector_width() const {
return 1; // TODO: CPU vectorization
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
}
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class TensorType : public Type {
return shape_;
}

int vector_width() const;
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved

Type *get_compute_type() override {
return this;
}
Expand Down
Loading