From 0a1ec46e942a86d60347c3a73f7d85411250c7a6 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 25 May 2020 20:28:08 -0400 Subject: [PATCH 1/4] [opt] [refactor] Avoid throwing exceptions in alg_simp and let it support more types --- taichi/ir/ir.cpp | 1 - taichi/ir/transforms.h | 2 +- taichi/lang_util.cpp | 149 +++++++++++++++++++++++++++++++++ taichi/lang_util.h | 89 ++++---------------- taichi/transforms/alg_simp.cpp | 91 +++++++++----------- 5 files changed, 204 insertions(+), 128 deletions(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index f734f838bdf4e..109088aa1d045 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -262,7 +262,6 @@ Stmt *Stmt::insert_after_me(std::unique_ptr &&new_stmt) { void Stmt::replace_with(Stmt *new_stmt) { auto root = get_ir_root(); irpass::replace_all_usages_with(root, this, new_stmt); - // Note: the current structure should have been destroyed now.. } void Stmt::replace_with(VecStatement &&new_statements, bool replace_usages) { diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 80577cbd964b1..48a69e76bb047 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -14,7 +14,7 @@ void re_id(IRNode *root); void flag_access(IRNode *root); void die(IRNode *root); void simplify(IRNode *root, Kernel *kernel = nullptr); -void alg_simp(IRNode *root, const CompileConfig &config); +bool alg_simp(IRNode *root, const CompileConfig &config); void whole_kernel_cse(IRNode *root); void variable_optimization(IRNode *root, bool after_lower_access); void extract_constant(IRNode *root); diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index e94ce8564bd74..e3b058e82531a 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -288,6 +288,155 @@ DataType promoted_type(DataType a, DataType b) { } return mapping[std::make_pair(a, b)]; } + +std::string TypedConstant::stringify() const { + if (dt == DataType::f32) { + return fmt::format("{}", val_f32); + } else if (dt == DataType::i32) { + return fmt::format("{}", val_i32); + } else if (dt == DataType::i64) { + return fmt::format("{}", val_i64); + } else if (dt == DataType::f64) { + return fmt::format("{}", val_f64); + } else if (dt == DataType::i8) { + return fmt::format("{}", val_i8); + } else if (dt == DataType::i16) { + return fmt::format("{}", val_i16); + } else if (dt == DataType::u8) { + return fmt::format("{}", val_u8); + } else if (dt == DataType::u16) { + return fmt::format("{}", val_u16); + } else if (dt == DataType::u32) { + return fmt::format("{}", val_u32); + } else if (dt == DataType::u64) { + return fmt::format("{}", val_u64); + } else { + TI_P(data_type_name(dt)); + TI_NOT_IMPLEMENTED + return ""; + } +} + +bool TypedConstant::equal_type_and_value(const TypedConstant &o) const { + if (dt != o.dt) + return false; + if (dt == DataType::f32) { + return val_f32 == o.val_f32; + } else if (dt == DataType::i32) { + return val_i32 == o.val_i32; + } else if (dt == DataType::i64) { + return val_i64 == o.val_i64; + } else if (dt == DataType::f64) { + return val_f64 == o.val_f64; + } else if (dt == DataType::i8) { + return val_i8 == o.val_i8; + } else if (dt == DataType::i16) { + return val_i16 == o.val_i16; + } else if (dt == DataType::u8) { + return val_u8 == o.val_u8; + } else if (dt == DataType::u16) { + return val_u16 == o.val_u16; + } else if (dt == DataType::u32) { + return val_u32 == o.val_u32; + } else if (dt == DataType::u64) { + return val_u64 == o.val_u64; + } else { + TI_NOT_IMPLEMENTED + return false; + } +} + +int32 &TypedConstant::val_int32() { + TI_ASSERT(get_data_type() == dt); + return val_i32; +} + +float32 &TypedConstant::val_float32() { + TI_ASSERT(get_data_type() == dt); + return val_f32; +} + +int64 &TypedConstant::val_int64() { + TI_ASSERT(get_data_type() == dt); + return val_i64; +} + +float64 &TypedConstant::val_float64() { + TI_ASSERT(get_data_type() == dt); + return val_f64; +} + +int8 &TypedConstant::val_int8() { + TI_ASSERT(get_data_type() == dt); + return val_i8; +} + +int16 &TypedConstant::val_int16() { + TI_ASSERT(get_data_type() == dt); + return val_i16; +} + +uint8 &TypedConstant::val_uint8() { + TI_ASSERT(get_data_type() == dt); + return val_u8; +} + +uint16 &TypedConstant::val_uint16() { + TI_ASSERT(get_data_type() == dt); + return val_u16; +} + +uint32 &TypedConstant::val_uint32() { + TI_ASSERT(get_data_type() == dt); + return val_u32; +} + +uint64 &TypedConstant::val_uint64() { + TI_ASSERT(get_data_type() == dt); + return val_u64; +} + +int64 TypedConstant::val_int() const { + TI_ASSERT(is_signed(dt)); + if (dt == DataType::i32) { + return val_i32; + } else if (dt == DataType::i64) { + return val_i64; + } else if (dt == DataType::i8) { + return val_i8; + } else if (dt == DataType::i16) { + return val_i16; + } else { + TI_NOT_IMPLEMENTED + } +} + +uint64 TypedConstant::val_uint() const { + TI_ASSERT(is_unsigned(dt)); + if (dt == DataType::u32) { + return val_u32; + } else if (dt == DataType::u64) { + return val_u64; + } else if (dt == DataType::u8) { + return val_u8; + } else if (dt == DataType::u16) { + return val_u16; + } else { + TI_NOT_IMPLEMENTED + } +} + +float64 TypedConstant::val_float() const { + TI_ASSERT(is_real(dt)); + if (dt == DataType::f32) { + return val_f32; + } else if (dt == DataType::f64) { + return val_f64; + } else { + TI_NOT_IMPLEMENTED + } +} + } // namespace lang void initialize_benchmark() { diff --git a/taichi/lang_util.h b/taichi/lang_util.h index 4c80b1dc53067..f53fca8c03af5 100644 --- a/taichi/lang_util.h +++ b/taichi/lang_util.h @@ -199,86 +199,27 @@ class TypedConstant { TypedConstant(float64 x) : dt(DataType::f64), val_f64(x) { } - std::string stringify() const { - if (dt == DataType::f32) { - return fmt::format("{}", val_f32); - } else if (dt == DataType::i32) { - return fmt::format("{}", val_i32); - } else if (dt == DataType::i64) { - return fmt::format("{}", val_i64); - } else if (dt == DataType::f64) { - return fmt::format("{}", val_f64); - } else if (dt == DataType::i8) { - return fmt::format("{}", val_i8); - } else if (dt == DataType::i16) { - return fmt::format("{}", val_i16); - } else if (dt == DataType::u8) { - return fmt::format("{}", val_u8); - } else if (dt == DataType::u16) { - return fmt::format("{}", val_u16); - } else if (dt == DataType::u32) { - return fmt::format("{}", val_u32); - } else if (dt == DataType::u64) { - return fmt::format("{}", val_u64); - } else { - TI_P(data_type_name(dt)); - TI_NOT_IMPLEMENTED - return ""; - } - } + std::string stringify() const; - bool equal_type_and_value(const TypedConstant &o) const { - if (dt != o.dt) - return false; - if (dt == DataType::f32) { - return val_f32 == o.val_f32; - } else if (dt == DataType::i32) { - return val_i32 == o.val_i32; - } else if (dt == DataType::i64) { - return val_i64 == o.val_i64; - } else if (dt == DataType::f64) { - return val_f64 == o.val_f64; - } else if (dt == DataType::i8) { - return val_i8 == o.val_i8; - } else if (dt == DataType::i16) { - return val_i16 == o.val_i16; - } else if (dt == DataType::u8) { - return val_u8 == o.val_u8; - } else if (dt == DataType::u16) { - return val_u16 == o.val_u16; - } else if (dt == DataType::u32) { - return val_u32 == o.val_u32; - } else if (dt == DataType::u64) { - return val_u64 == o.val_u64; - } else { - TI_NOT_IMPLEMENTED - return false; - } - } + bool equal_type_and_value(const TypedConstant &o) const; bool operator==(const TypedConstant &o) const { return equal_type_and_value(o); } - int32 &val_int32() { - TI_ASSERT(get_data_type() == dt); - return val_i32; - } - - float32 &val_float32() { - TI_ASSERT(get_data_type() == dt); - return val_f32; - } - - int64 &val_int64() { - TI_ASSERT(get_data_type() == dt); - return val_i64; - } - - float64 &val_float64() { - TI_ASSERT(get_data_type() == dt); - return val_f64; - } + int32 &val_int32(); + float32 &val_float32(); + int64 &val_int64(); + float64 &val_float64(); + int8 &val_int8(); + int16 &val_int16(); + uint8 &val_uint8(); + uint16 &val_uint16(); + uint32 &val_uint32(); + uint64 &val_uint64(); + int64 val_int() const; + uint64 val_uint() const; + float64 val_float() const; }; inline std::string make_list(const std::vector &data, diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 41a53aab9c52c..2570c6b4ee66d 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -10,6 +10,7 @@ class AlgSimp : public BasicStmtVisitor { public: using BasicStmtVisitor::visit; bool fast_math; + std::vector to_erase; explicit AlgSimp(bool fast_math_) : BasicStmtVisitor(), fast_math(fast_math_) { @@ -30,26 +31,22 @@ class AlgSimp : public BasicStmtVisitor { if (alg_is_zero(rhs)) { // a +-|^ 0 -> a stmt->replace_with(stmt->lhs); - stmt->parent->erase(stmt); - throw IRModified(); + to_erase.push_back(stmt); } else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs)) { // 0 +|^ a -> a stmt->replace_with(stmt->rhs); - stmt->parent->erase(stmt); - throw IRModified(); + to_erase.push_back(stmt); } } else if (stmt->op_type == BinaryOpType::mul || stmt->op_type == BinaryOpType::div) { if (alg_is_one(rhs)) { // a */ 1 -> a stmt->replace_with(stmt->lhs); - stmt->parent->erase(stmt); - throw IRModified(); + to_erase.push_back(stmt); } else if (stmt->op_type == BinaryOpType::mul && alg_is_one(lhs)) { // 1 * a -> a stmt->replace_with(stmt->rhs); - stmt->parent->erase(stmt); - throw IRModified(); + to_erase.push_back(stmt); } else if ((fast_math || is_integral(stmt->ret_type.data_type)) && stmt->op_type == BinaryOpType::mul && (alg_is_zero(lhs) || alg_is_zero(rhs))) { @@ -58,20 +55,17 @@ class AlgSimp : public BasicStmtVisitor { LaneAttribute(stmt->ret_type.data_type)); stmt->replace_with(zero.get()); stmt->parent->insert_before(stmt, VecStatement(std::move(zero))); - stmt->parent->erase(stmt); - throw IRModified(); + to_erase.push_back(stmt); } } else if (stmt->op_type == BinaryOpType::bit_and) { if (alg_is_minus_one(rhs)) { // a & -1 -> a stmt->replace_with(stmt->lhs); - stmt->parent->erase(stmt); - throw IRModified(); + to_erase.push_back(stmt); } else if (alg_is_minus_one(lhs)) { // -1 & a -> a stmt->replace_with(stmt->rhs); - stmt->parent->erase(stmt); - throw IRModified(); + to_erase.push_back(stmt); } } } @@ -82,8 +76,7 @@ class AlgSimp : public BasicStmtVisitor { return; if (!alg_is_zero(cond)) { // this statement has no effect - stmt->parent->erase(stmt); - throw IRModified(); + to_erase.push_back(stmt); } } @@ -93,8 +86,7 @@ class AlgSimp : public BasicStmtVisitor { return; if (!alg_is_zero(cond)) { // this statement has no effect - stmt->parent->erase(stmt); - throw IRModified(); + to_erase.push_back(stmt); } } @@ -105,14 +97,12 @@ class AlgSimp : public BasicStmtVisitor { return false; auto val = stmt->val[0]; auto data_type = stmt->ret_type.data_type; - if (data_type == DataType::i32) - return val.val_int32() == 0; - else if (data_type == DataType::f32) - return val.val_float32() == 0; - else if (data_type == DataType::i64) - return val.val_int64() == 0; - else if (data_type == DataType::f64) - return val.val_float64() == 0; + if (is_real(data_type)) + return val.val_float() == 0; + else if (is_signed(data_type)) + return val.val_int() == 0; + else if (is_unsigned(data_type)) + return val.val_uint() == 0; else { TI_NOT_IMPLEMENTED return false; @@ -126,14 +116,12 @@ class AlgSimp : public BasicStmtVisitor { return false; auto val = stmt->val[0]; auto data_type = stmt->ret_type.data_type; - if (data_type == DataType::i32) - return val.val_int32() == 1; - else if (data_type == DataType::f32) - return val.val_float32() == 1; - else if (data_type == DataType::i64) - return val.val_int64() == 1; - else if (data_type == DataType::f64) - return val.val_float64() == 1; + if (is_real(data_type)) + return val.val_float() == 1; + else if (is_signed(data_type)) + return val.val_int() == 1; + else if (is_unsigned(data_type)) + return val.val_uint() == 1; else { TI_NOT_IMPLEMENTED return false; @@ -147,38 +135,37 @@ class AlgSimp : public BasicStmtVisitor { return false; auto val = stmt->val[0]; auto data_type = stmt->ret_type.data_type; - if (data_type == DataType::i32) - return val.val_int32() == -1; - else if (data_type == DataType::f32) - return val.val_float32() == -1; - else if (data_type == DataType::i64) - return val.val_int64() == -1; - else if (data_type == DataType::f64) - return val.val_float64() == -1; + if (is_real(data_type)) + return val.val_float() == -1; + else if (is_signed(data_type)) + return val.val_int() == -1; + else if (is_unsigned(data_type)) + return false; else { TI_NOT_IMPLEMENTED - return false; } } - static void run(IRNode *node, bool fast_math) { + static bool run(IRNode *node, bool fast_math) { AlgSimp simplifier(fast_math); + bool modified = false; while (true) { - bool modified = false; - try { - node->accept(&simplifier); - } catch (IRModified) { - modified = true; - } - if (!modified) + node->accept(&simplifier); + if (simplifier.to_erase.empty()) break; + modified = true; + for (auto &stmt : simplifier.to_erase) { + stmt->parent->erase(stmt); + } + simplifier.to_erase.clear(); } + return modified; } }; namespace irpass { -void alg_simp(IRNode *root, const CompileConfig &config) { +bool alg_simp(IRNode *root, const CompileConfig &config) { return AlgSimp::run(root, config.fast_math); } From 5b9b8c795373de103372875180c071ca49dfa95d Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 25 May 2020 20:48:26 -0400 Subject: [PATCH 2/4] fix tests --- taichi/transforms/alg_simp.cpp | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 2570c6b4ee66d..49191040bd592 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -11,6 +11,7 @@ class AlgSimp : public BasicStmtVisitor { using BasicStmtVisitor::visit; bool fast_math; std::vector to_erase; + std::vector, Stmt *>> to_insert_before; explicit AlgSimp(bool fast_math_) : BasicStmtVisitor(), fast_math(fast_math_) { @@ -51,11 +52,21 @@ class AlgSimp : public BasicStmtVisitor { stmt->op_type == BinaryOpType::mul && (alg_is_zero(lhs) || alg_is_zero(rhs))) { // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0 - auto zero = Stmt::make( - LaneAttribute(stmt->ret_type.data_type)); - stmt->replace_with(zero.get()); - stmt->parent->insert_before(stmt, VecStatement(std::move(zero))); - to_erase.push_back(stmt); + if (alg_is_zero(lhs) && lhs->ret_type.data_type == + stmt->ret_type.data_type) { + stmt->replace_with(stmt->lhs); + to_erase.push_back(stmt); + } else if (alg_is_zero(rhs) && rhs->ret_type.data_type == + stmt->ret_type.data_type) { + stmt->replace_with(stmt->rhs); + to_erase.push_back(stmt); + } else { + auto zero = Stmt::make( + LaneAttribute(stmt->ret_type.data_type)); + stmt->replace_with(zero.get()); + to_insert_before.emplace_back(std::move(zero), stmt); + to_erase.push_back(stmt); + } } } else if (stmt->op_type == BinaryOpType::bit_and) { if (alg_is_minus_one(rhs)) { @@ -151,12 +162,16 @@ class AlgSimp : public BasicStmtVisitor { bool modified = false; while (true) { node->accept(&simplifier); - if (simplifier.to_erase.empty()) + if (simplifier.to_erase.empty() && simplifier.to_insert_before.empty()) break; modified = true; + for (auto &i : simplifier.to_insert_before) { + i.second->insert_before_me(std::move(i.first)); + } for (auto &stmt : simplifier.to_erase) { stmt->parent->erase(stmt); } + simplifier.to_insert_before.clear(); simplifier.to_erase.clear(); } return modified; From 7edf47b85b3aac4ec376f401f05c85409bf476fd Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 25 May 2020 20:51:03 -0400 Subject: [PATCH 3/4] [skip ci] code format --- taichi/transforms/alg_simp.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 49191040bd592..916758a37ebfe 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -52,12 +52,12 @@ class AlgSimp : public BasicStmtVisitor { stmt->op_type == BinaryOpType::mul && (alg_is_zero(lhs) || alg_is_zero(rhs))) { // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0 - if (alg_is_zero(lhs) && lhs->ret_type.data_type == - stmt->ret_type.data_type) { + if (alg_is_zero(lhs) && + lhs->ret_type.data_type == stmt->ret_type.data_type) { stmt->replace_with(stmt->lhs); to_erase.push_back(stmt); - } else if (alg_is_zero(rhs) && rhs->ret_type.data_type == - stmt->ret_type.data_type) { + } else if (alg_is_zero(rhs) && + rhs->ret_type.data_type == stmt->ret_type.data_type) { stmt->replace_with(stmt->rhs); to_erase.push_back(stmt); } else { From ded7b77829596b1d64660476246af7459158e71d Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 25 May 2020 20:58:51 -0400 Subject: [PATCH 4/4] No need to return after TI_NOT_IMPLEMENTED --- taichi/transforms/alg_simp.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 916758a37ebfe..f663a6c006a05 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -116,7 +116,6 @@ class AlgSimp : public BasicStmtVisitor { return val.val_uint() == 0; else { TI_NOT_IMPLEMENTED - return false; } } @@ -135,7 +134,6 @@ class AlgSimp : public BasicStmtVisitor { return val.val_uint() == 1; else { TI_NOT_IMPLEMENTED - return false; } }