Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Opt] [refactor] Avoid throwing exceptions in alg_simp and let it support more types #1060

Merged
merged 4 commits into from
May 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ Stmt *Stmt::insert_after_me(std::unique_ptr<Stmt> &&new_stmt) {
void Stmt::replace_with(Stmt *new_stmt) {
auto root = get_ir_root();
irpass::replace_all_usages_with(root, this, new_stmt);
// Note: the current structure should have been destroyed now..
}

void Stmt::replace_with(VecStatement &&new_statements, bool replace_usages) {
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void re_id(IRNode *root);
void flag_access(IRNode *root);
void die(IRNode *root);
void simplify(IRNode *root, Kernel *kernel = nullptr);
void alg_simp(IRNode *root, const CompileConfig &config);
bool alg_simp(IRNode *root, const CompileConfig &config);
void whole_kernel_cse(IRNode *root);
void variable_optimization(IRNode *root, bool after_lower_access);
void extract_constant(IRNode *root);
Expand Down
149 changes: 149 additions & 0 deletions taichi/lang_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,155 @@ DataType promoted_type(DataType a, DataType b) {
}
return mapping[std::make_pair(a, b)];
}

std::string TypedConstant::stringify() const {
if (dt == DataType::f32) {
return fmt::format("{}", val_f32);
} else if (dt == DataType::i32) {
return fmt::format("{}", val_i32);
} else if (dt == DataType::i64) {
return fmt::format("{}", val_i64);
} else if (dt == DataType::f64) {
return fmt::format("{}", val_f64);
} else if (dt == DataType::i8) {
return fmt::format("{}", val_i8);
} else if (dt == DataType::i16) {
return fmt::format("{}", val_i16);
} else if (dt == DataType::u8) {
return fmt::format("{}", val_u8);
} else if (dt == DataType::u16) {
return fmt::format("{}", val_u16);
} else if (dt == DataType::u32) {
return fmt::format("{}", val_u32);
} else if (dt == DataType::u64) {
return fmt::format("{}", val_u64);
} else {
TI_P(data_type_name(dt));
TI_NOT_IMPLEMENTED
return "";
}
}

bool TypedConstant::equal_type_and_value(const TypedConstant &o) const {
if (dt != o.dt)
return false;
if (dt == DataType::f32) {
return val_f32 == o.val_f32;
} else if (dt == DataType::i32) {
return val_i32 == o.val_i32;
} else if (dt == DataType::i64) {
return val_i64 == o.val_i64;
} else if (dt == DataType::f64) {
return val_f64 == o.val_f64;
} else if (dt == DataType::i8) {
return val_i8 == o.val_i8;
} else if (dt == DataType::i16) {
return val_i16 == o.val_i16;
} else if (dt == DataType::u8) {
return val_u8 == o.val_u8;
} else if (dt == DataType::u16) {
return val_u16 == o.val_u16;
} else if (dt == DataType::u32) {
return val_u32 == o.val_u32;
} else if (dt == DataType::u64) {
return val_u64 == o.val_u64;
} else {
TI_NOT_IMPLEMENTED
return false;
}
}

int32 &TypedConstant::val_int32() {
TI_ASSERT(get_data_type<int32>() == dt);
return val_i32;
}

float32 &TypedConstant::val_float32() {
TI_ASSERT(get_data_type<float32>() == dt);
return val_f32;
}

int64 &TypedConstant::val_int64() {
TI_ASSERT(get_data_type<int64>() == dt);
return val_i64;
}

float64 &TypedConstant::val_float64() {
TI_ASSERT(get_data_type<float64>() == dt);
return val_f64;
}

int8 &TypedConstant::val_int8() {
TI_ASSERT(get_data_type<int8>() == dt);
return val_i8;
}

int16 &TypedConstant::val_int16() {
TI_ASSERT(get_data_type<int16>() == dt);
return val_i16;
}

uint8 &TypedConstant::val_uint8() {
TI_ASSERT(get_data_type<uint8>() == dt);
return val_u8;
}

uint16 &TypedConstant::val_uint16() {
TI_ASSERT(get_data_type<uint16>() == dt);
return val_u16;
}

uint32 &TypedConstant::val_uint32() {
TI_ASSERT(get_data_type<uint32>() == dt);
return val_u32;
}

uint64 &TypedConstant::val_uint64() {
TI_ASSERT(get_data_type<uint64>() == dt);
return val_u64;
}

int64 TypedConstant::val_int() const {
TI_ASSERT(is_signed(dt));
if (dt == DataType::i32) {
return val_i32;
} else if (dt == DataType::i64) {
return val_i64;
} else if (dt == DataType::i8) {
return val_i8;
} else if (dt == DataType::i16) {
return val_i16;
} else {
TI_NOT_IMPLEMENTED
}
}

uint64 TypedConstant::val_uint() const {
TI_ASSERT(is_unsigned(dt));
if (dt == DataType::u32) {
return val_u32;
} else if (dt == DataType::u64) {
return val_u64;
} else if (dt == DataType::u8) {
return val_u8;
} else if (dt == DataType::u16) {
return val_u16;
} else {
TI_NOT_IMPLEMENTED
}
}

float64 TypedConstant::val_float() const {
TI_ASSERT(is_real(dt));
if (dt == DataType::f32) {
return val_f32;
} else if (dt == DataType::f64) {
return val_f64;
} else {
TI_NOT_IMPLEMENTED
}
}

} // namespace lang

void initialize_benchmark() {
Expand Down
89 changes: 15 additions & 74 deletions taichi/lang_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,86 +199,27 @@ class TypedConstant {
TypedConstant(float64 x) : dt(DataType::f64), val_f64(x) {
}

std::string stringify() const {
if (dt == DataType::f32) {
return fmt::format("{}", val_f32);
} else if (dt == DataType::i32) {
return fmt::format("{}", val_i32);
} else if (dt == DataType::i64) {
return fmt::format("{}", val_i64);
} else if (dt == DataType::f64) {
return fmt::format("{}", val_f64);
} else if (dt == DataType::i8) {
return fmt::format("{}", val_i8);
} else if (dt == DataType::i16) {
return fmt::format("{}", val_i16);
} else if (dt == DataType::u8) {
return fmt::format("{}", val_u8);
} else if (dt == DataType::u16) {
return fmt::format("{}", val_u16);
} else if (dt == DataType::u32) {
return fmt::format("{}", val_u32);
} else if (dt == DataType::u64) {
return fmt::format("{}", val_u64);
} else {
TI_P(data_type_name(dt));
TI_NOT_IMPLEMENTED
return "";
}
}
std::string stringify() const;

bool equal_type_and_value(const TypedConstant &o) const {
if (dt != o.dt)
return false;
if (dt == DataType::f32) {
return val_f32 == o.val_f32;
} else if (dt == DataType::i32) {
return val_i32 == o.val_i32;
} else if (dt == DataType::i64) {
return val_i64 == o.val_i64;
} else if (dt == DataType::f64) {
return val_f64 == o.val_f64;
} else if (dt == DataType::i8) {
return val_i8 == o.val_i8;
} else if (dt == DataType::i16) {
return val_i16 == o.val_i16;
} else if (dt == DataType::u8) {
return val_u8 == o.val_u8;
} else if (dt == DataType::u16) {
return val_u16 == o.val_u16;
} else if (dt == DataType::u32) {
return val_u32 == o.val_u32;
} else if (dt == DataType::u64) {
return val_u64 == o.val_u64;
} else {
TI_NOT_IMPLEMENTED
return false;
}
}
bool equal_type_and_value(const TypedConstant &o) const;

bool operator==(const TypedConstant &o) const {
return equal_type_and_value(o);
}

int32 &val_int32() {
TI_ASSERT(get_data_type<int32>() == dt);
return val_i32;
}

float32 &val_float32() {
TI_ASSERT(get_data_type<float32>() == dt);
return val_f32;
}

int64 &val_int64() {
TI_ASSERT(get_data_type<int64>() == dt);
return val_i64;
}

float64 &val_float64() {
TI_ASSERT(get_data_type<float64>() == dt);
return val_f64;
}
int32 &val_int32();
float32 &val_float32();
int64 &val_int64();
float64 &val_float64();
int8 &val_int8();
int16 &val_int16();
uint8 &val_uint8();
uint16 &val_uint16();
uint32 &val_uint32();
uint64 &val_uint64();
int64 val_int() const;
uint64 val_uint() const;
float64 val_float() const;
};

inline std::string make_list(const std::vector<std::string> &data,
Expand Down
Loading