Skip to content

Commit

Permalink
[type] Add basic implementations of VectorType and PointerType (#1948)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu authored Oct 13, 2020
1 parent 2a71e2a commit dcd5d7d
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 137 deletions.
66 changes: 38 additions & 28 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ FunctionCreationGuard::FunctionCreationGuard(
// emit into loop body function
mb->func = body;

allocas = BasicBlock::Create(*mb->llvm_context, "allocs", body);
allocas = llvm::BasicBlock::Create(*mb->llvm_context, "allocs", body);
old_entry = mb->entry_block;
mb->entry_block = allocas;

entry = BasicBlock::Create(*mb->llvm_context, "entry", mb->func);
entry = llvm::BasicBlock::Create(*mb->llvm_context, "entry", mb->func);

ip = mb->builder->saveIP();
mb->builder->SetInsertPoint(entry);

auto body_bb =
BasicBlock::Create(*mb->llvm_context, "function_body", mb->func);
llvm::BasicBlock::Create(*mb->llvm_context, "function_body", mb->func);
mb->builder->CreateBr(body_bb);
mb->builder->SetInsertPoint(body_bb);
}
Expand Down Expand Up @@ -337,8 +337,8 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) {
llvm_val[stmt] = builder->CreateBitCast(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
} else if (op == UnaryOpType::rsqrt) {
llvm::Function *sqrt_fn = Intrinsic::getDeclaration(
module.get(), Intrinsic::sqrt, input->getType());
llvm::Function *sqrt_fn = llvm::Intrinsic::getDeclaration(
module.get(), llvm::Intrinsic::sqrt, input->getType());
auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt");
llvm_val[stmt] = builder->CreateFDiv(
tlctx->get_constant(stmt->ret_type.data_type, 1.0), intermediate);
Expand Down Expand Up @@ -608,11 +608,12 @@ void CodeGenLLVM::visit(TernaryOpStmt *stmt) {

void CodeGenLLVM::visit(IfStmt *if_stmt) {
// TODO: take care of vectorized cases
BasicBlock *true_block =
BasicBlock::Create(*llvm_context, "true_block", func);
BasicBlock *false_block =
BasicBlock::Create(*llvm_context, "false_block", func);
BasicBlock *after_if = BasicBlock::Create(*llvm_context, "after_if", func);
llvm::BasicBlock *true_block =
llvm::BasicBlock::Create(*llvm_context, "true_block", func);
llvm::BasicBlock *false_block =
llvm::BasicBlock::Create(*llvm_context, "false_block", func);
llvm::BasicBlock *after_if =
llvm::BasicBlock::Create(*llvm_context, "after_if", func);
builder->CreateCondBr(
builder->CreateICmpNE(llvm_val[if_stmt->cond], tlctx->get_constant(0)),
true_block, false_block);
Expand All @@ -633,7 +634,7 @@ llvm::Value *CodeGenLLVM::create_print(std::string tag,
DataType dt,
llvm::Value *value) {
TI_ASSERT(arch_use_host_memory(kernel->arch));
std::vector<Value *> args;
std::vector<llvm::Value *> args;
std::string format = data_type_format(dt);
auto runtime_printf = call("LLVMRuntime_get_host_printf", get_runtime());
args.push_back(builder->CreateGlobalStringPtr(
Expand All @@ -648,7 +649,7 @@ llvm::Value *CodeGenLLVM::create_print(std::string tag,

void CodeGenLLVM::visit(PrintStmt *stmt) {
TI_ASSERT(stmt->width() == 1);
std::vector<Value *> args;
std::vector<llvm::Value *> args;
std::string formats;
for (auto const &content : stmt->contents) {
if (std::holds_alternative<Stmt *>(content)) {
Expand Down Expand Up @@ -701,6 +702,8 @@ void CodeGenLLVM::visit(ConstStmt *stmt) {
}

void CodeGenLLVM::visit(WhileControlStmt *stmt) {
using namespace llvm;

BasicBlock *after_break =
BasicBlock::Create(*llvm_context, "after_break", func);
TI_ASSERT(current_while_after_loop);
Expand All @@ -711,6 +714,7 @@ void CodeGenLLVM::visit(WhileControlStmt *stmt) {
}

void CodeGenLLVM::visit(ContinueStmt *stmt) {
using namespace llvm;
if (stmt->as_return()) {
builder->CreateRetVoid();
} else {
Expand All @@ -725,6 +729,7 @@ void CodeGenLLVM::visit(ContinueStmt *stmt) {
}

void CodeGenLLVM::visit(WhileStmt *stmt) {
using namespace llvm;
BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body", func);
builder->CreateBr(body);
builder->SetInsertPoint(body);
Expand Down Expand Up @@ -786,6 +791,7 @@ void CodeGenLLVM::create_increment(llvm::Value *ptr, llvm::Value *value) {
}

void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
using namespace llvm;
BasicBlock *body = BasicBlock::Create(*llvm_context, "for_loop_body", func);
BasicBlock *loop_inc =
BasicBlock::Create(*llvm_context, "for_loop_inc", func);
Expand Down Expand Up @@ -857,7 +863,8 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {

llvm::Type *dest_ty = nullptr;
if (stmt->is_ptr) {
dest_ty = PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0);
dest_ty =
llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0);
llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty);
} else {
dest_ty = tlctx->get_data_type(stmt->ret_type.data_type);
Expand All @@ -869,7 +876,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
}

void CodeGenLLVM::visit(KernelReturnStmt *stmt) {
if (stmt->is_ptr) {
if (stmt->ret_type.is_pointer()) {
TI_NOT_IMPLEMENTED
} else {
auto intermediate_bits =
Expand Down Expand Up @@ -908,7 +915,7 @@ void CodeGenLLVM::visit(AssertStmt *stmt) {
// TODO: maybe let all asserts in a single offload share a single buffer?
auto arguments = create_entry_block_alloca(argument_buffer_size);

std::vector<Value *> args;
std::vector<llvm::Value *> args;
args.emplace_back(get_runtime());
args.emplace_back(llvm_val[stmt->cond]);
args.emplace_back(builder->CreateGlobalStringPtr(stmt->text));
Expand Down Expand Up @@ -1132,9 +1139,10 @@ llvm::Value *CodeGenLLVM::call(SNode *snode,

void CodeGenLLVM::visit(GetRootStmt *stmt) {
llvm_val[stmt] = builder->CreateBitCast(
get_root(), PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), prog->snode_root.get()),
0));
get_root(),
llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), prog->snode_root.get()),
0));
}

void CodeGenLLVM::visit(BitExtractStmt *stmt) {
Expand Down Expand Up @@ -1197,11 +1205,11 @@ void CodeGenLLVM::visit(GetChStmt *stmt) {
auto ch = create_call(
stmt->output_snode->get_ch_from_parent_func_name(),
{builder->CreateBitCast(llvm_val[stmt->input_ptr],
PointerType::getInt8PtrTy(*llvm_context))});
llvm::PointerType::getInt8PtrTy(*llvm_context))});
llvm_val[stmt] = builder->CreateBitCast(
ch, PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), stmt->output_snode),
0));
ch, llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), stmt->output_snode),
0));
}

void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {
Expand Down Expand Up @@ -1248,13 +1256,14 @@ std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt,

task_function_type =
llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context),
{PointerType::get(context_ty, 0)}, false);
{llvm::PointerType::get(context_ty, 0)}, false);

auto task_kernel_name = fmt::format("{}_{}_{}{}", kernel_name, task_counter,
stmt->task_name(), suffix);
task_counter += 1;
func = Function::Create(task_function_type, Function::ExternalLinkage,
task_kernel_name, module.get());
func = llvm::Function::Create(task_function_type,
llvm::Function::ExternalLinkage,
task_kernel_name, module.get());

current_task = std::make_unique<OffloadedTask>(this);
current_task->begin(task_kernel_name);
Expand All @@ -1268,10 +1277,10 @@ std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt,
func->addParamAttr(0, llvm::Attribute::ByVal);

// entry_block has all the allocas
this->entry_block = BasicBlock::Create(*llvm_context, "entry", func);
this->entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func);

// The real function body
func_body_bb = BasicBlock::Create(*llvm_context, "body", func);
func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body", func);
builder->SetInsertPoint(func_body_bb);
return task_kernel_name;
}
Expand All @@ -1288,7 +1297,7 @@ void CodeGenLLVM::finalize_offloaded_task_function() {
"unoptimized LLVM IR (generic)");
writer.write(module.get());
}
TI_ASSERT(!llvm::verifyFunction(*func, &errs()));
TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs()));
// TI_INFO("Kernel function verified.");
}

Expand All @@ -1315,6 +1324,7 @@ std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::get_range_for_bounds(
}

void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {
using namespace llvm;
// TODO: instead of constructing tons of LLVM IR, writing the logic in
// runtime.cpp may be a cleaner solution. See
// CodeGenLLVMCPU::create_offload_range_for as an example.
Expand Down
11 changes: 5 additions & 6 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

TLANG_NAMESPACE_BEGIN

using namespace llvm;

class CodeGenLLVM;

class OffloadedTask {
Expand Down Expand Up @@ -57,7 +55,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
IRNode *ir;
Program *prog;
std::string kernel_name;
std::vector<Value *> kernel_args;
std::vector<llvm::Value *> kernel_args;
llvm::Type *context_ty;
llvm::Type *physical_coordinate_ty;
llvm::Value *current_coordinates;
Expand All @@ -73,7 +71,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
OffloadedStmt *current_offload{nullptr};
std::unique_ptr<OffloadedTask> current_task;
std::vector<OffloadedTask> offloaded_tasks;
BasicBlock *func_body_bb;
llvm::BasicBlock *func_body_bb;

std::unordered_map<const Stmt *, std::vector<llvm::Value *>> loop_vars_llvm;

Expand Down Expand Up @@ -131,10 +129,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void emit_gc(OffloadedStmt *stmt);

llvm::Value *create_call(llvm::Value *func, std::vector<Value *> args = {});
llvm::Value *create_call(llvm::Value *func,
std::vector<llvm::Value *> args = {});

llvm::Value *create_call(std::string func_name,
std::vector<Value *> args = {});
std::vector<llvm::Value *> args = {});
llvm::Value *call(SNode *snode,
llvm::Value *node_ptr,
const std::string &method,
Expand Down
4 changes: 1 addition & 3 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,13 @@ Stmt::Stmt() : field_manager(this), fields_registered(false) {
instance_id = instance_id_counter++;
id = instance_id;
erased = false;
is_ptr = false;
}

Stmt::Stmt(const Stmt &stmt) : field_manager(this), fields_registered(false) {
parent = stmt.parent;
instance_id = instance_id_counter++;
id = instance_id;
erased = stmt.erased;
is_ptr = stmt.is_ptr;
tb = stmt.tb;
ret_type = stmt.ret_type;
}
Expand Down Expand Up @@ -240,7 +238,7 @@ std::string Stmt::type_hint() const {
if (ret_type.data_type == PrimitiveType::unknown)
return "";
else
return fmt::format("<{}>{}", ret_type.str(), is_ptr ? "ptr " : " ");
return fmt::format("<{}>", ret_type.str());
}

std::string Stmt::type() {
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <unordered_map>
#include <variant>
#include <tuple>

#include "taichi/common/core.h"
#include "taichi/util/bit.h"
#include "taichi/lang_util.h"
Expand Down Expand Up @@ -530,7 +531,6 @@ class Stmt : public IRNode {
bool erased;
bool fields_registered;
std::string tb;
bool is_ptr;
LegacyVectorType ret_type;

Stmt();
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class UnaryOpStmt : public Stmt {
class ArgLoadStmt : public Stmt {
public:
int arg_id;
bool is_ptr;

ArgLoadStmt(int arg_id, DataType dt, bool is_ptr = false) : arg_id(arg_id) {
this->ret_type = LegacyVectorType(1, dt);
Expand Down
45 changes: 45 additions & 0 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "taichi/ir/type.h"
#include "taichi/program/program.h"

TLANG_NAMESPACE_BEGIN

// Note: these primitive types should never be freed. They are supposed to live
// together with the process. This is a temporary solution. Later we should
// manage its ownership more systematically.

// This part doesn't look good, but we will remove it soon anyway.
#define PER_TYPE(x) \
DataType PrimitiveType::x = \
DataType(Program::get_type_factory().get_primitive_type( \
PrimitiveType::primitive_type::x));

#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) {
}

DataType PrimitiveType::get(PrimitiveType::primitive_type t) {
if (false) {
}
#define PER_TYPE(x) else if (t == primitive_type::x) return PrimitiveType::x;
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE
else {
TI_NOT_IMPLEMENTED
}
}

std::size_t DataType::hash() const {
if (auto primitive = dynamic_cast<const PrimitiveType *>(ptr_)) {
return (std::size_t)primitive->type;
} else {
TI_NOT_IMPLEMENTED
}
}

std::string PrimitiveType::to_string() const {
return data_type_name(DataType(this));
}

TLANG_NAMESPACE_END
Loading

0 comments on commit dcd5d7d

Please sign in to comment.