Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Type] [refactor] Make PrimitiveTypeID a public enum #1965

Merged
merged 2 commits into from
Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,20 @@ TLANG_NAMESPACE_BEGIN
// manage its ownership more systematically.

// This part doesn't look good, but we will remove it soon anyway.
#define PER_TYPE(x) \
DataType PrimitiveType::x = \
DataType(TypeFactory::get_instance().get_primitive_type( \
PrimitiveType::primitive_type::x));
#define PER_TYPE(x) \
DataType PrimitiveType::x = DataType( \
TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::x));

#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) {
}

DataType PrimitiveType::get(PrimitiveType::primitive_type t) {
DataType PrimitiveType::get(PrimitiveTypeID t) {
if (false) {
}
#define PER_TYPE(x) else if (t == primitive_type::x) return PrimitiveType::x;
#define PER_TYPE(x) else if (t == PrimitiveTypeID::x) return PrimitiveType::x;
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE
else {
Expand Down
18 changes: 9 additions & 9 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

TLANG_NAMESPACE_BEGIN

enum class PrimitiveTypeID : int {
#define PER_TYPE(x) x,
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE
};

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The core change, everything else is to pass the build.

class Type {
public:
virtual std::string to_string() const = 0;
Expand Down Expand Up @@ -90,24 +96,18 @@ class DataType {

class PrimitiveType : public Type {
public:
enum class primitive_type : int {
#define PER_TYPE(x) x,
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE
};

#define PER_TYPE(x) static DataType x;
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

primitive_type type;
PrimitiveTypeID type;

PrimitiveType(primitive_type type) : type(type) {
PrimitiveType(PrimitiveTypeID type) : type(type) {
}

std::string to_string() const override;

static DataType get(primitive_type type);
static DataType get(PrimitiveTypeID type);
};

class PointerType : public Type {
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ TypeFactory &TypeFactory::get_instance() {
return *type_factory;
}

Type *TypeFactory::get_primitive_type(PrimitiveType::primitive_type id) {
Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
std::lock_guard<std::mutex> _(mut_);

if (primitive_types_.find(id) == primitive_types_.end()) {
Expand Down
5 changes: 2 additions & 3 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TypeFactory {
public:
static TypeFactory &get_instance();

Type *get_primitive_type(PrimitiveType::primitive_type id);
Type *get_primitive_type(PrimitiveTypeID id);

Type *get_vector_type(int num_elements, Type *element);

Expand All @@ -19,8 +19,7 @@ class TypeFactory {
private:
TypeFactory();

std::unordered_map<PrimitiveType::primitive_type, std::unique_ptr<Type>>
primitive_types_;
std::unordered_map<PrimitiveTypeID, std::unique_ptr<Type>> primitive_types_;

// TODO: use unordered map
std::map<std::pair<int, Type *>, std::unique_ptr<Type>> vector_types_;
Expand Down
6 changes: 3 additions & 3 deletions taichi/lang_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,10 @@ class TypePromotionMapping {

private:
std::map<
std::pair<PrimitiveType::primitive_type, PrimitiveType::primitive_type>,
PrimitiveType::primitive_type>
std::pair<PrimitiveTypeID, PrimitiveTypeID>,
PrimitiveTypeID>
mapping;
static PrimitiveType::primitive_type to_primitive_type(const DataType d_) {
static PrimitiveTypeID to_primitive_type(const DataType d_) {
Type *d = d_.get_ptr();
if (d->is<PointerType>()) {
d = d->as<PointerType>()->get_pointee_type();
Expand Down
24 changes: 12 additions & 12 deletions taichi/lang_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,29 +50,29 @@ inline DataType get_data_type() {
}

template <typename T>
inline PrimitiveType::primitive_type get_primitive_data_type() {
inline PrimitiveTypeID get_primitive_data_type() {
if (std::is_same<T, float32>()) {
return PrimitiveType::primitive_type::f32;
return PrimitiveTypeID::f32;
} else if (std::is_same<T, float64>()) {
return PrimitiveType::primitive_type::f64;
return PrimitiveTypeID::f64;
} else if (std::is_same<T, bool>()) {
return PrimitiveType::primitive_type::u1;
return PrimitiveTypeID::u1;
} else if (std::is_same<T, int8>()) {
return PrimitiveType::primitive_type::i8;
return PrimitiveTypeID::i8;
} else if (std::is_same<T, int16>()) {
return PrimitiveType::primitive_type::i16;
return PrimitiveTypeID::i16;
} else if (std::is_same<T, int32>()) {
return PrimitiveType::primitive_type::i32;
return PrimitiveTypeID::i32;
} else if (std::is_same<T, int64>()) {
return PrimitiveType::primitive_type::i64;
return PrimitiveTypeID::i64;
} else if (std::is_same<T, uint8>()) {
return PrimitiveType::primitive_type::u8;
return PrimitiveTypeID::u8;
} else if (std::is_same<T, uint16>()) {
return PrimitiveType::primitive_type::u16;
return PrimitiveTypeID::u16;
} else if (std::is_same<T, uint32>()) {
return PrimitiveType::primitive_type::u32;
return PrimitiveTypeID::u32;
} else if (std::is_same<T, uint64>()) {
return PrimitiveType::primitive_type::u64;
return PrimitiveTypeID::u64;
} else {
TI_NOT_IMPLEMENTED;
}
Expand Down
4 changes: 2 additions & 2 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ void export_lang(py::module &m) {
if (t.size() != 1)
throw std::runtime_error("Invalid state!");

DataType dt = PrimitiveType::get(
(PrimitiveType::primitive_type)(t[0].cast<std::size_t>()));
DataType dt =
PrimitiveType::get((PrimitiveTypeID)(t[0].cast<std::size_t>()));

return dt;
}));
Expand Down