Skip to content

Commit

Permalink
[Type] Adopt the new type system in Stmt (#1957)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu authored Oct 15, 2020
1 parent 706c519 commit 6b3db04
Show file tree
Hide file tree
Showing 23 changed files with 233 additions and 166 deletions.
20 changes: 10 additions & 10 deletions python/taichi/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,32 @@ def is_taichi_class(rhs):

# Real types

float32 = taichi_lang_core.DataType_float32
float32 = taichi_lang_core.DataType_f32
f32 = float32
float64 = taichi_lang_core.DataType_float64
float64 = taichi_lang_core.DataType_f64
f64 = float64

real_types = [f32, f64, float]
real_type_ids = [id(t) for t in real_types]

# Integer types

int8 = taichi_lang_core.DataType_int8
int8 = taichi_lang_core.DataType_i8
i8 = int8
int16 = taichi_lang_core.DataType_int16
int16 = taichi_lang_core.DataType_i16
i16 = int16
int32 = taichi_lang_core.DataType_int32
int32 = taichi_lang_core.DataType_i32
i32 = int32
int64 = taichi_lang_core.DataType_int64
int64 = taichi_lang_core.DataType_i64
i64 = int64

uint8 = taichi_lang_core.DataType_uint8
uint8 = taichi_lang_core.DataType_u8
u8 = uint8
uint16 = taichi_lang_core.DataType_uint16
uint16 = taichi_lang_core.DataType_u16
u16 = uint16
uint32 = taichi_lang_core.DataType_uint32
uint32 = taichi_lang_core.DataType_u32
u32 = uint32
uint64 = taichi_lang_core.DataType_uint64
uint64 = taichi_lang_core.DataType_u64
u64 = uint64

integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int]
Expand Down
17 changes: 10 additions & 7 deletions taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ class CCTransformer : public IRVisitor {

void visit(GlobalTemporaryStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
auto ptr_type = cc_data_type_name(stmt->element_type()) + " *";
auto ptr_type =
cc_data_type_name(stmt->element_type().ptr_removed()) + " *";
auto var = define_var(ptr_type, stmt->raw_name());
emit("{} = ({}) (ti_ctx->gtmp + {});", var, ptr_type, stmt->offset);
}
Expand All @@ -161,17 +162,19 @@ class CCTransformer : public IRVisitor {
offset = fmt::format("({} * {} + {})", offset, stride,
stmt->indices[i]->raw_name());
}
auto var = define_var(cc_data_type_name(stmt->element_type()) + " *",
stmt->raw_name());
auto var =
define_var(cc_data_type_name(stmt->element_type().ptr_removed()) + " *",
stmt->raw_name());
emit("{} = {} + {};", var, stmt->base_ptrs[0]->raw_name(), offset);
}

void visit(ArgLoadStmt *stmt) override {
if (stmt->is_ptr) {
auto var = define_var(cc_data_type_name(stmt->element_type()) + " *",
stmt->raw_name());
auto var = define_var(
cc_data_type_name(stmt->element_type().ptr_removed()) + " *",
stmt->raw_name());
emit("{} = ti_ctx->args[{}].ptr_{};", var, stmt->arg_id,
data_type_short_name(stmt->element_type()));
data_type_short_name(stmt->element_type().ptr_removed()));
} else {
auto var =
define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name());
Expand Down Expand Up @@ -377,7 +380,7 @@ class CCTransformer : public IRVisitor {
const auto dest_ptr = stmt->dest->raw_name();
const auto src_name = stmt->val->raw_name();
const auto op = cc_atomic_op_type_symbol(stmt->op_type);
const auto type = stmt->element_type();
const auto type = stmt->dest->element_type().ptr_removed();
auto var = define_var(cc_data_type_name(type), stmt->raw_name());
emit("{} = *{};", var, dest_ptr);
if (stmt->op_type == AtomicOpType::max ||
Expand Down
5 changes: 3 additions & 2 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,15 +349,16 @@ class KernelCodegen : public IRVisitor {

void visit(GlobalTemporaryStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
const auto dt = metal_data_type_name(stmt->element_type());
const auto dt = metal_data_type_name(stmt->element_type().ptr_removed());
emit("device {}* {} = reinterpret_cast<device {}*>({} + {});", dt,
stmt->raw_name(), dt, kGlobalTmpsBufferName, stmt->offset);
}

void visit(ThreadLocalPtrStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
emit("thread auto* {} = reinterpret_cast<thread {}*>({} + {});",
stmt->raw_name(), metal_data_type_name(stmt->element_type()),
stmt->raw_name(),
metal_data_type_name(stmt->element_type().ptr_removed()),
kTlsBufferName, stmt->offset);
}

Expand Down
1 change: 1 addition & 0 deletions taichi/backends/metal/data_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ TLANG_NAMESPACE_BEGIN
namespace metal {

MetalDataType to_metal_type(DataType dt) {
dt.set_is_pointer(false);
#define METAL_CASE(x) else if (dt == PrimitiveType::x) return MetalDataType::x
if (false) {
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ class KernelGen : public IRVisitor {

void visit(AtomicOpStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
auto dt = stmt->dest->element_type();
auto dt = stmt->dest->element_type().ptr_removed();
if (dt == PrimitiveType::i32 ||
(TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_int64) &&
dt == PrimitiveType::i64) ||
Expand Down
9 changes: 7 additions & 2 deletions taichi/backends/opengl/opengl_data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace opengl {

inline std::string opengl_data_type_name(DataType dt) {
// https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)
dt.set_is_pointer(false);
if (dt == PrimitiveType::f32)
return "float";
else if (dt == PrimitiveType::f64)
Expand All @@ -16,8 +17,9 @@ inline std::string opengl_data_type_name(DataType dt) {
return "int";
else if (dt == PrimitiveType::i64)
return "int64_t";
else
TI_NOT_IMPLEMENTED;
else {
TI_ERROR("Type {} not supported.", dt->to_string());
}
}

inline bool is_opengl_binary_op_infix(BinaryOpType type) {
Expand All @@ -32,6 +34,9 @@ inline bool is_opengl_binary_op_different_return_type(BinaryOpType type) {
}

inline int opengl_data_address_shifter(DataType type) {
// TODO: fail loudly when feeding a pointer type to this function, after type
// system upgrade.
type.set_is_pointer(false);
if (type == PrimitiveType::f32 || type == PrimitiveType::i32)
return 2;
else if (type == PrimitiveType::f64 || type == PrimitiveType::i64) {
Expand Down
15 changes: 8 additions & 7 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0);
llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty);
} else {
TI_ASSERT(!stmt->ret_type.data_type->is<PointerType>());
dest_ty = tlctx->get_data_type(stmt->ret_type.data_type);
auto dest_bits = dest_ty->getPrimitiveSizeInBits();
auto truncated = builder->CreateTrunc(
Expand Down Expand Up @@ -1227,7 +1228,7 @@ void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {
sizes[i] = raw_arg;
}

auto dt = stmt->ret_type.data_type;
auto dt = stmt->ret_type.data_type.ptr_removed();
auto base = builder->CreateBitCast(
llvm_val[stmt->base_ptrs[0]],
llvm::PointerType::get(tlctx->get_data_type(dt), 0));
Expand Down Expand Up @@ -1619,17 +1620,17 @@ void CodeGenLLVM::visit(GlobalTemporaryStmt *stmt) {
tlctx->get_constant((int64)stmt->offset));

TI_ASSERT(stmt->width() == 1);
auto ptr_type =
llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type.data_type), 0);
auto ptr_type = llvm::PointerType::get(
tlctx->get_data_type(stmt->ret_type.data_type.ptr_removed()), 0);
llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type);
}

void CodeGenLLVM::visit(ThreadLocalPtrStmt *stmt) {
auto base = get_tls_base_ptr();
TI_ASSERT(stmt->width() == 1);
auto ptr = builder->CreateGEP(base, tlctx->get_constant(stmt->offset));
auto ptr_type =
llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type.data_type), 0);
auto ptr_type = llvm::PointerType::get(
tlctx->get_data_type(stmt->ret_type.data_type.ptr_removed()), 0);
llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type);
}

Expand All @@ -1639,8 +1640,8 @@ void CodeGenLLVM::visit(BlockLocalPtrStmt *stmt) {
TI_ASSERT(stmt->width() == 1);
auto ptr = builder->CreateGEP(
base, {tlctx->get_constant(0), llvm_val[stmt->offset]});
auto ptr_type =
llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type.data_type), 0);
auto ptr_type = llvm::PointerType::get(
tlctx->get_data_type(stmt->ret_type.data_type.ptr_removed()), 0);
llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type);
}

Expand Down
19 changes: 1 addition & 18 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,6 @@ IRBuilder &current_ast_builder() {
return context->builder();
}

std::string LegacyVectorType::pointer_suffix() const {
if (is_pointer()) {
return "*";
} else {
return "";
}
}

std::string LegacyVectorType::element_type_name() const {
return fmt::format("{}{}", data_type_short_name(data_type), pointer_suffix());
}

std::string LegacyVectorType::str() const {
auto ename = element_type_name();
return fmt::format("{:4}x{}", ename, width);
}

void DecoratorRecorder::reset() {
vectorize = -1;
parallelize = 0;
Expand Down Expand Up @@ -238,7 +221,7 @@ std::string Stmt::type_hint() const {
if (ret_type.data_type == PrimitiveType::unknown)
return "";
else
return fmt::format("<{}>", ret_type.str());
return fmt::format("<{}> ", ret_type.to_string());
}

std::string Stmt::type() {
Expand Down
41 changes: 2 additions & 39 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,43 +34,6 @@ using ScratchPadOptions = std::vector<std::pair<int, SNode *>>;

IRBuilder &current_ast_builder();

struct LegacyVectorType {
private:
bool _is_pointer;

public:
int width;
DataType data_type;

LegacyVectorType(int width, DataType data_type, bool is_pointer = false)
: _is_pointer(is_pointer), width(width), data_type(data_type) {
}

LegacyVectorType()
: _is_pointer(false), width(1), data_type(PrimitiveType::unknown) {
}

bool operator==(const LegacyVectorType &o) const {
return width == o.width && data_type == o.data_type;
}

bool operator!=(const LegacyVectorType &o) const {
return !(*this == o);
}

std::string pointer_suffix() const;
std::string element_type_name() const;
std::string str() const;

bool is_pointer() const {
return _is_pointer;
}

void set_is_pointer(bool v) {
_is_pointer = v;
}
};

class DecoratorRecorder {
public:
int vectorize;
Expand Down Expand Up @@ -531,7 +494,7 @@ class Stmt : public IRNode {
bool erased;
bool fields_registered;
std::string tb;
LegacyVectorType ret_type;
DataType ret_type;

Stmt();
Stmt(const Stmt &stmt);
Expand All @@ -553,7 +516,7 @@ class Stmt : public IRNode {
}

std::string ret_data_type_name() const {
return ret_type.str();
return ret_type->to_string();
}

std::string type_hint() const;
Expand Down
8 changes: 3 additions & 5 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -934,8 +934,7 @@ class GlobalTemporaryStmt : public Stmt {
public:
std::size_t offset;

GlobalTemporaryStmt(std::size_t offset, LegacyVectorType ret_type)
: offset(offset) {
GlobalTemporaryStmt(std::size_t offset, DataType ret_type) : offset(offset) {
this->ret_type = ret_type;
TI_STMT_REG_FIELDS;
}
Expand All @@ -952,8 +951,7 @@ class ThreadLocalPtrStmt : public Stmt {
public:
std::size_t offset;

ThreadLocalPtrStmt(std::size_t offset, LegacyVectorType ret_type)
: offset(offset) {
ThreadLocalPtrStmt(std::size_t offset, DataType ret_type) : offset(offset) {
this->ret_type = ret_type;
TI_STMT_REG_FIELDS;
}
Expand All @@ -970,7 +968,7 @@ class BlockLocalPtrStmt : public Stmt {
public:
Stmt *offset;

BlockLocalPtrStmt(Stmt *offset, LegacyVectorType ret_type) : offset(offset) {
BlockLocalPtrStmt(Stmt *offset, DataType ret_type) : offset(offset) {
this->ret_type = ret_type;
TI_STMT_REG_FIELDS;
}
Expand Down
40 changes: 37 additions & 3 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ TLANG_NAMESPACE_BEGIN
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) {
DataType::DataType() : data_type(*this), ptr_(PrimitiveType::unknown.ptr_) {
}

DataType PrimitiveType::get(PrimitiveType::primitive_type t) {
Expand All @@ -31,15 +31,49 @@ DataType PrimitiveType::get(PrimitiveType::primitive_type t) {
}

std::size_t DataType::hash() const {
if (auto primitive = dynamic_cast<const PrimitiveType *>(ptr_)) {
if (auto primitive = ptr_->cast<PrimitiveType>()) {
return (std::size_t)primitive->type;
} else if (auto pointer = ptr_->cast<PointerType>()) {
return 10007 + DataType(pointer->get_pointee_type()).hash();
} else {
TI_NOT_IMPLEMENTED
}
}

bool DataType::is_pointer() const {
return ptr_->is<PointerType>();
}

void DataType::set_is_pointer(bool is_ptr) {
if (is_ptr && !ptr_->is<PointerType>()) {
ptr_ = Program::get_type_factory().get_pointer_type(ptr_);
}
if (!is_ptr && ptr_->is<PointerType>()) {
ptr_ = ptr_->cast<PointerType>()->get_pointee_type();
}
}

DataType DataType::ptr_removed() const {
auto t = ptr_;
auto ptr_type = t->cast<PointerType>();
if (ptr_type) {
return DataType(ptr_type->get_pointee_type());
} else {
return *this;
}
}

std::string PrimitiveType::to_string() const {
return data_type_name(DataType(this));
return data_type_name(DataType(const_cast<PrimitiveType *>(this)));
}

DataType LegacyVectorType(int width, DataType data_type, bool is_pointer) {
TI_ASSERT(width == 1);
if (is_pointer) {
return Program::get_type_factory().get_pointer_type(data_type.get_ptr());
} else {
return data_type;
}
}

TLANG_NAMESPACE_END
Loading

0 comments on commit 6b3db04

Please sign in to comment.