diff --git a/cmake/TaichiTests.cmake b/cmake/TaichiTests.cmake index 7108d8eadf306..4404a1cdafef9 100644 --- a/cmake/TaichiTests.cmake +++ b/cmake/TaichiTests.cmake @@ -10,8 +10,13 @@ endif() # TODO(#2195): # 1. "cpp" -> "cpp_legacy", "cpp_new" -> "cpp" # 2. Re-implement the legacy CPP tests using googletest -file(GLOB_RECURSE TAICHI_TESTS_SOURCE "tests/cpp/analysis/*.cpp" "tests/cpp/common/*.cpp" "tests/cpp/ir/*.cpp" - "tests/cpp/program/*.cpp" "tests/cpp/transforms/*.cpp") +file(GLOB_RECURSE TAICHI_TESTS_SOURCE + "tests/cpp/analysis/*.cpp" + "tests/cpp/codegen/*.cpp" + "tests/cpp/common/*.cpp" + "tests/cpp/ir/*.cpp" + "tests/cpp/program/*.cpp" + "tests/cpp/transforms/*.cpp") include_directories( ${PROJECT_SOURCE_DIR}, diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 8e951ffd05af8..fcc52e81d5293 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -281,7 +281,7 @@ CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir) initialize_context(); context_ty = get_runtime_type("Context"); - physical_coordinate_ty = get_runtime_type("PhysicalCoordinates"); + physical_coordinate_ty = get_runtime_type(kLLVMPhysicalCoordinatesName); kernel_name = kernel->name + "_kernel"; } @@ -1652,7 +1652,7 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) { } } - auto coord_object = RuntimeObject("PhysicalCoordinates", this, + auto coord_object = RuntimeObject(kLLVMPhysicalCoordinatesName, this, builder.get(), new_coordinates); for (int i = 0; i < snode->num_active_indices; i++) { auto j = snode->physical_index_position[i]; diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index e25a6415054fa..3a6cb7526bfd4 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -93,6 +93,7 @@ class IRBuilder { [[nodiscard]] LoopGuard get_loop_guard(XStmt *loop) { return LoopGuard(*this, loop); } + [[nodiscard]] IfGuard get_if_guard(IfStmt *if_stmt, bool true_branch) { return IfGuard(*this, if_stmt, true_branch); } @@ -192,7 +193,7 @@ class IRBuilder { // Print values and strings. Arguments can be Stmt* or std::string. template - PrintStmt *create_print(Args &&...args) { + PrintStmt *create_print(Args &&... args) { return insert(Stmt::make_typed(std::forward(args)...)); } diff --git a/taichi/llvm/llvm_codegen_utils.cpp b/taichi/llvm/llvm_codegen_utils.cpp index 40db084c5413c..158184a682ead 100644 --- a/taichi/llvm/llvm_codegen_utils.cpp +++ b/taichi/llvm/llvm_codegen_utils.cpp @@ -1,6 +1,7 @@ #include "llvm_codegen_utils.h" -TLANG_NAMESPACE_BEGIN +namespace taichi { +namespace lang { std::string type_name(llvm::Type *type) { std::string type_name_str; @@ -41,4 +42,5 @@ void check_func_call_signature(llvm::Value *func, } } -TLANG_NAMESPACE_END +} // namespace lang +} // namespace taichi diff --git a/taichi/llvm/llvm_codegen_utils.h b/taichi/llvm/llvm_codegen_utils.h index 76e7b4ff12433..6fc2f88ecdccd 100644 --- a/taichi/llvm/llvm_codegen_utils.h +++ b/taichi/llvm/llvm_codegen_utils.h @@ -36,7 +36,10 @@ #include "llvm_context.h" -TLANG_NAMESPACE_BEGIN +namespace taichi { +namespace lang { + +inline constexpr char kLLVMPhysicalCoordinatesName[] = "PhysicalCoordinates"; std::string type_name(llvm::Type *type); @@ -50,11 +53,11 @@ inline bool check_func_call_signature(llvm::Value *func, Args &&... args) { class LLVMModuleBuilder { public: - std::unique_ptr module; - llvm::BasicBlock *entry_block; - std::unique_ptr> builder; - TaichiLLVMContext *tlctx; - llvm::LLVMContext *llvm_context; + std::unique_ptr module{nullptr}; + llvm::BasicBlock *entry_block{nullptr}; + std::unique_ptr> builder{nullptr}; + TaichiLLVMContext *tlctx{nullptr}; + llvm::LLVMContext *llvm_context{nullptr}; LLVMModuleBuilder(std::unique_ptr &&module, TaichiLLVMContext *tlctx) @@ -145,10 +148,10 @@ class LLVMModuleBuilder { class RuntimeObject { public: std::string cls_name; - llvm::Value *ptr; - LLVMModuleBuilder *mb; - llvm::Type *type; - llvm::IRBuilder<> *builder; + llvm::Value *ptr{nullptr}; + LLVMModuleBuilder *mb{nullptr}; + llvm::Type *type{nullptr}; + llvm::IRBuilder<> *builder{nullptr}; RuntimeObject(const std::string &cls_name, LLVMModuleBuilder *mb, @@ -179,6 +182,10 @@ class RuntimeObject { call(fmt::format("set_{}", field), val); } + void set(const std::string &field, llvm::Value *index, llvm::Value *val) { + call(fmt::format("set_{}", field), index, val); + } + template llvm::Value *call(const std::string &func_name, Args &&... args) { auto func = get_func(func_name); @@ -192,4 +199,5 @@ class RuntimeObject { } }; -TLANG_NAMESPACE_END +} // namespace lang +} // namespace taichi diff --git a/taichi/program/program.h b/taichi/program/program.h index b78c606f27db1..d852a32065f45 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -60,8 +60,8 @@ TLANG_NAMESPACE_END namespace std { template <> struct hash { - std::size_t operator()(taichi::lang::JITEvaluatorId const &id) const - noexcept { + std::size_t operator()( + taichi::lang::JITEvaluatorId const &id) const noexcept { return ((std::size_t)id.op | (id.ret.hash() << 8) | (id.lhs.hash() << 16) | (id.rhs.hash() << 24) | ((std::size_t)id.is_binary << 31)) ^ (std::hash{}(id.thread_id) << 32); @@ -84,29 +84,31 @@ class AsyncEngine; class Program { public: using Kernel = taichi::lang::Kernel; - Callable *current_callable; - std::unique_ptr snode_root; // pointer to the data structure. - void *llvm_runtime; + Callable *current_callable{nullptr}; + std::unique_ptr snode_root{nullptr}; // pointer to the data structure. + void *llvm_runtime{nullptr}; CompileConfig config; - std::unique_ptr llvm_context_host, llvm_context_device; - bool sync; // device/host synchronized? - bool finalized; - float64 total_compilation_time; + std::unique_ptr llvm_context_host{nullptr}; + std::unique_ptr llvm_context_device{nullptr}; + bool sync{false}; // device/host synchronized? + bool finalized{false}; + float64 total_compilation_time{0.0}; static std::atomic num_instances; - std::unique_ptr thread_pool; - std::unique_ptr memory_pool; - uint64 *result_buffer; // TODO: move this - void *preallocated_device_buffer; // TODO: move this to memory allocator + std::unique_ptr thread_pool{nullptr}; + std::unique_ptr memory_pool{nullptr}; + uint64 *result_buffer{nullptr}; // TODO: move this + void *preallocated_device_buffer{ + nullptr}; // TODO: move this to memory allocator std::unordered_map snodes; - std::unique_ptr runtime; - std::unique_ptr async_engine; + std::unique_ptr runtime{nullptr}; + std::unique_ptr async_engine{nullptr}; std::vector> kernels; std::vector> functions; std::unordered_map function_map; - std::unique_ptr profiler; + std::unique_ptr profiler{nullptr}; std::unordered_map> jit_evaluator_cache; @@ -119,7 +121,7 @@ class Program { Program() : Program(default_compile_config.arch) { } - Program(Arch arch); + explicit Program(Arch arch); void kernel_profiler_print() { profiler->print(); diff --git a/taichi/struct/struct.cpp b/taichi/struct/struct.cpp index 2b397896eb12a..7b075ce38feb4 100644 --- a/taichi/struct/struct.cpp +++ b/taichi/struct/struct.cpp @@ -8,10 +8,6 @@ TLANG_NAMESPACE_BEGIN -StructCompiler::StructCompiler(Program *prog) : prog(prog) { - root_size = 0; -} - void StructCompiler::collect_snodes(SNode &snode) { snodes.push_back(&snode); for (int ch_id = 0; ch_id < (int)snode.ch.size(); ch_id++) { diff --git a/taichi/struct/struct.h b/taichi/struct/struct.h index 8a3aa64bb822a..fdba07d3e9457 100644 --- a/taichi/struct/struct.h +++ b/taichi/struct/struct.h @@ -11,10 +11,7 @@ class StructCompiler { std::vector stack; std::vector snodes; std::vector ambient_snodes; - std::size_t root_size; - Program *prog; - - explicit StructCompiler(Program *prog); + std::size_t root_size{0}; virtual ~StructCompiler() = default; diff --git a/taichi/struct/struct_llvm.cpp b/taichi/struct/struct_llvm.cpp index e357406a26099..ecad5031e5abe 100644 --- a/taichi/struct/struct_llvm.cpp +++ b/taichi/struct/struct_llvm.cpp @@ -1,25 +1,28 @@ -// Codegen for the hierarchical data structure (LLVM) +#include "taichi/struct/struct_llvm.h" #include "llvm/IR/Verifier.h" #include "llvm/IR/IRBuilder.h" #include "taichi/ir/ir.h" #include "taichi/struct/struct.h" -#include "taichi/struct/struct_llvm.h" #include "taichi/program/program.h" #include "taichi/util/file_sequence_writer.h" -TLANG_NAMESPACE_BEGIN - -using namespace llvm; +namespace taichi { +namespace lang { + +StructCompilerLLVM::StructCompilerLLVM(Arch arch, + const CompileConfig *config, + TaichiLLVMContext *tlctx) + : LLVMModuleBuilder(tlctx->clone_runtime_module(), tlctx), + arch_(arch), + config_(config), + tlctx_(tlctx), + llvm_ctx_(tlctx_->get_this_thread_context()) { +} -StructCompilerLLVM::StructCompilerLLVM(Program *prog, Arch arch) - : StructCompiler(prog), - LLVMModuleBuilder(prog->get_llvm_context(arch)->clone_runtime_module(), - prog->get_llvm_context(arch)), - arch(arch) { - tlctx = prog->get_llvm_context(arch); - llvm_ctx = tlctx->get_this_thread_context(); +StructCompilerLLVM::StructCompilerLLVM(Arch arch, Program *prog) + : StructCompilerLLVM(arch, &(prog->config), prog->get_llvm_context(arch)) { } void StructCompilerLLVM::generate_types(SNode &snode) { @@ -29,8 +32,8 @@ void StructCompilerLLVM::generate_types(SNode &snode) { return; llvm::Type *node_type = nullptr; - auto ctx = llvm_ctx; - TI_ASSERT(ctx == tlctx->get_this_thread_context()); + auto ctx = llvm_ctx_; + TI_ASSERT(ctx == tlctx_->get_this_thread_context()); // create children type that supports forking... @@ -46,20 +49,20 @@ void StructCompilerLLVM::generate_types(SNode &snode) { auto ch_type = llvm::StructType::create(*ctx, ch_types, snode.node_type_name + "_ch"); - snode.cell_size_bytes = tlctx->get_type_size(ch_type); + snode.cell_size_bytes = tlctx_->get_type_size(ch_type); llvm::Type *body_type = nullptr, *aux_type = nullptr; if (type == SNodeType::dense || type == SNodeType::bitmasked) { TI_ASSERT(snode._morton == false); body_type = llvm::ArrayType::get(ch_type, snode.max_num_elements()); if (type == SNodeType::bitmasked) { - aux_type = llvm::ArrayType::get(llvm::Type::getInt32Ty(*llvm_ctx), + aux_type = llvm::ArrayType::get(llvm::Type::getInt32Ty(*llvm_ctx_), (snode.max_num_elements() + 31) / 32); } } else if (type == SNodeType::root) { body_type = ch_type; } else if (type == SNodeType::place) { - body_type = tlctx->get_data_type(snode.dt); + body_type = tlctx_->get_data_type(snode.dt); } else if (type == SNodeType::bit_struct) { // Generate the bit_struct type std::vector ch_types; @@ -78,7 +81,7 @@ void StructCompilerLLVM::generate_types(SNode &snode) { TI_ERROR("Type {} not supported.", ch->dt->to_string()); } component_cit->set_physical_type(snode.physical_type); - if (!arch_is_cpu(arch)) { + if (!arch_is_cpu(arch_)) { TI_ERROR_IF(data_type_bits(snode.physical_type) < 32, "bit_struct physical type must be at least 32 bits on " "non-CPU backends."); @@ -95,14 +98,14 @@ void StructCompilerLLVM::generate_types(SNode &snode) { snode.physical_type, ch_types, ch_offsets); DataType container_primitive_type(snode.physical_type); - body_type = tlctx->get_data_type(container_primitive_type); + body_type = tlctx_->get_data_type(container_primitive_type); } else if (type == SNodeType::bit_array) { // A bit array SNode should have only one child TI_ASSERT(snode.ch.size() == 1); auto &ch = snode.ch[0]; Type *ch_type = ch->dt; ch->dt->as()->set_physical_type(snode.physical_type); - if (!arch_is_cpu(arch)) { + if (!arch_is_cpu(arch_)) { TI_ERROR_IF(data_type_bits(snode.physical_type) <= 16, "bit_array physical type must be at least 32 bits on " "non-CPU backends."); @@ -111,7 +114,7 @@ void StructCompilerLLVM::generate_types(SNode &snode) { snode.physical_type, ch_type, snode.n); DataType container_primitive_type(snode.physical_type); - body_type = tlctx->get_data_type(container_primitive_type); + body_type = tlctx_->get_data_type(container_primitive_type); } else if (type == SNodeType::pointer) { // mutex aux_type = llvm::ArrayType::get(llvm::PointerType::getInt64Ty(*ctx), @@ -162,18 +165,18 @@ void StructCompilerLLVM::generate_refine_coordinates(SNode *snode) { auto coord_type_ptr = llvm::PointerType::get(coord_type, 0); auto ft = llvm::FunctionType::get( - llvm::Type::getVoidTy(*llvm_ctx), - {coord_type_ptr, coord_type_ptr, llvm::Type::getInt32Ty(*llvm_ctx)}, + llvm::Type::getVoidTy(*llvm_ctx_), + {coord_type_ptr, coord_type_ptr, llvm::Type::getInt32Ty(*llvm_ctx_)}, false); auto func = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, snode->refine_coordinates_func_name(), *module); - auto bb = BasicBlock::Create(*llvm_ctx, "entry", func); + auto bb = llvm::BasicBlock::Create(*llvm_ctx_, "entry", func); llvm::IRBuilder<> builder(bb, bb->begin()); - std::vector args; + std::vector args; for (auto &arg : func->args()) { args.push_back(&arg); @@ -184,19 +187,19 @@ void StructCompilerLLVM::generate_refine_coordinates(SNode *snode) { auto l = args[2]; for (int i = 0; i < taichi_max_num_indices; i++) { - auto addition = tlctx->get_constant(0); + auto addition = tlctx_->get_constant(0); if (snode->extractors[i].num_bits) { auto mask = ((1 << snode->extractors[i].num_bits) - 1); addition = builder.CreateAnd( builder.CreateAShr(l, snode->extractors[i].acc_offset), mask); } auto in = call(&builder, "PhysicalCoordinates_get_val", inp_coords, - tlctx->get_constant(i)); + tlctx_->get_constant(i)); in = builder.CreateShl(in, - tlctx->get_constant(snode->extractors[i].num_bits)); + tlctx_->get_constant(snode->extractors[i].num_bits)); auto added = builder.CreateOr(in, addition); call(&builder, "PhysicalCoordinates_set_val", outp_coords, - tlctx->get_constant(i), added); + tlctx_->get_constant(i), added); } builder.CreateRetVoid(); } @@ -220,29 +223,29 @@ void StructCompilerLLVM::generate_child_accessors(SNode &snode) { llvm::PointerType::get(get_llvm_element_type(module.get(), parent), 0); auto ft = - llvm::FunctionType::get(llvm::Type::getInt8PtrTy(*llvm_ctx), - {llvm::Type::getInt8PtrTy(*llvm_ctx)}, false); + llvm::FunctionType::get(llvm::Type::getInt8PtrTy(*llvm_ctx_), + {llvm::Type::getInt8PtrTy(*llvm_ctx_)}, false); auto func = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, snode.get_ch_from_parent_func_name(), *module); - auto bb = BasicBlock::Create(*llvm_ctx, "entry", func); + auto bb = llvm::BasicBlock::Create(*llvm_ctx_, "entry", func); llvm::IRBuilder<> builder(bb, bb->begin()); - std::vector args; + std::vector args; for (auto &arg : func->args()) { args.push_back(&arg); } llvm::Value *ret; - ret = builder.CreateGEP( - builder.CreateBitCast(args[0], inp_type), - {tlctx->get_constant(0), tlctx->get_constant(parent->child_id(&snode))}, - "getch"); + ret = builder.CreateGEP(builder.CreateBitCast(args[0], inp_type), + {tlctx_->get_constant(0), + tlctx_->get_constant(parent->child_id(&snode))}, + "getch"); builder.CreateRet( - builder.CreateBitCast(ret, llvm::Type::getInt8PtrTy(*llvm_ctx))); + builder.CreateBitCast(ret, llvm::Type::getInt8PtrTy(*llvm_ctx_))); } for (auto &ch : snode.ch) { @@ -275,7 +278,7 @@ void StructCompilerLLVM::run(SNode &root, bool host) { generate_child_accessors(root); - if (prog->config.print_struct_llvm_ir) { + if (config_->print_struct_llvm_ir) { static FileSequenceWriter writer("taichi_struct_llvm_ir_{:04d}.ll", "struct LLVM IR"); writer.write(module.get()); @@ -284,9 +287,9 @@ void StructCompilerLLVM::run(SNode &root, bool host) { TI_ASSERT((int)snodes.size() <= taichi_max_num_snodes); auto node_type = get_llvm_node_type(module.get(), &root); - root_size = tlctx->get_data_layout().getTypeAllocSize(node_type); + root_size = tlctx_->get_data_layout().getTypeAllocSize(node_type); - tlctx->set_struct_module(module); + tlctx_->set_struct_module(module); } llvm::Type *StructCompilerLLVM::get_stub(llvm::Module *module, @@ -324,7 +327,8 @@ llvm::Type *StructCompilerLLVM::get_llvm_element_type(llvm::Module *module, } std::unique_ptr StructCompiler::make(Program *prog, Arch arch) { - return std::make_unique(prog, arch); + return std::make_unique(arch, prog); } -TLANG_NAMESPACE_END +} // namespace lang +} // namespace taichi diff --git a/taichi/struct/struct_llvm.h b/taichi/struct/struct_llvm.h index 7930f1a6d74d8..3b8159fa9c924 100644 --- a/taichi/struct/struct_llvm.h +++ b/taichi/struct/struct_llvm.h @@ -1,17 +1,18 @@ // Codegen for the hierarchical data structure (LLVM) -#include "taichi/struct/struct.h" #include "taichi/llvm/llvm_codegen_utils.h" +#include "taichi/struct/struct.h" -TLANG_NAMESPACE_BEGIN +namespace taichi { +namespace lang { class StructCompilerLLVM : public StructCompiler, public LLVMModuleBuilder { public: - StructCompilerLLVM(Program *prog, Arch arch); + StructCompilerLLVM(Arch arch, + const CompileConfig *config, + TaichiLLVMContext *tlctx); - Arch arch; - TaichiLLVMContext *tlctx; - llvm::LLVMContext *llvm_ctx; + StructCompilerLLVM(Arch arch, Program *prog); void generate_types(SNode &snode) override; @@ -32,6 +33,13 @@ class StructCompilerLLVM : public StructCompiler, public LLVMModuleBuilder { static llvm::Type *get_llvm_aux_type(llvm::Module *module, SNode *snode); static llvm::Type *get_llvm_element_type(llvm::Module *module, SNode *snode); + + private: + Arch arch_; + const CompileConfig *const config_; + TaichiLLVMContext *const tlctx_; + llvm::LLVMContext *const llvm_ctx_; }; -TLANG_NAMESPACE_END +} // namespace lang +} // namespace taichi diff --git a/tests/cpp/codegen/refine_coordinates_test.cpp b/tests/cpp/codegen/refine_coordinates_test.cpp new file mode 100644 index 0000000000000..c9db938d6d1b7 --- /dev/null +++ b/tests/cpp/codegen/refine_coordinates_test.cpp @@ -0,0 +1,164 @@ +#include "gtest/gtest.h" + +#include + +#include "llvm/IR/Function.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/BasicBlock.h" + +#include "taichi/program/arch.h" +#include "taichi/program/program.h" +#include "taichi/struct/struct_llvm.h" +#include "taichi/ir/snode.h" +#include "taichi/program/compile_config.h" +#include "taichi/llvm/llvm_codegen_utils.h" + +namespace taichi { + +namespace lang { +namespace { + +constexpr char kFuncName[] = "run_refine_coords"; + +class InvokeRefineCoordinatesBuilder : public LLVMModuleBuilder { + public: + // 1st arg: Value of the first parent physical coordiantes + // 2nd arg: The child index + // ret : Value of the first child physical coordinates + using FuncType = int (*)(int, int); + + static FuncType build(const SNode *snode, TaichiLLVMContext *tlctx) { + InvokeRefineCoordinatesBuilder mb{tlctx}; + mb.run_jit(snode); + tlctx->add_module(std::move(mb.module)); + auto *fn = tlctx->lookup_function_pointer(kFuncName); + return reinterpret_cast(fn); + } + + private: + InvokeRefineCoordinatesBuilder(TaichiLLVMContext *tlctx) + : LLVMModuleBuilder(tlctx->clone_struct_module(), tlctx) { + this->llvm_context = this->tlctx->get_this_thread_context(); + this->builder = std::make_unique>(*llvm_context); + } + + void run_jit(const SNode *snode) { + // pseudo code: + // + // int run_refine_coords(int parent_coords_first_comp, int child_index) { + // PhysicalCoordinates parent_coords; + // PhysicalCoordinates child_coords; + // parent_coord.val[0] = parent_coords_first_comp; + // snode_refine_coordinates(&parent_coords, &child_coords, child_index); + // return child_coords.val[0]; + // } + auto *const int32_ty = llvm::Type::getInt32Ty(*llvm_context); + auto *const func_ty = + llvm::FunctionType::get(int32_ty, {int32_ty, int32_ty}, + /*isVarArg=*/false); + auto *const func = llvm::Function::Create( + func_ty, llvm::Function::ExternalLinkage, kFuncName, module.get()); + std::vector args; + for (auto &a : func->args()) { + args.push_back(&a); + } + auto *const parent_coords_first_component = args[0]; + auto *const child_index = args[1]; + + this->entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func); + builder->SetInsertPoint(entry_block); + + auto *const index0 = tlctx->get_constant(0); + + RuntimeObject parent_coords{kLLVMPhysicalCoordinatesName, this, + builder.get()}; + parent_coords.set("val", index0, parent_coords_first_component); + auto *refine_fn = + get_runtime_function(snode->refine_coordinates_func_name()); + RuntimeObject child_coords{kLLVMPhysicalCoordinatesName, this, + builder.get()}; + builder->CreateCall(refine_fn, + {parent_coords.ptr, child_coords.ptr, child_index}); + auto *ret_val = child_coords.get("val", index0); + builder->CreateRet(ret_val); + + llvm::verifyFunction(*func); + } +}; + +struct BitsRange { + int begin{0}; + int end{0}; + + int extract(int v) const { + const unsigned mask = (1U << (end - begin)) - 1; + return (v >> begin) & mask; + } +}; + +constexpr int kPointerSize = 5; +constexpr int kDenseSize = 7; + +class RefineCoordinatesTest : public ::testing::Test { + protected: + void SetUp() override { + arch_ = host_arch(); + config_.print_kernel_llvm_ir = false; + prog_ = std::make_unique(arch_); + tlctx_ = prog_->llvm_context_host.get(); + + root_snode_ = std::make_unique(/*depth=*/0, /*t=*/SNodeType::root); + const std::vector indices = {Index{0}}; + ptr_snode_ = &(root_snode_->pointer(indices, kPointerSize)); + dense_snode_ = &(ptr_snode_->dense(indices, kDenseSize)); + // Must end with a `place` SNode. + auto &leaf_snode = dense_snode_->insert_children(SNodeType::place); + leaf_snode.dt = PrimitiveType::f32; + + auto sc = std::make_unique(arch_, &config_, tlctx_); + sc->run(*root_snode_, /*host=*/true); + } + + Arch arch_; + CompileConfig config_; + // We shouldn't need a Program instance in this test. Unfortunately, a few + // places depend on the global |current_program|, so we have to. + // ¯\_(ツ)_/¯ + std::unique_ptr prog_{nullptr}; + TaichiLLVMContext *tlctx_{nullptr}; + + std::unique_ptr root_snode_{nullptr}; + SNode *ptr_snode_{nullptr}; + SNode *dense_snode_{nullptr}; +}; + +TEST_F(RefineCoordinatesTest, Basic) { + auto *refine_ptr_fn = + InvokeRefineCoordinatesBuilder::build(ptr_snode_, tlctx_); + auto *refine_dense_fn = + InvokeRefineCoordinatesBuilder::build(dense_snode_, tlctx_); + + const BitsRange dense_bit_range{/*begin=*/0, + /*end=*/dense_snode_->extractors[0].num_bits}; + const BitsRange ptr_bit_range{ + /*begin=*/dense_bit_range.end, + /*end=*/dense_bit_range.end + ptr_snode_->extractors[0].num_bits}; + constexpr int kRootPhyCoord = 0; + for (int i = 0; i < kPointerSize; ++i) { + const int ptr_phy_coord = refine_ptr_fn(kRootPhyCoord, i); + for (int j = 0; j < kDenseSize; ++j) { + const int loop_index = refine_dense_fn(ptr_phy_coord, j); + // TODO: This is basically doing a lower_scalar_ptr() manually. + // We should modularize that function, and use it to generate IRs that + // does the bit extraction procedure. + const int dense_portion = dense_bit_range.extract(loop_index); + const int ptr_portion = ptr_bit_range.extract(loop_index); + EXPECT_EQ(dense_portion, j); + EXPECT_EQ(ptr_portion, i); + } + } +} + +} // namespace +} // namespace lang +} // namespace taichi diff --git a/tests/cpp/struct/fake_struct_compiler.h b/tests/cpp/struct/fake_struct_compiler.h index cdbcc51baf974..0c0784b62aa06 100644 --- a/tests/cpp/struct/fake_struct_compiler.h +++ b/tests/cpp/struct/fake_struct_compiler.h @@ -5,9 +5,6 @@ namespace lang { class FakeStructCompiler : public StructCompiler { public: - FakeStructCompiler() : StructCompiler(/*prog=*/nullptr) { - } - void generate_types(SNode &) override { }