Skip to content

Commit

Permalink
[lang] Add reference type support on real functions (#4889)
Browse files Browse the repository at this point in the history
* wip

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* add test

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test_api

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored May 12, 2022
1 parent 80f20f2 commit acedc0e
Show file tree
Hide file tree
Showing 19 changed files with 158 additions and 17 deletions.
6 changes: 6 additions & 0 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,12 @@ def transform_as_kernel():
arg.arg,
kernel_arguments.decl_matrix_arg(
ctx.func.arguments[i].annotation))
elif isinstance(ctx.func.arguments[i].annotation,
primitive_types.RefType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_scalar_arg(
ctx.func.arguments[i].annotation))
else:
ctx.global_vars[
arg.arg] = kernel_arguments.decl_scalar_arg(
Expand Down
12 changes: 8 additions & 4 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from taichi.lang.expr import Expr
from taichi.lang.matrix import Matrix, MatrixType
from taichi.lang.util import cook_dtype
from taichi.types.primitive_types import u64
from taichi.types.primitive_types import RefType, u64


class KernelArgument:
Expand Down Expand Up @@ -47,9 +47,13 @@ def subscript(self, i, j):


def decl_scalar_arg(dtype):
is_ref = False
if isinstance(dtype, RefType):
is_ref = True
dtype = dtype.tp
dtype = cook_dtype(dtype)
arg_id = impl.get_runtime().prog.decl_arg(dtype, False)
return Expr(_ti_core.make_arg_load_expr(arg_id, dtype))
return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref))


def decl_matrix_arg(matrixtype):
Expand All @@ -63,8 +67,8 @@ def decl_sparse_matrix(dtype):
ptr_type = cook_dtype(u64)
# Treat the sparse matrix argument as a scalar since we only need to pass in the base pointer
arg_id = impl.get_runtime().prog.decl_arg(ptr_type, False)
return SparseMatrixProxy(_ti_core.make_arg_load_expr(arg_id, ptr_type),
value_type)
return SparseMatrixProxy(
_ti_core.make_arg_load_expr(arg_id, ptr_type, False), value_type)


def decl_ndarray_arg(dtype, dim, element_shape, layout):
Expand Down
6 changes: 5 additions & 1 deletion python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ def func_call_rvalue(self, key, args):
if not isinstance(anno, template):
if id(anno) in primitive_types.type_ids:
non_template_args.append(ops.cast(args[i], anno))
elif isinstance(anno, primitive_types.RefType):
non_template_args.append(
_ti_core.make_reference(args[i].ptr))
else:
non_template_args.append(args[i])
non_template_args = impl.make_expr_group(non_template_args)
Expand Down Expand Up @@ -302,7 +305,8 @@ def extract_arguments(self):
else:
if not id(annotation
) in primitive_types.type_ids and not isinstance(
annotation, template):
annotation, template) and not isinstance(
annotation, primitive_types.RefType):
raise TaichiSyntaxError(
f'Invalid type annotation (argument {i}) of Taichi function: {annotation}'
)
Expand Down
11 changes: 11 additions & 0 deletions python/taichi/types/primitive_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@

# ----------------------------------------


class RefType:
def __init__(self, tp):
self.tp = tp


def ref(tp):
return RefType(tp)


real_types = [f16, f32, f64, float]
real_type_ids = [id(t) for t in real_types]

Expand Down Expand Up @@ -173,4 +183,5 @@
'u32',
'uint64',
'u64',
'ref',
]
2 changes: 2 additions & 0 deletions taichi/analysis/data_source_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ std::vector<Stmt *> get_load_pointers(Stmt *load_stmt) {
return std::vector<Stmt *>(1, stack_pop->stack);
} else if (auto external_func = load_stmt->cast<ExternalFuncCallStmt>()) {
return external_func->arg_stmts;
} else if (auto ref = load_stmt->cast<ReferenceStmt>()) {
return {ref->var};
} else {
return std::vector<Stmt *>();
}
Expand Down
11 changes: 9 additions & 2 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,9 @@ llvm::Value *CodeGenLLVM::bitcast_from_u64(llvm::Value *val, DataType type) {

llvm::Value *CodeGenLLVM::bitcast_to_u64(llvm::Value *val, DataType type) {
auto intermediate_bits = 0;
if (type.is_pointer()) {
return builder->CreatePtrToInt(val, tlctx->get_data_type<int64>());
}
if (auto cit = type->cast<CustomIntType>()) {
intermediate_bits = data_type_bits(cit->get_compute_type());
} else {
Expand All @@ -1109,8 +1112,8 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {

llvm::Type *dest_ty = nullptr;
if (stmt->is_ptr) {
dest_ty =
llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0);
dest_ty = llvm::PointerType::get(
tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0);
llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty);
} else {
llvm_val[stmt] = bitcast_from_u64(raw_arg, stmt->ret_type);
Expand Down Expand Up @@ -2460,6 +2463,10 @@ llvm::Value *CodeGenLLVM::create_mesh_xlogue(std::unique_ptr<Block> &block) {
return xlogue;
}

void CodeGenLLVM::visit(ReferenceStmt *stmt) {
llvm_val[stmt] = llvm_val[stmt->var];
}

void CodeGenLLVM::visit(FuncCallStmt *stmt) {
if (!func_map.count(stmt->func)) {
auto guard = get_function_creation_guard(
Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(MeshPatchIndexStmt *stmt) override;

void visit(ReferenceStmt *stmt) override;

llvm::Value *create_xlogue(std::unique_ptr<Block> &block);

llvm::Value *create_mesh_xlogue(std::unique_ptr<Block> &block);
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/expressions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ PER_EXPRESSION(FuncCallExpression)
PER_EXPRESSION(MeshPatchIndexExpression)
PER_EXPRESSION(MeshRelationAccessExpression)
PER_EXPRESSION(MeshIndexConversionExpression)
PER_EXPRESSION(ReferenceExpression)
1 change: 1 addition & 0 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ PER_STATEMENT(FuncCallStmt)
PER_STATEMENT(ReturnStmt)

PER_STATEMENT(ArgLoadStmt)
PER_STATEMENT(ReferenceStmt)
PER_STATEMENT(ExternalPtrStmt)
PER_STATEMENT(PtrOffsetStmt)
PER_STATEMENT(ConstStmt)
Expand Down
6 changes: 6 additions & 0 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
emit(")");
}

void visit(ReferenceExpression *expr) override {
emit("ref(");
expr->var->accept(this);
emit(")");
}

static std::string expr_to_string(Expr &expr) {
std::ostringstream oss;
ExpressionHumanFriendlyPrinter printer(&oss);
Expand Down
27 changes: 21 additions & 6 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void ArgLoadExpression::type_check(CompileConfig *) {
}

void ArgLoadExpression::flatten(FlattenContext *ctx) {
auto arg_load = std::make_unique<ArgLoadStmt>(arg_id, dt);
auto arg_load = std::make_unique<ArgLoadStmt>(arg_id, dt, is_ptr);
ctx->push_back(std::move(arg_load));
stmt = ctx->back_stmt();
}
Expand Down Expand Up @@ -485,17 +485,19 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) {
op_type = AtomicOpType::add;
}
// expand rhs
auto expr = val;
flatten_rvalue(expr, ctx);
flatten_rvalue(val, ctx);
auto src_val = val->stmt;
if (dest.is<IdExpression>()) { // local variable
// emit local store stmt
auto alloca = ctx->current_block->lookup_var(dest.cast<IdExpression>()->id);
ctx->push_back<AtomicOpStmt>(op_type, alloca, expr->stmt);
ctx->push_back<AtomicOpStmt>(op_type, alloca, src_val);
} else {
TI_ASSERT(dest.is<GlobalPtrExpression>() ||
dest.is<TensorElementExpression>());
dest.is<TensorElementExpression>() ||
(dest.is<ArgLoadExpression>() &&
dest.cast<ArgLoadExpression>()->is_ptr));
flatten_lvalue(dest, ctx);
ctx->push_back<AtomicOpStmt>(op_type, dest->stmt, expr->stmt);
ctx->push_back<AtomicOpStmt>(op_type, dest->stmt, src_val);
}
stmt = ctx->back_stmt();
stmt->tb = tb;
Expand Down Expand Up @@ -625,6 +627,16 @@ void MeshIndexConversionExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void ReferenceExpression::type_check(CompileConfig *) {
ret_type = var->ret_type;
}

void ReferenceExpression::flatten(FlattenContext *ctx) {
flatten_lvalue(var, ctx);
ctx->push_back<ReferenceStmt>(var->stmt);
stmt = ctx->back_stmt();
}

Block *ASTBuilder::current_block() {
if (stack_.empty())
return nullptr;
Expand Down Expand Up @@ -945,6 +957,9 @@ void flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) {
else {
TI_NOT_IMPLEMENTED
}
} else if (ptr.is<ArgLoadExpression>() &&
ptr.cast<ArgLoadExpression>()->is_ptr) {
flatten_global_load(ptr, ctx);
}
}

Expand Down
21 changes: 20 additions & 1 deletion taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,20 @@ class ArgLoadExpression : public Expression {
public:
int arg_id;
DataType dt;
bool is_ptr;

ArgLoadExpression(int arg_id, DataType dt) : arg_id(arg_id), dt(dt) {
ArgLoadExpression(int arg_id, DataType dt, bool is_ptr = false)
: arg_id(arg_id), dt(dt), is_ptr(is_ptr) {
}

void type_check(CompileConfig *config) override;

void flatten(FlattenContext *ctx) override;

bool is_lvalue() const override {
return is_ptr;
}

TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

Expand Down Expand Up @@ -727,6 +733,19 @@ class MeshIndexConversionExpression : public Expression {
TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class ReferenceExpression : public Expression {
public:
Expr var;
void type_check(CompileConfig *config) override;

ReferenceExpression(const Expr &expr) : var(expr) {
}

void flatten(FlattenContext *ctx) override;

TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class ASTBuilder {
private:
enum LoopState { None, Outermost, Inner };
Expand Down
20 changes: 20 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,26 @@ class FuncCallStmt : public Stmt {
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* A reference to a variable.
*/
class ReferenceStmt : public Stmt {
public:
Stmt *var;
bool global_side_effect{false};

ReferenceStmt(Stmt *var) : var(var) {
TI_STMT_REG_FIELDS;
}

bool has_global_side_effect() const override {
return global_side_effect;
}

TI_STMT_DEF_FIELDS(ret_type, var);
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* Exit the kernel or function with a return value.
*/
Expand Down
4 changes: 3 additions & 1 deletion taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,9 @@ void export_lang(py::module &m) {
Stmt::make<FrontendAssignStmt, const Expr &, const Expr &>);

m.def("make_arg_load_expr",
Expr::make<ArgLoadExpression, int, const DataType &>);
Expr::make<ArgLoadExpression, int, const DataType &, bool>);

m.def("make_reference", Expr::make<ReferenceExpression, const Expr &>);

m.def("make_external_tensor_expr",
Expr::make<ExternalTensorExpression, const DataType &, int, int, int,
Expand Down
4 changes: 4 additions & 0 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,10 @@ class IRPrinter : public IRVisitor {
print(")");
}

void visit(ReferenceStmt *stmt) override {
print("{}{} = ref({})", stmt->type_hint(), stmt->name(), stmt->var->name());
}

private:
std::string expr_to_string(Expr &expr) {
return expr_to_string(expr.expr.get());
Expand Down
4 changes: 3 additions & 1 deletion taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,9 @@ class LowerAST : public IRVisitor {
TI_NOT_IMPLEMENTED
}
} else { // global variable
TI_ASSERT(dest.is<GlobalPtrExpression>());
TI_ASSERT(dest.is<GlobalPtrExpression>() ||
(dest.is<ArgLoadExpression>() &&
dest.cast<ArgLoadExpression>()->is_ptr));
flatten_lvalue(dest, &fctx);
fctx.push_back<GlobalStoreStmt>(dest->stmt, expr->stmt);
}
Expand Down
5 changes: 5 additions & 0 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,11 @@ class TypeCheck : public IRVisitor {
void visit(BitStructStoreStmt *stmt) override {
// do nothing
}

void visit(ReferenceStmt *stmt) override {
stmt->ret_type = stmt->var->ret_type;
stmt->ret_type.set_is_pointer(true);
}
};

namespace irpass {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _get_expected_matrix_apis():
'lang', 'length', 'linalg', 'log', 'loop_config', 'math', 'max',
'mesh_local', 'mesh_patch_idx', 'metal', 'min', 'ndarray', 'ndrange',
'no_activate', 'one', 'opengl', 'polar_decompose', 'pow', 'profiler',
'randn', 'random', 'raw_div', 'raw_mod', 'rescale_index', 'reset',
'randn', 'random', 'raw_div', 'raw_mod', 'ref', 'rescale_index', 'reset',
'rgb_to_hex', 'root', 'round', 'rsqrt', 'select', 'set_logging_level',
'simt', 'sin', 'solve', 'sparse_matrix_builder', 'sqrt', 'static',
'static_assert', 'static_print', 'stop_grad', 'svd', 'swizzle_generator',
Expand Down
30 changes: 30 additions & 0 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,33 @@ def bar(a: ti.i32) -> ti.i32:

assert bar(10) == 11 * 5
assert bar(200) == 99 * 50


@test_utils.test(arch=[ti.cpu, ti.gpu], debug=True)
def test_ref():
@ti.experimental.real_func
def foo(a: ti.ref(ti.f32)):
a = 7

@ti.kernel
def bar():
a = 5.
foo(a)
assert a == 7

bar()


@test_utils.test(arch=[ti.cpu, ti.gpu], debug=True)
def test_ref_atomic():
@ti.experimental.real_func
def foo(a: ti.ref(ti.f32)):
a += a

@ti.kernel
def bar():
a = 5.
foo(a)
assert a == 10.

bar()

0 comments on commit acedc0e

Please sign in to comment.