Skip to content

Commit

Permalink
[Lang] Support vector/matrix ndarray arguments in real function
Browse files Browse the repository at this point in the history
ghstack-source-id: 8aad3bdb368cdad5f6627d819bf5bdc86353ffcb
Pull Request resolved: #8231
  • Loading branch information
lin-hitonami authored and Taichi Gardener committed Jun 29, 2023
1 parent 9945f44 commit d6e24b7
Show file tree
Hide file tree
Showing 14 changed files with 118 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def func_call_rvalue(self, key, args):
raise TaichiTypeError(
f"Expected ndarray in the kernel argument for argument {kernel_arg.name}, got {args[i]}"
)
non_template_args.append(args[i].ptr)
non_template_args += _ti_core.get_external_tensor_real_func_args(args[i].ptr)
else:
non_template_args.append(args[i])
non_template_args = impl.make_expr_group(non_template_args)
Expand Down
5 changes: 5 additions & 0 deletions taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(expr->axis);
}

void visit(ExternalTensorBasePtrExpression *expr) override {
emit(ExprOpCode::ExternalTensorBasePtrExpression);
emit(expr->ptr);
}

void visit(FrontendFuncCallStmt *expr) override {
emit(StmtOpCode::FrontendFuncCallStmt);
emit(expr->func);
Expand Down
7 changes: 7 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,13 @@ void TaskCodeGenLLVM::visit(ExternalTensorShapeAlongAxisStmt *stmt) {
{arg_id, TypeFactory::SHAPE_POS_IN_NDARRAY, axis}, /*create_load=*/true);
}

void TaskCodeGenLLVM::visit(ExternalTensorBasePtrStmt *stmt) {
const auto arg_id = stmt->arg_id;
int pos = stmt->is_grad ? TypeFactory::GRAD_PTR_POS_IN_NDARRAY
: TypeFactory::DATA_PTR_POS_IN_NDARRAY;
llvm_val[stmt] = get_struct_arg({arg_id, pos}, /*create_load=*/true);
}

std::string TaskCodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt,
std::string suffix) {
current_loop_reentry = nullptr;
Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(ExternalTensorShapeAlongAxisStmt *stmt) override;

void visit(ExternalTensorBasePtrStmt *stmt) override;

virtual bool kernel_argument_by_val() const {
return false; // on CPU devices just pass in a pointer
}
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/expressions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ PER_EXPRESSION(AtomicOpExpression)
PER_EXPRESSION(SNodeOpExpression)
PER_EXPRESSION(ConstExpression)
PER_EXPRESSION(ExternalTensorShapeAlongAxisExpression)
PER_EXPRESSION(ExternalTensorBasePtrExpression)
PER_EXPRESSION(MeshPatchIndexExpression)
PER_EXPRESSION(MeshRelationAccessExpression)
PER_EXPRESSION(MeshIndexConversionExpression)
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ PER_STATEMENT(LoopUniqueStmt)
PER_STATEMENT(AssertStmt)
PER_STATEMENT(ExternalFuncCallStmt)
PER_STATEMENT(ExternalTensorShapeAlongAxisStmt)
PER_STATEMENT(ExternalTensorBasePtrStmt)
PER_STATEMENT(MatrixInitStmt)

// Locals with reverse-mode autodiff
Expand Down
6 changes: 6 additions & 0 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
emit(", ", expr->axis, ')');
}

void visit(ExternalTensorBasePtrExpression *expr) override {
emit("external_tensor_base_ptr(");
expr->ptr->accept(this);
emit(')');
}

void visit(MeshPatchIndexExpression *expr) override {
emit("mesh_patch_idx()");
}
Expand Down
15 changes: 15 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,21 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void ExternalTensorBasePtrExpression::type_check(const CompileConfig *) {
TI_ASSERT_INFO(ptr.is<ExternalTensorExpression>(),
"Invalid ptr [{}] for ExternalTensorBasePtrExpression",
ExpressionHumanFriendlyPrinter::expr_to_string(ptr));
ret_type = ptr.cast<ExternalTensorExpression>()->dt.get_element_type();
ret_type.set_is_pointer(true);
}

void ExternalTensorBasePtrExpression::flatten(FlattenContext *ctx) {
auto tensor = ptr.cast<ExternalTensorExpression>();
ctx->push_back<ExternalTensorBasePtrStmt>(tensor->arg_id, is_grad);
stmt = ctx->back_stmt();
stmt->ret_type = ret_type;
}

void GetElementExpression::type_check(const CompileConfig *config) {
TI_ASSERT_TYPE_CHECKED(src);
auto src_type = src->ret_type;
Expand Down
16 changes: 16 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,22 @@ class ExternalTensorShapeAlongAxisExpression : public Expression {
TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class ExternalTensorBasePtrExpression : public Expression {
public:
Expr ptr;
bool is_grad;

explicit ExternalTensorBasePtrExpression(const Expr &ptr, bool is_grad)
: ptr(ptr), is_grad(is_grad) {
}

void type_check(const CompileConfig *config) override;

void flatten(FlattenContext *ctx) override;

TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class FrontendFuncCallStmt : public Stmt {
public:
std::optional<Identifier> ident;
Expand Down
5 changes: 5 additions & 0 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ ExternalTensorShapeAlongAxisStmt::ExternalTensorShapeAlongAxisStmt(int axis,
TI_STMT_REG_FIELDS;
}

ExternalTensorBasePtrStmt::ExternalTensorBasePtrStmt(int arg_id, bool is_grad)
: arg_id(arg_id), is_grad(is_grad) {
TI_STMT_REG_FIELDS;
}

LoopUniqueStmt::LoopUniqueStmt(Stmt *input, const std::vector<SNode *> &covers)
: input(input) {
for (const auto &sn : covers) {
Expand Down
15 changes: 15 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,21 @@ class ExternalTensorShapeAlongAxisStmt : public Stmt {
TI_DEFINE_ACCEPT_AND_CLONE
};

class ExternalTensorBasePtrStmt : public Stmt {
public:
int arg_id;
bool is_grad;

ExternalTensorBasePtrStmt(int arg_id, bool is_grad);

bool has_global_side_effect() const override {
return false;
}

TI_STMT_DEF_FIELDS(ret_type, arg_id, is_grad);
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* An assertion.
* If |cond| is false, print the formatted |text| with |args|, and terminate
Expand Down
24 changes: 24 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,30 @@ void export_lang(py::module &m) {
m.def("get_external_tensor_shape_along_axis",
Expr::make<ExternalTensorShapeAlongAxisExpression, const Expr &, int>);

m.def("get_external_tensor_real_func_args", [](const Expr &expr) {
TI_ASSERT(expr.is<ExternalTensorExpression>());
auto external_tensor_expr = expr.cast<ExternalTensorExpression>();

std::vector<Expr> args;
for (int i = 0; i < external_tensor_expr->ndim; i++) {
args.push_back(
Expr::make<ExternalTensorShapeAlongAxisExpression>(expr, i));
args.back()->type_check(nullptr);
}

args.push_back(
Expr::make<ExternalTensorBasePtrExpression>(expr, /*is_grad=*/false));
args.back()->type_check(nullptr);

if (external_tensor_expr->needs_grad) {
args.push_back(
Expr::make<ExternalTensorBasePtrExpression>(expr, /*is_grad=*/true));
args.back()->type_check(nullptr);
}

return args;
});

// Mesh related.
m.def("get_relation_size", [](mesh::MeshPtr mesh_ptr, const Expr &mesh_idx,
mesh::MeshElementType to_type) {
Expand Down
5 changes: 5 additions & 0 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,11 @@ class IRPrinter : public IRVisitor {
stmt->type_hint(), stmt->name(), stmt->axis, stmt->arg_id);
}

void visit(ExternalTensorBasePtrStmt *stmt) override {
print("{}{} = external_tensor_base_ptr (arg_id={})", stmt->type_hint(),
stmt->name(), stmt->arg_id);
}

void visit(BitStructStoreStmt *stmt) override {
std::string ch_ids;
std::string values;
Expand Down
15 changes: 15 additions & 0 deletions tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,3 +1114,18 @@ def foo(x: ti.types.ndarray(float, ndim=0)) -> ti.f32:
x = ti.ndarray(ti.f32, shape=())
x[None] = 42
assert foo(x) == 42


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_real_func_vector_ndarray_arg():
@ti.experimental.real_func
def foo(x: ti.types.ndarray(ndim=1)) -> vec3:
return x[0]

@ti.kernel
def test(x: ti.types.ndarray(ndim=1)) -> vec3:
return foo(x)

x = ti.Vector.ndarray(3, ti.f32, shape=(1))
x[0] = vec3(1, 2, 3)
assert (test(x) == vec3(1, 2, 3)).all()

0 comments on commit d6e24b7

Please sign in to comment.