Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] Add basic implementations of VectorType and PointerType #1948

Merged
merged 5 commits into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 38 additions & 28 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ FunctionCreationGuard::FunctionCreationGuard(
// emit into loop body function
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes in this file are mostly adding the llvm:: namespace.

mb->func = body;

allocas = BasicBlock::Create(*mb->llvm_context, "allocs", body);
allocas = llvm::BasicBlock::Create(*mb->llvm_context, "allocs", body);
old_entry = mb->entry_block;
mb->entry_block = allocas;

entry = BasicBlock::Create(*mb->llvm_context, "entry", mb->func);
entry = llvm::BasicBlock::Create(*mb->llvm_context, "entry", mb->func);

ip = mb->builder->saveIP();
mb->builder->SetInsertPoint(entry);

auto body_bb =
BasicBlock::Create(*mb->llvm_context, "function_body", mb->func);
llvm::BasicBlock::Create(*mb->llvm_context, "function_body", mb->func);
mb->builder->CreateBr(body_bb);
mb->builder->SetInsertPoint(body_bb);
}
Expand Down Expand Up @@ -336,8 +336,8 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) {
llvm_val[stmt] = builder->CreateBitCast(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
} else if (op == UnaryOpType::rsqrt) {
llvm::Function *sqrt_fn = Intrinsic::getDeclaration(
module.get(), Intrinsic::sqrt, input->getType());
llvm::Function *sqrt_fn = llvm::Intrinsic::getDeclaration(
module.get(), llvm::Intrinsic::sqrt, input->getType());
auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt");
llvm_val[stmt] = builder->CreateFDiv(
tlctx->get_constant(stmt->ret_type.data_type, 1.0), intermediate);
Expand Down Expand Up @@ -607,11 +607,12 @@ void CodeGenLLVM::visit(TernaryOpStmt *stmt) {

void CodeGenLLVM::visit(IfStmt *if_stmt) {
// TODO: take care of vectorized cases
BasicBlock *true_block =
BasicBlock::Create(*llvm_context, "true_block", func);
BasicBlock *false_block =
BasicBlock::Create(*llvm_context, "false_block", func);
BasicBlock *after_if = BasicBlock::Create(*llvm_context, "after_if", func);
llvm::BasicBlock *true_block =
llvm::BasicBlock::Create(*llvm_context, "true_block", func);
llvm::BasicBlock *false_block =
llvm::BasicBlock::Create(*llvm_context, "false_block", func);
llvm::BasicBlock *after_if =
llvm::BasicBlock::Create(*llvm_context, "after_if", func);
builder->CreateCondBr(
builder->CreateICmpNE(llvm_val[if_stmt->cond], tlctx->get_constant(0)),
true_block, false_block);
Expand All @@ -632,7 +633,7 @@ llvm::Value *CodeGenLLVM::create_print(std::string tag,
DataType dt,
llvm::Value *value) {
TI_ASSERT(arch_use_host_memory(kernel->arch));
std::vector<Value *> args;
std::vector<llvm::Value *> args;
std::string format = data_type_format(dt);
auto runtime_printf = call("LLVMRuntime_get_host_printf", get_runtime());
args.push_back(builder->CreateGlobalStringPtr(
Expand All @@ -647,7 +648,7 @@ llvm::Value *CodeGenLLVM::create_print(std::string tag,

void CodeGenLLVM::visit(PrintStmt *stmt) {
TI_ASSERT(stmt->width() == 1);
std::vector<Value *> args;
std::vector<llvm::Value *> args;
std::string formats;
for (auto const &content : stmt->contents) {
if (std::holds_alternative<Stmt *>(content)) {
Expand Down Expand Up @@ -700,6 +701,8 @@ void CodeGenLLVM::visit(ConstStmt *stmt) {
}

void CodeGenLLVM::visit(WhileControlStmt *stmt) {
using namespace llvm;

BasicBlock *after_break =
BasicBlock::Create(*llvm_context, "after_break", func);
TI_ASSERT(current_while_after_loop);
Expand All @@ -710,6 +713,7 @@ void CodeGenLLVM::visit(WhileControlStmt *stmt) {
}

void CodeGenLLVM::visit(ContinueStmt *stmt) {
using namespace llvm;
if (stmt->as_return()) {
builder->CreateRetVoid();
} else {
Expand All @@ -724,6 +728,7 @@ void CodeGenLLVM::visit(ContinueStmt *stmt) {
}

void CodeGenLLVM::visit(WhileStmt *stmt) {
using namespace llvm;
BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body", func);
builder->CreateBr(body);
builder->SetInsertPoint(body);
Expand Down Expand Up @@ -785,6 +790,7 @@ void CodeGenLLVM::create_increment(llvm::Value *ptr, llvm::Value *value) {
}

void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
using namespace llvm;
BasicBlock *body = BasicBlock::Create(*llvm_context, "for_loop_body", func);
BasicBlock *loop_inc =
BasicBlock::Create(*llvm_context, "for_loop_inc", func);
Expand Down Expand Up @@ -856,7 +862,8 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {

llvm::Type *dest_ty = nullptr;
if (stmt->is_ptr) {
dest_ty = PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0);
dest_ty =
llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0);
llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty);
} else {
dest_ty = tlctx->get_data_type(stmt->ret_type.data_type);
Expand All @@ -868,7 +875,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
}

void CodeGenLLVM::visit(KernelReturnStmt *stmt) {
if (stmt->is_ptr) {
if (stmt->ret_type.is_pointer()) {
Comment on lines -871 to +878
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only substantial change in this file. Note that this part of the code is never used yet anyway.

TI_NOT_IMPLEMENTED
} else {
auto intermediate_bits =
Expand Down Expand Up @@ -907,7 +914,7 @@ void CodeGenLLVM::visit(AssertStmt *stmt) {
// TODO: maybe let all asserts in a single offload share a single buffer?
auto arguments = create_entry_block_alloca(argument_buffer_size);

std::vector<Value *> args;
std::vector<llvm::Value *> args;
args.emplace_back(get_runtime());
args.emplace_back(llvm_val[stmt->cond]);
args.emplace_back(builder->CreateGlobalStringPtr(stmt->text));
Expand Down Expand Up @@ -1131,9 +1138,10 @@ llvm::Value *CodeGenLLVM::call(SNode *snode,

void CodeGenLLVM::visit(GetRootStmt *stmt) {
llvm_val[stmt] = builder->CreateBitCast(
get_root(), PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), prog->snode_root.get()),
0));
get_root(),
llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), prog->snode_root.get()),
0));
}

void CodeGenLLVM::visit(BitExtractStmt *stmt) {
Expand Down Expand Up @@ -1196,11 +1204,11 @@ void CodeGenLLVM::visit(GetChStmt *stmt) {
auto ch = create_call(
stmt->output_snode->get_ch_from_parent_func_name(),
{builder->CreateBitCast(llvm_val[stmt->input_ptr],
PointerType::getInt8PtrTy(*llvm_context))});
llvm::PointerType::getInt8PtrTy(*llvm_context))});
llvm_val[stmt] = builder->CreateBitCast(
ch, PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), stmt->output_snode),
0));
ch, llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), stmt->output_snode),
0));
}

void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {
Expand Down Expand Up @@ -1247,13 +1255,14 @@ std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt,

task_function_type =
llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context),
{PointerType::get(context_ty, 0)}, false);
{llvm::PointerType::get(context_ty, 0)}, false);

auto task_kernel_name = fmt::format("{}_{}_{}{}", kernel_name, task_counter,
stmt->task_name(), suffix);
task_counter += 1;
func = Function::Create(task_function_type, Function::ExternalLinkage,
task_kernel_name, module.get());
func = llvm::Function::Create(task_function_type,
llvm::Function::ExternalLinkage,
task_kernel_name, module.get());

current_task = std::make_unique<OffloadedTask>(this);
current_task->begin(task_kernel_name);
Expand All @@ -1267,10 +1276,10 @@ std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt,
func->addParamAttr(0, llvm::Attribute::ByVal);

// entry_block has all the allocas
this->entry_block = BasicBlock::Create(*llvm_context, "entry", func);
this->entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func);

// The real function body
func_body_bb = BasicBlock::Create(*llvm_context, "body", func);
func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body", func);
builder->SetInsertPoint(func_body_bb);
return task_kernel_name;
}
Expand All @@ -1287,7 +1296,7 @@ void CodeGenLLVM::finalize_offloaded_task_function() {
"unoptimized LLVM IR (generic)");
writer.write(module.get());
}
TI_ASSERT(!llvm::verifyFunction(*func, &errs()));
TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs()));
// TI_INFO("Kernel function verified.");
}

Expand All @@ -1314,6 +1323,7 @@ std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::get_range_for_bounds(
}

void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {
using namespace llvm;
// TODO: instead of constructing tons of LLVM IR, writing the logic in
// runtime.cpp may be a cleaner solution. See
// CodeGenLLVMCPU::create_offload_range_for as an example.
Expand Down
11 changes: 5 additions & 6 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

TLANG_NAMESPACE_BEGIN

using namespace llvm;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing since llvm::PointerType conflicts with taichi::lang::PointerType.


class CodeGenLLVM;

class OffloadedTask {
Expand Down Expand Up @@ -57,7 +55,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
IRNode *ir;
Program *prog;
std::string kernel_name;
std::vector<Value *> kernel_args;
std::vector<llvm::Value *> kernel_args;
llvm::Type *context_ty;
llvm::Type *physical_coordinate_ty;
llvm::Value *current_coordinates;
Expand All @@ -73,7 +71,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
OffloadedStmt *current_offload{nullptr};
std::unique_ptr<OffloadedTask> current_task;
std::vector<OffloadedTask> offloaded_tasks;
BasicBlock *func_body_bb;
llvm::BasicBlock *func_body_bb;

std::unordered_map<const Stmt *, std::vector<llvm::Value *>> loop_vars_llvm;

Expand Down Expand Up @@ -131,10 +129,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void emit_gc(OffloadedStmt *stmt);

llvm::Value *create_call(llvm::Value *func, std::vector<Value *> args = {});
llvm::Value *create_call(llvm::Value *func,
std::vector<llvm::Value *> args = {});

llvm::Value *create_call(std::string func_name,
std::vector<Value *> args = {});
std::vector<llvm::Value *> args = {});
llvm::Value *call(SNode *snode,
llvm::Value *node_ptr,
const std::string &method,
Expand Down
4 changes: 1 addition & 3 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,13 @@ Stmt::Stmt() : field_manager(this), fields_registered(false) {
instance_id = instance_id_counter++;
id = instance_id;
erased = false;
is_ptr = false;
}

Stmt::Stmt(const Stmt &stmt) : field_manager(this), fields_registered(false) {
parent = stmt.parent;
instance_id = instance_id_counter++;
id = instance_id;
erased = stmt.erased;
is_ptr = stmt.is_ptr;
tb = stmt.tb;
ret_type = stmt.ret_type;
}
Expand Down Expand Up @@ -240,7 +238,7 @@ std::string Stmt::type_hint() const {
if (ret_type.data_type == PrimitiveType::unknown)
return "";
else
return fmt::format("<{}>{}", ret_type.str(), is_ptr ? "ptr " : " ");
return fmt::format("<{}>", ret_type.str());
}

std::string Stmt::type() {
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <unordered_map>
#include <variant>
#include <tuple>

#include "taichi/common/core.h"
#include "taichi/util/bit.h"
#include "taichi/lang_util.h"
Expand Down Expand Up @@ -530,7 +531,6 @@ class Stmt : public IRNode {
bool erased;
bool fields_registered;
std::string tb;
bool is_ptr;
LegacyVectorType ret_type;

Stmt();
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class UnaryOpStmt : public Stmt {
class ArgLoadStmt : public Stmt {
public:
int arg_id;
bool is_ptr;

ArgLoadStmt(int arg_id, DataType dt, bool is_ptr = false) : arg_id(arg_id) {
this->ret_type = LegacyVectorType(1, dt);
Expand Down
45 changes: 45 additions & 0 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "taichi/ir/type.h"
#include "taichi/program/program.h"

TLANG_NAMESPACE_BEGIN

// Note: these primitive types should never be freed. They are supposed to live
// together with the process. This is a temporary solution. Later we should
// manage its ownership more systematically.

// This part doesn't look good, but we will remove it soon anyway.
#define PER_TYPE(x) \
DataType PrimitiveType::x = \
DataType(Program::get_type_factory().get_primitive_type( \
PrimitiveType::primitive_type::x));

#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) {
}

DataType PrimitiveType::get(PrimitiveType::primitive_type t) {
if (false) {
}
#define PER_TYPE(x) else if (t == primitive_type::x) return PrimitiveType::x;
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE
else {
TI_NOT_IMPLEMENTED
}
}

std::size_t DataType::hash() const {
if (auto primitive = dynamic_cast<const PrimitiveType *>(ptr_)) {
return (std::size_t)primitive->type;
} else {
TI_NOT_IMPLEMENTED
}
}

std::string PrimitiveType::to_string() const {
return data_type_name(DataType(this));
}
Comment on lines +6 to +43
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is completely copy-pasted from lang_util.cpp.


TLANG_NAMESPACE_END
Loading