Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Type] Adopt the new type system in Stmt #1957

Merged
merged 18 commits into from
Oct 15, 2020
20 changes: 10 additions & 10 deletions python/taichi/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,32 @@ def is_taichi_class(rhs):

# Real types

float32 = taichi_lang_core.DataType_float32
float32 = taichi_lang_core.DataType_f32
f32 = float32
float64 = taichi_lang_core.DataType_float64
float64 = taichi_lang_core.DataType_f64
f64 = float64

real_types = [f32, f64, float]
real_type_ids = [id(t) for t in real_types]

# Integer types

int8 = taichi_lang_core.DataType_int8
int8 = taichi_lang_core.DataType_i8
i8 = int8
int16 = taichi_lang_core.DataType_int16
int16 = taichi_lang_core.DataType_i16
i16 = int16
int32 = taichi_lang_core.DataType_int32
int32 = taichi_lang_core.DataType_i32
i32 = int32
int64 = taichi_lang_core.DataType_int64
int64 = taichi_lang_core.DataType_i64
i64 = int64

uint8 = taichi_lang_core.DataType_uint8
uint8 = taichi_lang_core.DataType_u8
u8 = uint8
uint16 = taichi_lang_core.DataType_uint16
uint16 = taichi_lang_core.DataType_u16
u16 = uint16
uint32 = taichi_lang_core.DataType_uint32
uint32 = taichi_lang_core.DataType_u32
u32 = uint32
uint64 = taichi_lang_core.DataType_uint64
uint64 = taichi_lang_core.DataType_u64
u64 = uint64

integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int]
Expand Down
17 changes: 10 additions & 7 deletions taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ class CCTransformer : public IRVisitor {

void visit(GlobalTemporaryStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
auto ptr_type = cc_data_type_name(stmt->element_type()) + " *";
auto ptr_type =
cc_data_type_name(stmt->element_type().ptr_removed()) + " *";
auto var = define_var(ptr_type, stmt->raw_name());
emit("{} = ({}) (ti_ctx->gtmp + {});", var, ptr_type, stmt->offset);
}
Expand All @@ -161,17 +162,19 @@ class CCTransformer : public IRVisitor {
offset = fmt::format("({} * {} + {})", offset, stride,
stmt->indices[i]->raw_name());
}
auto var = define_var(cc_data_type_name(stmt->element_type()) + " *",
stmt->raw_name());
auto var =
define_var(cc_data_type_name(stmt->element_type().ptr_removed()) + " *",
stmt->raw_name());
emit("{} = {} + {};", var, stmt->base_ptrs[0]->raw_name(), offset);
}

void visit(ArgLoadStmt *stmt) override {
if (stmt->is_ptr) {
auto var = define_var(cc_data_type_name(stmt->element_type()) + " *",
stmt->raw_name());
auto var = define_var(
cc_data_type_name(stmt->element_type().ptr_removed()) + " *",
stmt->raw_name());
emit("{} = ti_ctx->args[{}].ptr_{};", var, stmt->arg_id,
data_type_short_name(stmt->element_type()));
data_type_short_name(stmt->element_type().ptr_removed()));
} else {
auto var =
define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name());
Expand Down Expand Up @@ -377,7 +380,7 @@ class CCTransformer : public IRVisitor {
const auto dest_ptr = stmt->dest->raw_name();
const auto src_name = stmt->val->raw_name();
const auto op = cc_atomic_op_type_symbol(stmt->op_type);
const auto type = stmt->element_type();
const auto type = stmt->dest->element_type().ptr_removed();
auto var = define_var(cc_data_type_name(type), stmt->raw_name());
emit("{} = *{};", var, dest_ptr);
if (stmt->op_type == AtomicOpType::max ||
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ class KernelGen : public IRVisitor {

void visit(AtomicOpStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
auto dt = stmt->dest->element_type();
auto dt = stmt->dest->element_type().ptr_removed();
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved
if (dt == PrimitiveType::i32 ||
(TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_int64) &&
dt == PrimitiveType::i64) ||
Expand Down
9 changes: 7 additions & 2 deletions taichi/backends/opengl/opengl_data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace opengl {

inline std::string opengl_data_type_name(DataType dt) {
// https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)
dt.set_is_pointer(false);
if (dt == PrimitiveType::f32)
return "float";
else if (dt == PrimitiveType::f64)
Expand All @@ -16,8 +17,9 @@ inline std::string opengl_data_type_name(DataType dt) {
return "int";
else if (dt == PrimitiveType::i64)
return "int64_t";
else
TI_NOT_IMPLEMENTED;
else {
TI_ERROR("Type {} not supported.", dt->to_string());
}
}

inline bool is_opengl_binary_op_infix(BinaryOpType type) {
Expand All @@ -32,6 +34,9 @@ inline bool is_opengl_binary_op_different_return_type(BinaryOpType type) {
}

inline int opengl_data_address_shifter(DataType type) {
// TODO: fail loudly when feeding a pointer type to this function after type
// system upgrade.
type.set_is_pointer(false);
if (type == PrimitiveType::f32 || type == PrimitiveType::i32)
return 2;
else if (type == PrimitiveType::f64 || type == PrimitiveType::i64) {
Expand Down
15 changes: 8 additions & 7 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0);
llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty);
} else {
TI_ASSERT(!stmt->ret_type.data_type->is<PointerType>());
dest_ty = tlctx->get_data_type(stmt->ret_type.data_type);
auto dest_bits = dest_ty->getPrimitiveSizeInBits();
auto truncated = builder->CreateTrunc(
Expand Down Expand Up @@ -1227,7 +1228,7 @@ void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {
sizes[i] = raw_arg;
}

auto dt = stmt->ret_type.data_type;
auto dt = stmt->ret_type.data_type.ptr_removed();
auto base = builder->CreateBitCast(
llvm_val[stmt->base_ptrs[0]],
llvm::PointerType::get(tlctx->get_data_type(dt), 0));
Expand Down Expand Up @@ -1619,17 +1620,17 @@ void CodeGenLLVM::visit(GlobalTemporaryStmt *stmt) {
tlctx->get_constant((int64)stmt->offset));

TI_ASSERT(stmt->width() == 1);
auto ptr_type =
llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type.data_type), 0);
auto ptr_type = llvm::PointerType::get(
tlctx->get_data_type(stmt->ret_type.data_type.ptr_removed()), 0);
llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type);
}

void CodeGenLLVM::visit(ThreadLocalPtrStmt *stmt) {
auto base = get_tls_base_ptr();
TI_ASSERT(stmt->width() == 1);
auto ptr = builder->CreateGEP(base, tlctx->get_constant(stmt->offset));
auto ptr_type =
llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type.data_type), 0);
auto ptr_type = llvm::PointerType::get(
tlctx->get_data_type(stmt->ret_type.data_type.ptr_removed()), 0);
llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type);
}

Expand All @@ -1639,8 +1640,8 @@ void CodeGenLLVM::visit(BlockLocalPtrStmt *stmt) {
TI_ASSERT(stmt->width() == 1);
auto ptr = builder->CreateGEP(
base, {tlctx->get_constant(0), llvm_val[stmt->offset]});
auto ptr_type =
llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type.data_type), 0);
auto ptr_type = llvm::PointerType::get(
tlctx->get_data_type(stmt->ret_type.data_type.ptr_removed()), 0);
llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type);
}

Expand Down
19 changes: 1 addition & 18 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,6 @@ IRBuilder &current_ast_builder() {
return context->builder();
}

std::string LegacyVectorType::pointer_suffix() const {
if (is_pointer()) {
return "*";
} else {
return "";
}
}

std::string LegacyVectorType::element_type_name() const {
return fmt::format("{}{}", data_type_short_name(data_type), pointer_suffix());
}

std::string LegacyVectorType::str() const {
auto ename = element_type_name();
return fmt::format("{:4}x{}", ename, width);
}

void DecoratorRecorder::reset() {
vectorize = -1;
parallelize = 0;
Expand Down Expand Up @@ -238,7 +221,7 @@ std::string Stmt::type_hint() const {
if (ret_type.data_type == PrimitiveType::unknown)
return "";
else
return fmt::format("<{}>", ret_type.str());
return fmt::format("<{}> ", ret_type.to_string());
}

std::string Stmt::type() {
Expand Down
41 changes: 2 additions & 39 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,43 +34,6 @@ using ScratchPadOptions = std::vector<std::pair<int, SNode *>>;

IRBuilder &current_ast_builder();

struct LegacyVectorType {
private:
bool _is_pointer;

public:
int width;
DataType data_type;

LegacyVectorType(int width, DataType data_type, bool is_pointer = false)
: _is_pointer(is_pointer), width(width), data_type(data_type) {
}

LegacyVectorType()
: _is_pointer(false), width(1), data_type(PrimitiveType::unknown) {
}

bool operator==(const LegacyVectorType &o) const {
return width == o.width && data_type == o.data_type;
}

bool operator!=(const LegacyVectorType &o) const {
return !(*this == o);
}

std::string pointer_suffix() const;
std::string element_type_name() const;
std::string str() const;

bool is_pointer() const {
return _is_pointer;
}

void set_is_pointer(bool v) {
_is_pointer = v;
}
};

class DecoratorRecorder {
public:
int vectorize;
Expand Down Expand Up @@ -531,7 +494,7 @@ class Stmt : public IRNode {
bool erased;
bool fields_registered;
std::string tb;
LegacyVectorType ret_type;
DataType ret_type;
Comment on lines -534 to +497
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The centric change. Every other change in this PR is around this line.


Stmt();
Stmt(const Stmt &stmt);
Expand All @@ -553,7 +516,7 @@ class Stmt : public IRNode {
}

std::string ret_data_type_name() const {
return ret_type.str();
return ret_type->to_string();
}

std::string type_hint() const;
Expand Down
8 changes: 3 additions & 5 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -934,8 +934,7 @@ class GlobalTemporaryStmt : public Stmt {
public:
std::size_t offset;

GlobalTemporaryStmt(std::size_t offset, LegacyVectorType ret_type)
: offset(offset) {
GlobalTemporaryStmt(std::size_t offset, DataType ret_type) : offset(offset) {
this->ret_type = ret_type;
TI_STMT_REG_FIELDS;
}
Expand All @@ -952,8 +951,7 @@ class ThreadLocalPtrStmt : public Stmt {
public:
std::size_t offset;

ThreadLocalPtrStmt(std::size_t offset, LegacyVectorType ret_type)
: offset(offset) {
ThreadLocalPtrStmt(std::size_t offset, DataType ret_type) : offset(offset) {
this->ret_type = ret_type;
TI_STMT_REG_FIELDS;
}
Expand All @@ -970,7 +968,7 @@ class BlockLocalPtrStmt : public Stmt {
public:
Stmt *offset;

BlockLocalPtrStmt(Stmt *offset, LegacyVectorType ret_type) : offset(offset) {
BlockLocalPtrStmt(Stmt *offset, DataType ret_type) : offset(offset) {
this->ret_type = ret_type;
TI_STMT_REG_FIELDS;
}
Expand Down
40 changes: 37 additions & 3 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ TLANG_NAMESPACE_BEGIN
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) {
DataType::DataType() : data_type(*this), ptr_(PrimitiveType::unknown.ptr_) {
}

DataType PrimitiveType::get(PrimitiveType::primitive_type t) {
Expand All @@ -31,15 +31,49 @@ DataType PrimitiveType::get(PrimitiveType::primitive_type t) {
}

std::size_t DataType::hash() const {
if (auto primitive = dynamic_cast<const PrimitiveType *>(ptr_)) {
if (auto primitive = ptr_->cast<PrimitiveType>()) {
return (std::size_t)primitive->type;
} else if (auto pointer = ptr_->cast<PointerType>()) {
return 10007 + DataType(pointer->get_pointee_type()).hash();
} else {
TI_NOT_IMPLEMENTED
}
}

bool DataType::is_pointer() const {
return ptr_->is<PointerType>();
}

void DataType::set_is_pointer(bool is_ptr) {
if (is_ptr && !ptr_->is<PointerType>()) {
ptr_ = Program::get_type_factory().get_pointer_type(ptr_);
}
if (!is_ptr && ptr_->is<PointerType>()) {
ptr_ = ptr_->cast<PointerType>()->get_pointee_type();
}
}

DataType DataType::ptr_removed() const {
auto t = ptr_;
auto ptr_type = t->cast<PointerType>();
if (ptr_type) {
return DataType(ptr_type->get_pointee_type());
} else {
return *this;
}
}

std::string PrimitiveType::to_string() const {
return data_type_name(DataType(this));
return data_type_name(DataType(const_cast<PrimitiveType *>(this)));
}

DataType LegacyVectorType(int width, DataType data_type, bool is_pointer) {
if (is_pointer) {
return Program::get_type_factory().get_pointer_type(data_type.get_ptr());
} else {
return data_type;
}
TI_ASSERT(width == 1);
}

TLANG_NAMESPACE_END
Loading