-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[type] Add basic implementations of VectorType and PointerType #1948
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,17 +50,17 @@ FunctionCreationGuard::FunctionCreationGuard( | |
// emit into loop body function | ||
mb->func = body; | ||
|
||
allocas = BasicBlock::Create(*mb->llvm_context, "allocs", body); | ||
allocas = llvm::BasicBlock::Create(*mb->llvm_context, "allocs", body); | ||
old_entry = mb->entry_block; | ||
mb->entry_block = allocas; | ||
|
||
entry = BasicBlock::Create(*mb->llvm_context, "entry", mb->func); | ||
entry = llvm::BasicBlock::Create(*mb->llvm_context, "entry", mb->func); | ||
|
||
ip = mb->builder->saveIP(); | ||
mb->builder->SetInsertPoint(entry); | ||
|
||
auto body_bb = | ||
BasicBlock::Create(*mb->llvm_context, "function_body", mb->func); | ||
llvm::BasicBlock::Create(*mb->llvm_context, "function_body", mb->func); | ||
mb->builder->CreateBr(body_bb); | ||
mb->builder->SetInsertPoint(body_bb); | ||
} | ||
|
@@ -336,8 +336,8 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { | |
llvm_val[stmt] = builder->CreateBitCast( | ||
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); | ||
} else if (op == UnaryOpType::rsqrt) { | ||
llvm::Function *sqrt_fn = Intrinsic::getDeclaration( | ||
module.get(), Intrinsic::sqrt, input->getType()); | ||
llvm::Function *sqrt_fn = llvm::Intrinsic::getDeclaration( | ||
module.get(), llvm::Intrinsic::sqrt, input->getType()); | ||
auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt"); | ||
llvm_val[stmt] = builder->CreateFDiv( | ||
tlctx->get_constant(stmt->ret_type.data_type, 1.0), intermediate); | ||
|
@@ -607,11 +607,12 @@ void CodeGenLLVM::visit(TernaryOpStmt *stmt) { | |
|
||
void CodeGenLLVM::visit(IfStmt *if_stmt) { | ||
// TODO: take care of vectorized cases | ||
BasicBlock *true_block = | ||
BasicBlock::Create(*llvm_context, "true_block", func); | ||
BasicBlock *false_block = | ||
BasicBlock::Create(*llvm_context, "false_block", func); | ||
BasicBlock *after_if = BasicBlock::Create(*llvm_context, "after_if", func); | ||
llvm::BasicBlock *true_block = | ||
llvm::BasicBlock::Create(*llvm_context, "true_block", func); | ||
llvm::BasicBlock *false_block = | ||
llvm::BasicBlock::Create(*llvm_context, "false_block", func); | ||
llvm::BasicBlock *after_if = | ||
llvm::BasicBlock::Create(*llvm_context, "after_if", func); | ||
builder->CreateCondBr( | ||
builder->CreateICmpNE(llvm_val[if_stmt->cond], tlctx->get_constant(0)), | ||
true_block, false_block); | ||
|
@@ -632,7 +633,7 @@ llvm::Value *CodeGenLLVM::create_print(std::string tag, | |
DataType dt, | ||
llvm::Value *value) { | ||
TI_ASSERT(arch_use_host_memory(kernel->arch)); | ||
std::vector<Value *> args; | ||
std::vector<llvm::Value *> args; | ||
std::string format = data_type_format(dt); | ||
auto runtime_printf = call("LLVMRuntime_get_host_printf", get_runtime()); | ||
args.push_back(builder->CreateGlobalStringPtr( | ||
|
@@ -647,7 +648,7 @@ llvm::Value *CodeGenLLVM::create_print(std::string tag, | |
|
||
void CodeGenLLVM::visit(PrintStmt *stmt) { | ||
TI_ASSERT(stmt->width() == 1); | ||
std::vector<Value *> args; | ||
std::vector<llvm::Value *> args; | ||
std::string formats; | ||
for (auto const &content : stmt->contents) { | ||
if (std::holds_alternative<Stmt *>(content)) { | ||
|
@@ -700,6 +701,8 @@ void CodeGenLLVM::visit(ConstStmt *stmt) { | |
} | ||
|
||
void CodeGenLLVM::visit(WhileControlStmt *stmt) { | ||
using namespace llvm; | ||
|
||
BasicBlock *after_break = | ||
BasicBlock::Create(*llvm_context, "after_break", func); | ||
TI_ASSERT(current_while_after_loop); | ||
|
@@ -710,6 +713,7 @@ void CodeGenLLVM::visit(WhileControlStmt *stmt) { | |
} | ||
|
||
void CodeGenLLVM::visit(ContinueStmt *stmt) { | ||
using namespace llvm; | ||
if (stmt->as_return()) { | ||
builder->CreateRetVoid(); | ||
} else { | ||
|
@@ -724,6 +728,7 @@ void CodeGenLLVM::visit(ContinueStmt *stmt) { | |
} | ||
|
||
void CodeGenLLVM::visit(WhileStmt *stmt) { | ||
using namespace llvm; | ||
BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body", func); | ||
builder->CreateBr(body); | ||
builder->SetInsertPoint(body); | ||
|
@@ -785,6 +790,7 @@ void CodeGenLLVM::create_increment(llvm::Value *ptr, llvm::Value *value) { | |
} | ||
|
||
void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) { | ||
using namespace llvm; | ||
BasicBlock *body = BasicBlock::Create(*llvm_context, "for_loop_body", func); | ||
BasicBlock *loop_inc = | ||
BasicBlock::Create(*llvm_context, "for_loop_inc", func); | ||
|
@@ -856,7 +862,8 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) { | |
|
||
llvm::Type *dest_ty = nullptr; | ||
if (stmt->is_ptr) { | ||
dest_ty = PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0); | ||
dest_ty = | ||
llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0); | ||
llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty); | ||
} else { | ||
dest_ty = tlctx->get_data_type(stmt->ret_type.data_type); | ||
|
@@ -868,7 +875,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) { | |
} | ||
|
||
void CodeGenLLVM::visit(KernelReturnStmt *stmt) { | ||
if (stmt->is_ptr) { | ||
if (stmt->ret_type.is_pointer()) { | ||
Comment on lines
-871
to
+878
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the only substantial change in this file. Note that this part of the code is never used yet anyway. |
||
TI_NOT_IMPLEMENTED | ||
} else { | ||
auto intermediate_bits = | ||
|
@@ -907,7 +914,7 @@ void CodeGenLLVM::visit(AssertStmt *stmt) { | |
// TODO: maybe let all asserts in a single offload share a single buffer? | ||
auto arguments = create_entry_block_alloca(argument_buffer_size); | ||
|
||
std::vector<Value *> args; | ||
std::vector<llvm::Value *> args; | ||
args.emplace_back(get_runtime()); | ||
args.emplace_back(llvm_val[stmt->cond]); | ||
args.emplace_back(builder->CreateGlobalStringPtr(stmt->text)); | ||
|
@@ -1131,9 +1138,10 @@ llvm::Value *CodeGenLLVM::call(SNode *snode, | |
|
||
void CodeGenLLVM::visit(GetRootStmt *stmt) { | ||
llvm_val[stmt] = builder->CreateBitCast( | ||
get_root(), PointerType::get(StructCompilerLLVM::get_llvm_node_type( | ||
module.get(), prog->snode_root.get()), | ||
0)); | ||
get_root(), | ||
llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type( | ||
module.get(), prog->snode_root.get()), | ||
0)); | ||
} | ||
|
||
void CodeGenLLVM::visit(BitExtractStmt *stmt) { | ||
|
@@ -1196,11 +1204,11 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { | |
auto ch = create_call( | ||
stmt->output_snode->get_ch_from_parent_func_name(), | ||
{builder->CreateBitCast(llvm_val[stmt->input_ptr], | ||
PointerType::getInt8PtrTy(*llvm_context))}); | ||
llvm::PointerType::getInt8PtrTy(*llvm_context))}); | ||
llvm_val[stmt] = builder->CreateBitCast( | ||
ch, PointerType::get(StructCompilerLLVM::get_llvm_node_type( | ||
module.get(), stmt->output_snode), | ||
0)); | ||
ch, llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type( | ||
module.get(), stmt->output_snode), | ||
0)); | ||
} | ||
|
||
void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { | ||
|
@@ -1247,13 +1255,14 @@ std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt, | |
|
||
task_function_type = | ||
llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), | ||
{PointerType::get(context_ty, 0)}, false); | ||
{llvm::PointerType::get(context_ty, 0)}, false); | ||
|
||
auto task_kernel_name = fmt::format("{}_{}_{}{}", kernel_name, task_counter, | ||
stmt->task_name(), suffix); | ||
task_counter += 1; | ||
func = Function::Create(task_function_type, Function::ExternalLinkage, | ||
task_kernel_name, module.get()); | ||
func = llvm::Function::Create(task_function_type, | ||
llvm::Function::ExternalLinkage, | ||
task_kernel_name, module.get()); | ||
|
||
current_task = std::make_unique<OffloadedTask>(this); | ||
current_task->begin(task_kernel_name); | ||
|
@@ -1267,10 +1276,10 @@ std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt, | |
func->addParamAttr(0, llvm::Attribute::ByVal); | ||
|
||
// entry_block has all the allocas | ||
this->entry_block = BasicBlock::Create(*llvm_context, "entry", func); | ||
this->entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func); | ||
|
||
// The real function body | ||
func_body_bb = BasicBlock::Create(*llvm_context, "body", func); | ||
func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body", func); | ||
builder->SetInsertPoint(func_body_bb); | ||
return task_kernel_name; | ||
} | ||
|
@@ -1287,7 +1296,7 @@ void CodeGenLLVM::finalize_offloaded_task_function() { | |
"unoptimized LLVM IR (generic)"); | ||
writer.write(module.get()); | ||
} | ||
TI_ASSERT(!llvm::verifyFunction(*func, &errs())); | ||
TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs())); | ||
// TI_INFO("Kernel function verified."); | ||
} | ||
|
||
|
@@ -1314,6 +1323,7 @@ std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::get_range_for_bounds( | |
} | ||
|
||
void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) { | ||
using namespace llvm; | ||
// TODO: instead of constructing tons of LLVM IR, writing the logic in | ||
// runtime.cpp may be a cleaner solution. See | ||
// CodeGenLLVMCPU::create_offload_range_for as an example. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,8 +10,6 @@ | |
|
||
TLANG_NAMESPACE_BEGIN | ||
|
||
using namespace llvm; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing since |
||
|
||
class CodeGenLLVM; | ||
|
||
class OffloadedTask { | ||
|
@@ -57,7 +55,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { | |
IRNode *ir; | ||
Program *prog; | ||
std::string kernel_name; | ||
std::vector<Value *> kernel_args; | ||
std::vector<llvm::Value *> kernel_args; | ||
llvm::Type *context_ty; | ||
llvm::Type *physical_coordinate_ty; | ||
llvm::Value *current_coordinates; | ||
|
@@ -73,7 +71,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { | |
OffloadedStmt *current_offload{nullptr}; | ||
std::unique_ptr<OffloadedTask> current_task; | ||
std::vector<OffloadedTask> offloaded_tasks; | ||
BasicBlock *func_body_bb; | ||
llvm::BasicBlock *func_body_bb; | ||
|
||
std::unordered_map<const Stmt *, std::vector<llvm::Value *>> loop_vars_llvm; | ||
|
||
|
@@ -131,10 +129,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { | |
|
||
void emit_gc(OffloadedStmt *stmt); | ||
|
||
llvm::Value *create_call(llvm::Value *func, std::vector<Value *> args = {}); | ||
llvm::Value *create_call(llvm::Value *func, | ||
std::vector<llvm::Value *> args = {}); | ||
|
||
llvm::Value *create_call(std::string func_name, | ||
std::vector<Value *> args = {}); | ||
std::vector<llvm::Value *> args = {}); | ||
llvm::Value *call(SNode *snode, | ||
llvm::Value *node_ptr, | ||
const std::string &method, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#include "taichi/ir/type.h" | ||
#include "taichi/program/program.h" | ||
|
||
TLANG_NAMESPACE_BEGIN | ||
|
||
// Note: these primitive types should never be freed. They are supposed to live | ||
// together with the process. This is a temporary solution. Later we should | ||
// manage its ownership more systematically. | ||
|
||
// This part doesn't look good, but we will remove it soon anyway. | ||
#define PER_TYPE(x) \ | ||
DataType PrimitiveType::x = \ | ||
DataType(Program::get_type_factory().get_primitive_type( \ | ||
PrimitiveType::primitive_type::x)); | ||
|
||
#include "taichi/inc/data_type.inc.h" | ||
#undef PER_TYPE | ||
|
||
DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) { | ||
} | ||
|
||
DataType PrimitiveType::get(PrimitiveType::primitive_type t) { | ||
if (false) { | ||
} | ||
#define PER_TYPE(x) else if (t == primitive_type::x) return PrimitiveType::x; | ||
#include "taichi/inc/data_type.inc.h" | ||
#undef PER_TYPE | ||
else { | ||
TI_NOT_IMPLEMENTED | ||
} | ||
} | ||
|
||
std::size_t DataType::hash() const { | ||
if (auto primitive = dynamic_cast<const PrimitiveType *>(ptr_)) { | ||
return (std::size_t)primitive->type; | ||
} else { | ||
TI_NOT_IMPLEMENTED | ||
} | ||
} | ||
|
||
std::string PrimitiveType::to_string() const { | ||
return data_type_name(DataType(this)); | ||
} | ||
Comment on lines
+6
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is completely copy-pasted from |
||
|
||
TLANG_NAMESPACE_END |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes in this file are mostly adding the
llvm::
namespace.