diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 9d378349b0b09..148a74e228542 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -50,17 +50,17 @@ FunctionCreationGuard::FunctionCreationGuard( // emit into loop body function mb->func = body; - allocas = BasicBlock::Create(*mb->llvm_context, "allocs", body); + allocas = llvm::BasicBlock::Create(*mb->llvm_context, "allocs", body); old_entry = mb->entry_block; mb->entry_block = allocas; - entry = BasicBlock::Create(*mb->llvm_context, "entry", mb->func); + entry = llvm::BasicBlock::Create(*mb->llvm_context, "entry", mb->func); ip = mb->builder->saveIP(); mb->builder->SetInsertPoint(entry); auto body_bb = - BasicBlock::Create(*mb->llvm_context, "function_body", mb->func); + llvm::BasicBlock::Create(*mb->llvm_context, "function_body", mb->func); mb->builder->CreateBr(body_bb); mb->builder->SetInsertPoint(body_bb); } @@ -336,8 +336,8 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { llvm_val[stmt] = builder->CreateBitCast( llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); } else if (op == UnaryOpType::rsqrt) { - llvm::Function *sqrt_fn = Intrinsic::getDeclaration( - module.get(), Intrinsic::sqrt, input->getType()); + llvm::Function *sqrt_fn = llvm::Intrinsic::getDeclaration( + module.get(), llvm::Intrinsic::sqrt, input->getType()); auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt"); llvm_val[stmt] = builder->CreateFDiv( tlctx->get_constant(stmt->ret_type.data_type, 1.0), intermediate); @@ -607,11 +607,12 @@ void CodeGenLLVM::visit(TernaryOpStmt *stmt) { void CodeGenLLVM::visit(IfStmt *if_stmt) { // TODO: take care of vectorized cases - BasicBlock *true_block = - BasicBlock::Create(*llvm_context, "true_block", func); - BasicBlock *false_block = - BasicBlock::Create(*llvm_context, "false_block", func); - BasicBlock *after_if = BasicBlock::Create(*llvm_context, "after_if", func); + llvm::BasicBlock *true_block = + llvm::BasicBlock::Create(*llvm_context, "true_block", func); + llvm::BasicBlock *false_block = + llvm::BasicBlock::Create(*llvm_context, "false_block", func); + llvm::BasicBlock *after_if = + llvm::BasicBlock::Create(*llvm_context, "after_if", func); builder->CreateCondBr( builder->CreateICmpNE(llvm_val[if_stmt->cond], tlctx->get_constant(0)), true_block, false_block); @@ -632,7 +633,7 @@ llvm::Value *CodeGenLLVM::create_print(std::string tag, DataType dt, llvm::Value *value) { TI_ASSERT(arch_use_host_memory(kernel->arch)); - std::vector args; + std::vector args; std::string format = data_type_format(dt); auto runtime_printf = call("LLVMRuntime_get_host_printf", get_runtime()); args.push_back(builder->CreateGlobalStringPtr( @@ -647,7 +648,7 @@ llvm::Value *CodeGenLLVM::create_print(std::string tag, void CodeGenLLVM::visit(PrintStmt *stmt) { TI_ASSERT(stmt->width() == 1); - std::vector args; + std::vector args; std::string formats; for (auto const &content : stmt->contents) { if (std::holds_alternative(content)) { @@ -700,6 +701,8 @@ void CodeGenLLVM::visit(ConstStmt *stmt) { } void CodeGenLLVM::visit(WhileControlStmt *stmt) { + using namespace llvm; + BasicBlock *after_break = BasicBlock::Create(*llvm_context, "after_break", func); TI_ASSERT(current_while_after_loop); @@ -710,6 +713,7 @@ void CodeGenLLVM::visit(WhileControlStmt *stmt) { } void CodeGenLLVM::visit(ContinueStmt *stmt) { + using namespace llvm; if (stmt->as_return()) { builder->CreateRetVoid(); } else { @@ -724,6 +728,7 @@ void CodeGenLLVM::visit(ContinueStmt *stmt) { } void CodeGenLLVM::visit(WhileStmt *stmt) { + using namespace llvm; BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body", func); builder->CreateBr(body); builder->SetInsertPoint(body); @@ -785,6 +790,7 @@ void CodeGenLLVM::create_increment(llvm::Value *ptr, llvm::Value *value) { } void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) { + using namespace llvm; BasicBlock *body = BasicBlock::Create(*llvm_context, "for_loop_body", func); BasicBlock *loop_inc = BasicBlock::Create(*llvm_context, "for_loop_inc", func); @@ -856,7 +862,8 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) { llvm::Type *dest_ty = nullptr; if (stmt->is_ptr) { - dest_ty = PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0); + dest_ty = + llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0); llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty); } else { dest_ty = tlctx->get_data_type(stmt->ret_type.data_type); @@ -868,7 +875,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) { } void CodeGenLLVM::visit(KernelReturnStmt *stmt) { - if (stmt->is_ptr) { + if (stmt->ret_type.is_pointer()) { TI_NOT_IMPLEMENTED } else { auto intermediate_bits = @@ -907,7 +914,7 @@ void CodeGenLLVM::visit(AssertStmt *stmt) { // TODO: maybe let all asserts in a single offload share a single buffer? auto arguments = create_entry_block_alloca(argument_buffer_size); - std::vector args; + std::vector args; args.emplace_back(get_runtime()); args.emplace_back(llvm_val[stmt->cond]); args.emplace_back(builder->CreateGlobalStringPtr(stmt->text)); @@ -1131,9 +1138,10 @@ llvm::Value *CodeGenLLVM::call(SNode *snode, void CodeGenLLVM::visit(GetRootStmt *stmt) { llvm_val[stmt] = builder->CreateBitCast( - get_root(), PointerType::get(StructCompilerLLVM::get_llvm_node_type( - module.get(), prog->snode_root.get()), - 0)); + get_root(), + llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type( + module.get(), prog->snode_root.get()), + 0)); } void CodeGenLLVM::visit(BitExtractStmt *stmt) { @@ -1196,11 +1204,11 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { auto ch = create_call( stmt->output_snode->get_ch_from_parent_func_name(), {builder->CreateBitCast(llvm_val[stmt->input_ptr], - PointerType::getInt8PtrTy(*llvm_context))}); + llvm::PointerType::getInt8PtrTy(*llvm_context))}); llvm_val[stmt] = builder->CreateBitCast( - ch, PointerType::get(StructCompilerLLVM::get_llvm_node_type( - module.get(), stmt->output_snode), - 0)); + ch, llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type( + module.get(), stmt->output_snode), + 0)); } void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { @@ -1247,13 +1255,14 @@ std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt, task_function_type = llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), - {PointerType::get(context_ty, 0)}, false); + {llvm::PointerType::get(context_ty, 0)}, false); auto task_kernel_name = fmt::format("{}_{}_{}{}", kernel_name, task_counter, stmt->task_name(), suffix); task_counter += 1; - func = Function::Create(task_function_type, Function::ExternalLinkage, - task_kernel_name, module.get()); + func = llvm::Function::Create(task_function_type, + llvm::Function::ExternalLinkage, + task_kernel_name, module.get()); current_task = std::make_unique(this); current_task->begin(task_kernel_name); @@ -1267,10 +1276,10 @@ std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt, func->addParamAttr(0, llvm::Attribute::ByVal); // entry_block has all the allocas - this->entry_block = BasicBlock::Create(*llvm_context, "entry", func); + this->entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func); // The real function body - func_body_bb = BasicBlock::Create(*llvm_context, "body", func); + func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body", func); builder->SetInsertPoint(func_body_bb); return task_kernel_name; } @@ -1287,7 +1296,7 @@ void CodeGenLLVM::finalize_offloaded_task_function() { "unoptimized LLVM IR (generic)"); writer.write(module.get()); } - TI_ASSERT(!llvm::verifyFunction(*func, &errs())); + TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs())); // TI_INFO("Kernel function verified."); } @@ -1314,6 +1323,7 @@ std::tuple CodeGenLLVM::get_range_for_bounds( } void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) { + using namespace llvm; // TODO: instead of constructing tons of LLVM IR, writing the logic in // runtime.cpp may be a cleaner solution. See // CodeGenLLVMCPU::create_offload_range_for as an example. diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 943b58dbec3a7..5afb14ac9aeea 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -10,8 +10,6 @@ TLANG_NAMESPACE_BEGIN -using namespace llvm; - class CodeGenLLVM; class OffloadedTask { @@ -57,7 +55,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { IRNode *ir; Program *prog; std::string kernel_name; - std::vector kernel_args; + std::vector kernel_args; llvm::Type *context_ty; llvm::Type *physical_coordinate_ty; llvm::Value *current_coordinates; @@ -73,7 +71,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { OffloadedStmt *current_offload{nullptr}; std::unique_ptr current_task; std::vector offloaded_tasks; - BasicBlock *func_body_bb; + llvm::BasicBlock *func_body_bb; std::unordered_map> loop_vars_llvm; @@ -131,10 +129,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void emit_gc(OffloadedStmt *stmt); - llvm::Value *create_call(llvm::Value *func, std::vector args = {}); + llvm::Value *create_call(llvm::Value *func, + std::vector args = {}); llvm::Value *create_call(std::string func_name, - std::vector args = {}); + std::vector args = {}); llvm::Value *call(SNode *snode, llvm::Value *node_ptr, const std::string &method, diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index cb27783d0a546..1f864055d3f2e 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -172,7 +172,6 @@ Stmt::Stmt() : field_manager(this), fields_registered(false) { instance_id = instance_id_counter++; id = instance_id; erased = false; - is_ptr = false; } Stmt::Stmt(const Stmt &stmt) : field_manager(this), fields_registered(false) { @@ -180,7 +179,6 @@ Stmt::Stmt(const Stmt &stmt) : field_manager(this), fields_registered(false) { instance_id = instance_id_counter++; id = instance_id; erased = stmt.erased; - is_ptr = stmt.is_ptr; tb = stmt.tb; ret_type = stmt.ret_type; } @@ -240,7 +238,7 @@ std::string Stmt::type_hint() const { if (ret_type.data_type == PrimitiveType::unknown) return ""; else - return fmt::format("<{}>{}", ret_type.str(), is_ptr ? "ptr " : " "); + return fmt::format("<{}>", ret_type.str()); } std::string Stmt::type() { diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 12c4d8c2d04a9..be0e8b8b32e5b 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -7,6 +7,7 @@ #include #include #include + #include "taichi/common/core.h" #include "taichi/util/bit.h" #include "taichi/lang_util.h" @@ -530,7 +531,6 @@ class Stmt : public IRNode { bool erased; bool fields_registered; std::string tb; - bool is_ptr; LegacyVectorType ret_type; Stmt(); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 0c6924bb77049..795387d2edf29 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -101,6 +101,7 @@ class UnaryOpStmt : public Stmt { class ArgLoadStmt : public Stmt { public: int arg_id; + bool is_ptr; ArgLoadStmt(int arg_id, DataType dt, bool is_ptr = false) : arg_id(arg_id) { this->ret_type = LegacyVectorType(1, dt); diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp new file mode 100644 index 0000000000000..545a9018fcb82 --- /dev/null +++ b/taichi/ir/type.cpp @@ -0,0 +1,45 @@ +#include "taichi/ir/type.h" +#include "taichi/program/program.h" + +TLANG_NAMESPACE_BEGIN + +// Note: these primitive types should never be freed. They are supposed to live +// together with the process. This is a temporary solution. Later we should +// manage its ownership more systematically. + +// This part doesn't look good, but we will remove it soon anyway. +#define PER_TYPE(x) \ + DataType PrimitiveType::x = \ + DataType(Program::get_type_factory().get_primitive_type( \ + PrimitiveType::primitive_type::x)); + +#include "taichi/inc/data_type.inc.h" +#undef PER_TYPE + +DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) { +} + +DataType PrimitiveType::get(PrimitiveType::primitive_type t) { + if (false) { + } +#define PER_TYPE(x) else if (t == primitive_type::x) return PrimitiveType::x; +#include "taichi/inc/data_type.inc.h" +#undef PER_TYPE + else { + TI_NOT_IMPLEMENTED + } +} + +std::size_t DataType::hash() const { + if (auto primitive = dynamic_cast(ptr_)) { + return (std::size_t)primitive->type; + } else { + TI_NOT_IMPLEMENTED + } +} + +std::string PrimitiveType::to_string() const { + return data_type_name(DataType(this)); +} + +TLANG_NAMESPACE_END diff --git a/taichi/ir/type.h b/taichi/ir/type.h new file mode 100644 index 0000000000000..31cc19ed963ff --- /dev/null +++ b/taichi/ir/type.h @@ -0,0 +1,118 @@ +#pragma once + +#include "taichi/common/core.h" + +TLANG_NAMESPACE_BEGIN + +class Type { + public: + virtual std::string to_string() const = 0; + virtual ~Type() { + } +}; + +// A "Type" handle. This should be removed later. +class DataType { + public: + DataType(); + + DataType(const Type *ptr) : ptr_(ptr) { + } + + bool operator==(const DataType &o) const { + return ptr_ == o.ptr_; + } + + bool operator!=(const DataType &o) const { + return !(*this == o); + } + + std::size_t hash() const; + + std::string to_string() const { + return ptr_->to_string(); + }; + + // TODO: DataType itself should be a pointer in the future + const Type *get_ptr() const { + return ptr_; + } + + private: + const Type *ptr_; +}; + +class PrimitiveType : public Type { + public: + enum class primitive_type : int { +#define PER_TYPE(x) x, +#include "taichi/inc/data_type.inc.h" +#undef PER_TYPE + }; + +#define PER_TYPE(x) static DataType x; +#include "taichi/inc/data_type.inc.h" +#undef PER_TYPE + + primitive_type type; + + PrimitiveType(primitive_type type) : type(type) { + } + + std::string to_string() const override; + + static DataType get(primitive_type type); +}; + +class PointerType : public Type { + public: + PointerType(Type *pointee, bool is_bit_pointer) + : pointee_(pointee), is_bit_pointer_(is_bit_pointer) { + } + + Type *get_pointee_type() const { + return pointee_; + } + + auto get_addr_space() const { + return addr_space_; + } + + bool is_bit_pointer() const { + return is_bit_pointer_; + } + + std::string to_string() const override { + return fmt::format("*{}", pointee_->to_string()); + }; + + private: + Type *pointee_{nullptr}; + int addr_space_{0}; // TODO: make this an enum + bool is_bit_pointer_{false}; +}; + +class VectorType : public Type { + public: + VectorType(int num_elements, Type *element) + : num_elements_(num_elements), element_(element) { + } + + Type *get_element_type() const { + return element_; + } + + int get_num_elements() const { + return num_elements_; + } + + std::string to_string() const override { + return fmt::format("[{} x {}]", num_elements_, element_->to_string()); + }; + + private: + int num_elements_{0}; + Type *element_{nullptr}; +}; + +TLANG_NAMESPACE_END diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index bb4755cab7575..645b8f576bcb6 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/type_factory.h" +#include "type_factory.h" TLANG_NAMESPACE_BEGIN @@ -12,4 +13,20 @@ Type *TypeFactory::get_primitive_type(PrimitiveType::primitive_type id) { return primitive_types_[id].get(); } +Type *TypeFactory::get_vector_type(int num_elements, Type *element) { + auto key = std::make_pair(num_elements, element); + if (vector_types_.find(key) == vector_types_.end()) { + vector_types_[key] = std::make_unique(num_elements, element); + } + return vector_types_[key].get(); +} + +Type *TypeFactory::get_pointer_type(Type *element) { + auto key = element; // may need to add is_bit_ptr later + if (pointer_types_.find(key) == pointer_types_.end()) { + pointer_types_[key] = std::make_unique(element, false); + } + return pointer_types_[key].get(); +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index 8d1a5903b7cdf..8a74b4ddc8e7f 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -1,3 +1,5 @@ +#pragma once + #include "taichi/lang_util.h" #include @@ -8,10 +10,20 @@ class TypeFactory { public: Type *get_primitive_type(PrimitiveType::primitive_type id); + Type *get_vector_type(int num_elements, Type *element); + + Type *get_pointer_type(Type *element); + private: std::unordered_map> primitive_types_; + // TODO: use unordered map + std::map, std::unique_ptr> vector_types_; + + // TODO: is_bit_ptr? + std::map> pointer_types_; + std::mutex mut_; }; diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index 560d356516b34..d6ea762c8164b 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -30,45 +30,6 @@ real get_cpu_frequency() { real default_measurement_time = 1; -// Note: these primitive types should never be freed. They are supposed to live -// together with the process. This is a temporary solution. Later we should -// manage its ownership more systematically. - -// This part doesn't look good, but we will remove it soon anyway. -#define PER_TYPE(x) \ - DataType PrimitiveType::x = \ - DataType(Program::get_type_factory().get_primitive_type( \ - PrimitiveType::primitive_type::x)); - -#include "taichi/inc/data_type.inc.h" -#undef PER_TYPE - -DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) { -} - -DataType PrimitiveType::get(PrimitiveType::primitive_type t) { - if (false) { - } -#define PER_TYPE(x) else if (t == primitive_type::x) return PrimitiveType::x; -#include "taichi/inc/data_type.inc.h" -#undef PER_TYPE - else { - TI_NOT_IMPLEMENTED - } -} - -std::size_t DataType::hash() const { - if (auto primitive = dynamic_cast(ptr_)) { - return (std::size_t)primitive->type; - } else { - TI_NOT_IMPLEMENTED - } -} - -std::string PrimitiveType::to_string() const { - return data_type_name(DataType(this)); -} - real measure_cpe(std::function target, int64 elements_per_call, real time_second) { diff --git a/taichi/lang_util.h b/taichi/lang_util.h index d51a4bc3cff79..106d389cf585c 100644 --- a/taichi/lang_util.h +++ b/taichi/lang_util.h @@ -4,6 +4,7 @@ #include "taichi/util/io.h" #include "taichi/common/core.h" #include "taichi/system/profiler.h" +#include "taichi/ir/type.h" TLANG_NAMESPACE_BEGIN @@ -19,66 +20,6 @@ struct Context; using FunctionType = std::function; -class Type { - public: - virtual std::string to_string() const = 0; - virtual ~Type() { - } -}; - -// A "Type" handle. This should be removed later. -class DataType { - public: - DataType(); - - DataType(const Type *ptr) : ptr_(ptr) { - } - - bool operator==(const DataType &o) const { - return ptr_ == o.ptr_; - } - - bool operator!=(const DataType &o) const { - return !(*this == o); - } - - std::size_t hash() const; - - std::string to_string() const { - return ptr_->to_string(); - }; - - // TODO: DataType itself should be a pointer in the future - const Type *get_ptr() const { - return ptr_; - } - - private: - const Type *ptr_; -}; - -class PrimitiveType : public Type { - public: - enum class primitive_type : int { -#define PER_TYPE(x) x, -#include "taichi/inc/data_type.inc.h" -#undef PER_TYPE - }; - -#define PER_TYPE(x) static DataType x; -#include "taichi/inc/data_type.inc.h" -#undef PER_TYPE - - primitive_type type; - - PrimitiveType(primitive_type type) : type(type) { - } - - std::string to_string() const override; - - static DataType get(primitive_type type); -}; - template inline DataType get_data_type() { if (std::is_same()) {