Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Add reference type support on real functions #4889

Merged
merged 15 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,12 @@ def transform_as_kernel():
arg.arg,
kernel_arguments.decl_matrix_arg(
ctx.func.arguments[i].annotation))
elif isinstance(ctx.func.arguments[i].annotation,
primitive_types.RefType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_scalar_arg(
ctx.func.arguments[i].annotation))
else:
ctx.global_vars[
arg.arg] = kernel_arguments.decl_scalar_arg(
Expand Down
12 changes: 8 additions & 4 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from taichi.lang.expr import Expr
from taichi.lang.matrix import Matrix, MatrixType
from taichi.lang.util import cook_dtype
from taichi.types.primitive_types import u64
from taichi.types.primitive_types import RefType, u64


class KernelArgument:
Expand Down Expand Up @@ -47,9 +47,13 @@ def subscript(self, i, j):


def decl_scalar_arg(dtype):
is_ref = False
if isinstance(dtype, RefType):
is_ref = True
dtype = dtype.tp
dtype = cook_dtype(dtype)
arg_id = impl.get_runtime().prog.decl_arg(dtype, False)
return Expr(_ti_core.make_arg_load_expr(arg_id, dtype))
return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref))


def decl_matrix_arg(matrixtype):
Expand All @@ -63,8 +67,8 @@ def decl_sparse_matrix(dtype):
ptr_type = cook_dtype(u64)
# Treat the sparse matrix argument as a scalar since we only need to pass in the base pointer
arg_id = impl.get_runtime().prog.decl_arg(ptr_type, False)
return SparseMatrixProxy(_ti_core.make_arg_load_expr(arg_id, ptr_type),
value_type)
return SparseMatrixProxy(
_ti_core.make_arg_load_expr(arg_id, ptr_type, False), value_type)


def decl_ndarray_arg(dtype, dim, element_shape, layout):
Expand Down
6 changes: 5 additions & 1 deletion python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def func_call_rvalue(self, key, args):
if not isinstance(anno, template):
if id(anno) in primitive_types.type_ids:
non_template_args.append(ops.cast(args[i], anno))
elif isinstance(anno, primitive_types.RefType):
non_template_args.append(
_ti_core.make_reference(args[i].ptr))
else:
non_template_args.append(args[i])
non_template_args = impl.make_expr_group(non_template_args)
Expand Down Expand Up @@ -299,7 +302,8 @@ def extract_arguments(self):
else:
if not id(annotation
) in primitive_types.type_ids and not isinstance(
annotation, template):
annotation, template) and not isinstance(
annotation, primitive_types.RefType):
raise TaichiSyntaxError(
f'Invalid type annotation (argument {i}) of Taichi function: {annotation}'
)
Expand Down
11 changes: 11 additions & 0 deletions python/taichi/types/primitive_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@

# ----------------------------------------


class RefType:
def __init__(self, tp):
self.tp = tp


def ref(tp):
return RefType(tp)


real_types = [f16, f32, f64, float]
real_type_ids = [id(t) for t in real_types]

Expand Down Expand Up @@ -173,4 +183,5 @@
'u32',
'uint64',
'u64',
'ref',
]
2 changes: 2 additions & 0 deletions taichi/analysis/data_source_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ std::vector<Stmt *> get_load_pointers(Stmt *load_stmt) {
return std::vector<Stmt *>(1, stack_pop->stack);
} else if (auto external_func = load_stmt->cast<ExternalFuncCallStmt>()) {
return external_func->arg_stmts;
} else if (auto ref = load_stmt->cast<ReferenceStmt>()) {
return {ref->var};
} else {
return std::vector<Stmt *>();
}
Expand Down
11 changes: 9 additions & 2 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,9 @@ llvm::Value *CodeGenLLVM::bitcast_from_u64(llvm::Value *val, DataType type) {

llvm::Value *CodeGenLLVM::bitcast_to_u64(llvm::Value *val, DataType type) {
auto intermediate_bits = 0;
if (type.is_pointer()) {
return builder->CreatePtrToInt(val, tlctx->get_data_type<int64>());
}
if (auto cit = type->cast<CustomIntType>()) {
intermediate_bits = data_type_bits(cit->get_compute_type());
} else {
Expand All @@ -1108,8 +1111,8 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {

llvm::Type *dest_ty = nullptr;
if (stmt->is_ptr) {
dest_ty =
llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0);
dest_ty = llvm::PointerType::get(
tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0);
llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty);
} else {
llvm_val[stmt] = bitcast_from_u64(raw_arg, stmt->ret_type);
Expand Down Expand Up @@ -2458,6 +2461,10 @@ llvm::Value *CodeGenLLVM::create_mesh_xlogue(std::unique_ptr<Block> &block) {
return xlogue;
}

void CodeGenLLVM::visit(ReferenceStmt *stmt) {
llvm_val[stmt] = llvm_val[stmt->var];
}

void CodeGenLLVM::visit(FuncCallStmt *stmt) {
if (!func_map.count(stmt->func)) {
auto guard = get_function_creation_guard(
Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(MeshPatchIndexStmt *stmt) override;

void visit(ReferenceStmt *stmt) override;

llvm::Value *create_xlogue(std::unique_ptr<Block> &block);

llvm::Value *create_mesh_xlogue(std::unique_ptr<Block> &block);
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/expressions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ PER_EXPRESSION(FuncCallExpression)
PER_EXPRESSION(MeshPatchIndexExpression)
PER_EXPRESSION(MeshRelationAccessExpression)
PER_EXPRESSION(MeshIndexConversionExpression)
PER_EXPRESSION(ReferenceExpression)
1 change: 1 addition & 0 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ PER_STATEMENT(FuncCallStmt)
PER_STATEMENT(ReturnStmt)

PER_STATEMENT(ArgLoadStmt)
PER_STATEMENT(ReferenceStmt)
PER_STATEMENT(ExternalPtrStmt)
PER_STATEMENT(PtrOffsetStmt)
PER_STATEMENT(ConstStmt)
Expand Down
6 changes: 6 additions & 0 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
emit(")");
}

void visit(ReferenceExpression *expr) override {
emit("ref(");
expr->var->accept(this);
emit(")");
}

static std::string expr_to_string(Expr &expr) {
std::ostringstream oss;
ExpressionHumanFriendlyPrinter printer(&oss);
Expand Down
27 changes: 21 additions & 6 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void ArgLoadExpression::type_check(CompileConfig *) {
}

void ArgLoadExpression::flatten(FlattenContext *ctx) {
auto arg_load = std::make_unique<ArgLoadStmt>(arg_id, dt);
auto arg_load = std::make_unique<ArgLoadStmt>(arg_id, dt, is_ptr);
ctx->push_back(std::move(arg_load));
stmt = ctx->back_stmt();
}
Expand Down Expand Up @@ -485,17 +485,19 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) {
op_type = AtomicOpType::add;
}
// expand rhs
auto expr = val;
flatten_rvalue(expr, ctx);
flatten_rvalue(val, ctx);
auto src_val = val->stmt;
if (dest.is<IdExpression>()) { // local variable
// emit local store stmt
auto alloca = ctx->current_block->lookup_var(dest.cast<IdExpression>()->id);
ctx->push_back<AtomicOpStmt>(op_type, alloca, expr->stmt);
ctx->push_back<AtomicOpStmt>(op_type, alloca, src_val);
} else {
TI_ASSERT(dest.is<GlobalPtrExpression>() ||
dest.is<TensorElementExpression>());
dest.is<TensorElementExpression>() ||
(dest.is<ArgLoadExpression>() &&
dest.cast<ArgLoadExpression>()->is_ptr));
flatten_lvalue(dest, ctx);
ctx->push_back<AtomicOpStmt>(op_type, dest->stmt, expr->stmt);
ctx->push_back<AtomicOpStmt>(op_type, dest->stmt, src_val);
}
stmt = ctx->back_stmt();
stmt->tb = tb;
Expand Down Expand Up @@ -625,6 +627,16 @@ void MeshIndexConversionExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void ReferenceExpression::type_check(CompileConfig *) {
ret_type = var->ret_type;
}

void ReferenceExpression::flatten(FlattenContext *ctx) {
flatten_lvalue(var, ctx);
ctx->push_back<ReferenceStmt>(var->stmt);
stmt = ctx->back_stmt();
}

Block *ASTBuilder::current_block() {
if (stack_.empty())
return nullptr;
Expand Down Expand Up @@ -945,6 +957,9 @@ void flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) {
else {
TI_NOT_IMPLEMENTED
}
} else if (ptr.is<ArgLoadExpression>() &&
ptr.cast<ArgLoadExpression>()->is_ptr) {
flatten_global_load(ptr, ctx);
}
}

Expand Down
21 changes: 20 additions & 1 deletion taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,20 @@ class ArgLoadExpression : public Expression {
public:
int arg_id;
DataType dt;
bool is_ptr;

ArgLoadExpression(int arg_id, DataType dt) : arg_id(arg_id), dt(dt) {
ArgLoadExpression(int arg_id, DataType dt, bool is_ptr = false)
: arg_id(arg_id), dt(dt), is_ptr(is_ptr) {
}

void type_check(CompileConfig *config) override;

void flatten(FlattenContext *ctx) override;

bool is_lvalue() const override {
return is_ptr;
}

TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

Expand Down Expand Up @@ -727,6 +733,19 @@ class MeshIndexConversionExpression : public Expression {
TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class ReferenceExpression : public Expression {
public:
Expr var;
void type_check(CompileConfig *config) override;

ReferenceExpression(const Expr &expr) : var(expr) {
}

void flatten(FlattenContext *ctx) override;

TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class ASTBuilder {
private:
enum LoopState { None, Outermost, Inner };
Expand Down
20 changes: 20 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,26 @@ class FuncCallStmt : public Stmt {
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* A reference to a variable.
*/
class ReferenceStmt : public Stmt {
public:
Stmt *var;
bool global_side_effect{false};

ReferenceStmt(Stmt *var) : var(var) {
TI_STMT_REG_FIELDS;
}

bool has_global_side_effect() const override {
return global_side_effect;
}

TI_STMT_DEF_FIELDS(ret_type, var);
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* Exit the kernel or function with a return value.
*/
Expand Down
4 changes: 3 additions & 1 deletion taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,9 @@ void export_lang(py::module &m) {
Stmt::make<FrontendAssignStmt, const Expr &, const Expr &>);

m.def("make_arg_load_expr",
Expr::make<ArgLoadExpression, int, const DataType &>);
Expr::make<ArgLoadExpression, int, const DataType &, bool>);

m.def("make_reference", Expr::make<ReferenceExpression, const Expr &>);

m.def("make_external_tensor_expr",
Expr::make<ExternalTensorExpression, const DataType &, int, int, int,
Expand Down
4 changes: 4 additions & 0 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,10 @@ class IRPrinter : public IRVisitor {
print(")");
}

void visit(ReferenceStmt *stmt) override {
print("{}{} = ref({})", stmt->type_hint(), stmt->name(), stmt->var->name());
}

private:
std::string expr_to_string(Expr &expr) {
return expr_to_string(expr.expr.get());
Expand Down
4 changes: 3 additions & 1 deletion taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,9 @@ class LowerAST : public IRVisitor {
TI_NOT_IMPLEMENTED
}
} else { // global variable
TI_ASSERT(dest.is<GlobalPtrExpression>());
TI_ASSERT(dest.is<GlobalPtrExpression>() ||
(dest.is<ArgLoadExpression>() &&
dest.cast<ArgLoadExpression>()->is_ptr));
flatten_lvalue(dest, &fctx);
fctx.push_back<GlobalStoreStmt>(dest->stmt, expr->stmt);
}
Expand Down
5 changes: 5 additions & 0 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,11 @@ class TypeCheck : public IRVisitor {
void visit(BitStructStoreStmt *stmt) override {
// do nothing
}

void visit(ReferenceStmt *stmt) override {
stmt->ret_type = stmt->var->ret_type;
stmt->ret_type.set_is_pointer(true);
}
};

namespace irpass {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _get_expected_matrix_apis():
'lang', 'length', 'linalg', 'log', 'loop_config', 'math', 'max',
'mesh_local', 'mesh_patch_idx', 'metal', 'min', 'ndarray', 'ndrange',
'no_activate', 'one', 'opengl', 'polar_decompose', 'pow', 'profiler',
'randn', 'random', 'raw_div', 'raw_mod', 'rescale_index', 'reset',
'randn', 'random', 'raw_div', 'raw_mod', 'ref', 'rescale_index', 'reset',
'rgb_to_hex', 'root', 'round', 'rsqrt', 'select', 'set_logging_level',
'simt', 'sin', 'solve', 'sparse_matrix_builder', 'sqrt', 'static',
'static_assert', 'static_print', 'stop_grad', 'svd', 'swizzle_generator',
Expand Down
30 changes: 30 additions & 0 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,33 @@ def bar(a: ti.i32) -> ti.i32:

assert bar(10) == 11 * 5
assert bar(200) == 99 * 50


@test_utils.test(arch=[ti.cpu, ti.gpu], debug=True)
def test_ref():
@ti.experimental.real_func
def foo(a: ti.ref(ti.f32)):
a = 7

@ti.kernel
def bar():
a = 5.
foo(a)
assert a == 7

bar()


@test_utils.test(arch=[ti.cpu, ti.gpu], debug=True)
def test_ref_atomic():
@ti.experimental.real_func
def foo(a: ti.ref(ti.f32)):
a += a

@ti.kernel
def bar():
a = 5.
foo(a)
assert a == 10.

bar()