From 59cba6fbc993b5d81ac4609ca0586bd5904ec989 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 15 Sep 2020 11:17:30 +0800 Subject: [PATCH 1/6] [IR] [lang] Support SHR operator: ti.bit_shr(x, y) --- python/taichi/lang/ops.py | 3 +++ taichi/codegen/codegen_llvm.cpp | 3 +++ taichi/inc/binary_op.inc.h | 1 + taichi/ir/expression_ops.h | 1 + taichi/python/export_lang.cpp | 1 + tests/python/test_bit_operations.py | 15 +++++++++++++++ 6 files changed, 24 insertions(+) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index 00e25bc1fcd01..a0ecb98df8df9 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -419,6 +419,9 @@ def bit_shl(a, b): def bit_sar(a, b): return _binary_operation(ti_core.expr_bit_sar, ops.rshift, a, b) +@binary +def bit_shr(a, b): + return _binary_operation(ti_core.expr_bit_shr, ops.rshift, a, b) # We don't have logic_and/or instructions yet: logical_or = bit_or diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index ea18b14b29e7f..a16a2d29e9958 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -415,6 +415,9 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } else if (op == BinaryOpType::bit_shl) { llvm_val[stmt] = builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); + } else if (op == BinaryOpType::bit_shr) { + llvm_val[stmt] = + builder->CreateLShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else if (op == BinaryOpType::bit_sar) { llvm_val[stmt] = builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); diff --git a/taichi/inc/binary_op.inc.h b/taichi/inc/binary_op.inc.h index d2c946fd56d49..cff268e4fed8e 100644 --- a/taichi/inc/binary_op.inc.h +++ b/taichi/inc/binary_op.inc.h @@ -11,6 +11,7 @@ PER_BINARY_OP(bit_and) PER_BINARY_OP(bit_or) PER_BINARY_OP(bit_xor) PER_BINARY_OP(bit_shl) +PER_BINARY_OP(bit_shr) PER_BINARY_OP(bit_sar) PER_BINARY_OP(cmp_lt) PER_BINARY_OP(cmp_le) diff --git a/taichi/ir/expression_ops.h b/taichi/ir/expression_ops.h index d0a0f76035f53..3c349e65dbcdb 100644 --- a/taichi/ir/expression_ops.h +++ b/taichi/ir/expression_ops.h @@ -100,6 +100,7 @@ DEFINE_EXPRESSION_FUNC(atan2); DEFINE_EXPRESSION_FUNC(pow); DEFINE_EXPRESSION_FUNC(truediv); DEFINE_EXPRESSION_FUNC(floordiv); +DEFINE_EXPRESSION_FUNC(bit_shr) #undef DEFINE_EXPRESSION_OP_UNARY #undef DEFINE_EXPRESSION_OP_BINARY diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 64fb5b587fe39..d0757484e6715 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -430,6 +430,7 @@ void export_lang(py::module &m) { m.def("expr_bit_or", expr_bit_or); m.def("expr_bit_xor", expr_bit_xor); m.def("expr_bit_shl", expr_bit_shl); + m.def("expr_bit_shr", expr_bit_shr); m.def("expr_bit_sar", expr_bit_sar); m.def("expr_bit_not", expr_bit_not); m.def("expr_logic_not", expr_logic_not); diff --git a/tests/python/test_bit_operations.py b/tests/python/test_bit_operations.py index faa250e51f635..2fd094a779026 100644 --- a/tests/python/test_bit_operations.py +++ b/tests/python/test_bit_operations.py @@ -29,3 +29,18 @@ def sar(a: ti.i32, b: ti.i32) -> ti.i32: # for negative number for i in range(n): assert sar(neg_test_num, i) == -2**(n - i) + +@ti.test() +def test_bit_shr(): + @ti.kernel + def shr(a: ti.i32, b: ti.i32) -> ti.i32: + return ti.bit_shr(a, b) + + n = 8 + test_num = 2**n + neg_test_num = -test_num + for i in range(n): + assert shr(test_num, i) == 2**(n - i) + for i in range(n): + offset = 0x100000000 if i > 0 else 0 + assert shr(neg_test_num, i) == (neg_test_num + offset) >> i From 98c1b7e27a79d3e60c2e53fe7f3506d08fe0501e Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Mon, 14 Sep 2020 23:20:07 -0400 Subject: [PATCH 2/6] [skip ci] enforce code format --- python/taichi/lang/ops.py | 2 ++ tests/python/test_bit_operations.py | 1 + 2 files changed, 3 insertions(+) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index a0ecb98df8df9..15fab5905f9f1 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -419,10 +419,12 @@ def bit_shl(a, b): def bit_sar(a, b): return _binary_operation(ti_core.expr_bit_sar, ops.rshift, a, b) + @binary def bit_shr(a, b): return _binary_operation(ti_core.expr_bit_shr, ops.rshift, a, b) + # We don't have logic_and/or instructions yet: logical_or = bit_or logical_and = bit_and diff --git a/tests/python/test_bit_operations.py b/tests/python/test_bit_operations.py index 2fd094a779026..f069547a63c39 100644 --- a/tests/python/test_bit_operations.py +++ b/tests/python/test_bit_operations.py @@ -30,6 +30,7 @@ def sar(a: ti.i32, b: ti.i32) -> ti.i32: for i in range(n): assert sar(neg_test_num, i) == -2**(n - i) + @ti.test() def test_bit_shr(): @ti.kernel From 8eaea7767768056673fd410aaecb70092dc6dfd1 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 22 Sep 2020 16:58:10 +0800 Subject: [PATCH 3/6] demote bit_shr, change bit_sar's LLVM implementation --- python/taichi/lang/ops.py | 1 + taichi/codegen/codegen_llvm.cpp | 10 ++++++---- taichi/lang_util.h | 16 ++++++++++++++++ taichi/transforms/demote_operations.cpp | 20 ++++++++++++++++++++ 4 files changed, 43 insertions(+), 4 deletions(-) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index a0ecb98df8df9..852144fc394d2 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -419,6 +419,7 @@ def bit_shl(a, b): def bit_sar(a, b): return _binary_operation(ti_core.expr_bit_sar, ops.rshift, a, b) +@taichi_scope @binary def bit_shr(a, b): return _binary_operation(ti_core.expr_bit_shr, ops.rshift, a, b) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index a16a2d29e9958..0a45c54fa587f 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -415,12 +415,14 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } else if (op == BinaryOpType::bit_shl) { llvm_val[stmt] = builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); - } else if (op == BinaryOpType::bit_shr) { - llvm_val[stmt] = - builder->CreateLShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else if (op == BinaryOpType::bit_sar) { - llvm_val[stmt] = + if (is_signed(stmt->lhs->element_type())) { + llvm_val[stmt] = builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); + } else { + llvm_val[stmt] = + builder->CreateLShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); + } } else if (op == BinaryOpType::max) { if (is_real(ret_type)) { llvm_val[stmt] = diff --git a/taichi/lang_util.h b/taichi/lang_util.h index c5c02d16d09b9..2dea235b17be3 100644 --- a/taichi/lang_util.h +++ b/taichi/lang_util.h @@ -109,6 +109,22 @@ inline bool constexpr is_unsigned(DataType dt) { return !is_signed(dt); } +inline DataType to_unsigned(DataType dt) { + TI_ASSERT(is_signed(dt)); + switch (dt) { + case DataType::i8: + return DataType::u8; + case DataType::i16: + return DataType::u16; + case DataType::i32: + return DataType::u32; + case DataType::i64: + return DataType::u32; + default: + return DataType::unknown; + } +} + inline bool needs_grad(DataType dt) { return is_real(dt); } diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index 80b80187745fb..fd3b478ed5104 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -73,6 +73,26 @@ class DemoteOperations : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(floor)); modifier.erase(stmt); } + } else if (stmt->op_type == BinaryOpType::bit_shr && + is_integral(lhs->element_type()) && + is_integral(rhs->element_type()) && + is_signed(lhs->element_type())) { + // @ti.func + // def bit_shr(a, b): + // signed_a = ti.cast(a, ti.uXX) + // shifted = ti.bit_sar(a, b) + // ret = ti.cast(a, ti.iXX) + // return ret + auto unsigned_cast = Stmt::make(UnaryOpType::cast_bits, lhs); + unsigned_cast->as()->cast_type = to_unsigned(lhs->element_type()); + auto shift = Stmt::make(BinaryOpType::bit_sar, unsigned_cast.get(), rhs); + auto signed_cast = Stmt::make(UnaryOpType::cast_bits, shift.get()); + signed_cast->as()->cast_type = lhs->element_type(); + stmt->replace_with(signed_cast.get()); + modifier.insert_before(stmt, std::move(unsigned_cast)); + modifier.insert_before(stmt, std::move(shift)); + modifier.insert_before(stmt, std::move(signed_cast)); + modifier.erase(stmt); } } From cc216bfbae0103535b59da50f77edbf4bb54dec6 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Tue, 22 Sep 2020 04:59:38 -0400 Subject: [PATCH 4/6] [skip ci] enforce code format --- python/taichi/lang/ops.py | 1 + taichi/codegen/codegen_llvm.cpp | 4 ++-- taichi/transforms/demote_operations.cpp | 9 ++++++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index b60aff2c65875..8421d89540062 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -419,6 +419,7 @@ def bit_shl(a, b): def bit_sar(a, b): return _binary_operation(ti_core.expr_bit_sar, ops.rshift, a, b) + @taichi_scope @binary def bit_shr(a, b): diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 0a45c54fa587f..01b1058a497c5 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -418,10 +418,10 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } else if (op == BinaryOpType::bit_sar) { if (is_signed(stmt->lhs->element_type())) { llvm_val[stmt] = - builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); + builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { llvm_val[stmt] = - builder->CreateLShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); + builder->CreateLShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::max) { if (is_real(ret_type)) { diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index fd3b478ed5104..1179a68d307dc 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -84,9 +84,12 @@ class DemoteOperations : public BasicStmtVisitor { // ret = ti.cast(a, ti.iXX) // return ret auto unsigned_cast = Stmt::make(UnaryOpType::cast_bits, lhs); - unsigned_cast->as()->cast_type = to_unsigned(lhs->element_type()); - auto shift = Stmt::make(BinaryOpType::bit_sar, unsigned_cast.get(), rhs); - auto signed_cast = Stmt::make(UnaryOpType::cast_bits, shift.get()); + unsigned_cast->as()->cast_type = + to_unsigned(lhs->element_type()); + auto shift = Stmt::make(BinaryOpType::bit_sar, + unsigned_cast.get(), rhs); + auto signed_cast = + Stmt::make(UnaryOpType::cast_bits, shift.get()); signed_cast->as()->cast_type = lhs->element_type(); stmt->replace_with(signed_cast.get()); modifier.insert_before(stmt, std::move(unsigned_cast)); From fae7a3b6571ec5dda56646f1573a521ff22a3ea7 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Wed, 23 Sep 2020 17:29:57 +0800 Subject: [PATCH 5/6] Update taichi/lang_util.h Co-authored-by: Jiafeng Liu --- taichi/lang_util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/lang_util.h b/taichi/lang_util.h index 2dea235b17be3..2b8344da61a3c 100644 --- a/taichi/lang_util.h +++ b/taichi/lang_util.h @@ -119,7 +119,7 @@ inline DataType to_unsigned(DataType dt) { case DataType::i32: return DataType::u32; case DataType::i64: - return DataType::u32; + return DataType::u64; default: return DataType::unknown; } From 8fc51d9d275ca04f26ef23d4ec3bea8c3f3f7107 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Wed, 23 Sep 2020 17:31:13 +0800 Subject: [PATCH 6/6] Update taichi/transforms/demote_operations.cpp Co-authored-by: Jiafeng Liu --- taichi/transforms/demote_operations.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index 1179a68d307dc..e63e4ad5f8f3b 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -80,8 +80,8 @@ class DemoteOperations : public BasicStmtVisitor { // @ti.func // def bit_shr(a, b): // signed_a = ti.cast(a, ti.uXX) - // shifted = ti.bit_sar(a, b) - // ret = ti.cast(a, ti.iXX) + // shifted = ti.bit_sar(signed_a, b) + // ret = ti.cast(shifted, ti.iXX) // return ret auto unsigned_cast = Stmt::make(UnaryOpType::cast_bits, lhs); unsigned_cast->as()->cast_type =