diff --git a/taichi/ir/ir_builder.cpp b/taichi/ir/ir_builder.cpp index b6962fefe1872..798972c5a4eb5 100644 --- a/taichi/ir/ir_builder.cpp +++ b/taichi/ir/ir_builder.cpp @@ -313,4 +313,16 @@ void IRBuilder::create_local_store(AllocaStmt *ptr, Stmt *data) { insert(Stmt::make_typed(ptr, data)); } +GlobalPtrStmt *IRBuilder::create_global_ptr( + SNode *snode, + const std::vector &indices) { + return insert(Stmt::make_typed(snode, indices)); +} + +ExternalPtrStmt *IRBuilder::create_external_ptr( + ArgLoadStmt *ptr, + const std::vector &indices) { + return insert(Stmt::make_typed(ptr, indices)); +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index f4793e28edcee..c924eb8c3dd21 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -46,12 +46,13 @@ class IRBuilder { void set_insertion_point_to_false_branch(IfStmt *if_stmt); template void set_insertion_point_to_loop_begin(XxxStmt *loop) { - if constexpr (!std::is_base_of_v) { + using DecayedType = typename std::decay_t; + if constexpr (!std::is_base_of_v) { TI_ERROR("The argument is not a statement."); } - if constexpr (std::is_same_v || - std::is_same_v || - std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { set_insertion_point({loop->body.get(), 0}); } else { TI_ERROR("Statement {} is not a loop.", loop->name()); @@ -153,10 +154,42 @@ class IRBuilder { return insert(Stmt::make_typed(std::forward(args)...)); } - // Local variable. + // Local variables. AllocaStmt *create_local_var(DataType dt); LocalLoadStmt *create_local_load(AllocaStmt *ptr); void create_local_store(AllocaStmt *ptr, Stmt *data); + + // Global variables. + GlobalPtrStmt *create_global_ptr(SNode *snode, + const std::vector &indices); + ExternalPtrStmt *create_external_ptr(ArgLoadStmt *ptr, + const std::vector &indices); + template + GlobalLoadStmt *create_global_load(XxxStmt *ptr) { + using DecayedType = typename std::decay_t; + if constexpr (!std::is_base_of_v) { + TI_ERROR("The argument is not a statement."); + } + if constexpr (std::is_same_v || + std::is_same_v) { + return insert(Stmt::make_typed(ptr)); + } else { + TI_ERROR("Statement {} is not a global pointer.", ptr->name()); + } + } + template + void create_global_store(XxxStmt *ptr, Stmt *data) { + using DecayedType = typename std::decay_t; + if constexpr (!std::is_base_of_v) { + TI_ERROR("The argument is not a statement."); + } + if constexpr (std::is_same_v || + std::is_same_v) { + insert(Stmt::make_typed(ptr, data)); + } else { + TI_ERROR("Statement {} is not a global pointer.", ptr->name()); + } + } }; TLANG_NAMESPACE_END diff --git a/tests/cpp_new/ir/ir_builder_test.cpp b/tests/cpp_new/ir/ir_builder_test.cpp index ad095fcd15b0a..543859df99794 100644 --- a/tests/cpp_new/ir/ir_builder_test.cpp +++ b/tests/cpp_new/ir/ir_builder_test.cpp @@ -2,6 +2,7 @@ #include "taichi/ir/ir_builder.h" #include "taichi/ir/statements.h" +#include "taichi/program/program.h" namespace taichi { namespace lang { @@ -53,5 +54,37 @@ TEST(IRBuilder, RangeFor) { EXPECT_EQ(loopc->body->size(), 1); EXPECT_EQ(loopc->body->statements[0].get(), index); } + +TEST(IRBuilder, ExternalPtr) { + auto prog = Program(arch_from_name("x64")); + prog.materialize_layout(); + IRBuilder builder; + const int size = 10; + auto array = std::make_unique(size); + array[0] = 2; + array[2] = 40; + auto *arg = builder.create_arg_load(/*arg_id=*/0, get_data_type(), + /*is_ptr=*/true); + auto *zero = builder.get_int32(0); + auto *one = builder.get_int32(1); + auto *two = builder.get_int32(2); + auto *a1ptr = builder.create_external_ptr(arg, {one}); + builder.create_global_store(a1ptr, one); // a[1] = 1 + auto *a0 = + builder.create_global_load(builder.create_external_ptr(arg, {zero})); + auto *a2ptr = builder.create_external_ptr(arg, {two}); + auto *a2 = builder.create_global_load(a2ptr); + auto *a0plusa2 = builder.create_add(a0, a2); + builder.create_global_store(a2ptr, a0plusa2); // a[2] = a[0] + a[2] + auto block = builder.extract_ir(); + auto ker = std::make_unique(prog, std::move(block)); + ker->insert_arg(get_data_type(), /*is_nparray=*/true); + auto launch_ctx = ker->make_launch_context(); + launch_ctx.set_arg_nparray(0, (uint64)array.get(), size); + (*ker)(launch_ctx); + EXPECT_EQ(array[0], 2); + EXPECT_EQ(array[1], 1); + EXPECT_EQ(array[2], 42); +} } // namespace lang } // namespace taichi