Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] [lang] Support SHR operator: ti.bit_shr(x, y) #1871

Merged
merged 7 commits into from
Sep 23, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,12 @@ def bit_sar(a, b):
return _binary_operation(ti_core.expr_bit_sar, ops.rshift, a, b)


@taichi_scope
@binary
def bit_shr(a, b):
return _binary_operation(ti_core.expr_bit_shr, ops.rshift, a, b)
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved


# We don't have logic_and/or instructions yet:
logical_or = bit_or
logical_and = bit_and
Expand Down
9 changes: 7 additions & 2 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,13 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) {
llvm_val[stmt] =
builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else if (op == BinaryOpType::bit_sar) {
llvm_val[stmt] =
builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
if (is_signed(stmt->lhs->element_type())) {
llvm_val[stmt] =
builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
llvm_val[stmt] =
builder->CreateLShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
}
} else if (op == BinaryOpType::max) {
if (is_real(ret_type)) {
llvm_val[stmt] =
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/binary_op.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ PER_BINARY_OP(bit_and)
PER_BINARY_OP(bit_or)
PER_BINARY_OP(bit_xor)
PER_BINARY_OP(bit_shl)
PER_BINARY_OP(bit_shr)
PER_BINARY_OP(bit_sar)
PER_BINARY_OP(cmp_lt)
PER_BINARY_OP(cmp_le)
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/expression_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ DEFINE_EXPRESSION_FUNC(atan2);
DEFINE_EXPRESSION_FUNC(pow);
DEFINE_EXPRESSION_FUNC(truediv);
DEFINE_EXPRESSION_FUNC(floordiv);
DEFINE_EXPRESSION_FUNC(bit_shr)

#undef DEFINE_EXPRESSION_OP_UNARY
#undef DEFINE_EXPRESSION_OP_BINARY
Expand Down
16 changes: 16 additions & 0 deletions taichi/lang_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,22 @@ inline bool constexpr is_unsigned(DataType dt) {
return !is_signed(dt);
}

inline DataType to_unsigned(DataType dt) {
TI_ASSERT(is_signed(dt));
switch (dt) {
case DataType::i8:
return DataType::u8;
case DataType::i16:
return DataType::u16;
case DataType::i32:
return DataType::u32;
case DataType::i64:
return DataType::u32;
TH3CHARLie marked this conversation as resolved.
Show resolved Hide resolved
default:
return DataType::unknown;
}
}

inline bool needs_grad(DataType dt) {
return is_real(dt);
}
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ void export_lang(py::module &m) {
m.def("expr_bit_or", expr_bit_or);
m.def("expr_bit_xor", expr_bit_xor);
m.def("expr_bit_shl", expr_bit_shl);
m.def("expr_bit_shr", expr_bit_shr);
m.def("expr_bit_sar", expr_bit_sar);
m.def("expr_bit_not", expr_bit_not);
m.def("expr_logic_not", expr_logic_not);
Expand Down
23 changes: 23 additions & 0 deletions taichi/transforms/demote_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,29 @@ class DemoteOperations : public BasicStmtVisitor {
modifier.insert_before(stmt, std::move(floor));
modifier.erase(stmt);
}
} else if (stmt->op_type == BinaryOpType::bit_shr &&
is_integral(lhs->element_type()) &&
is_integral(rhs->element_type()) &&
is_signed(lhs->element_type())) {
// @ti.func
// def bit_shr(a, b):
// signed_a = ti.cast(a, ti.uXX)
// shifted = ti.bit_sar(a, b)
// ret = ti.cast(a, ti.iXX)
TH3CHARLie marked this conversation as resolved.
Show resolved Hide resolved
// return ret
auto unsigned_cast = Stmt::make<UnaryOpStmt>(UnaryOpType::cast_bits, lhs);
unsigned_cast->as<UnaryOpStmt>()->cast_type =
to_unsigned(lhs->element_type());
auto shift = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_sar,
unsigned_cast.get(), rhs);
auto signed_cast =
Stmt::make<UnaryOpStmt>(UnaryOpType::cast_bits, shift.get());
signed_cast->as<UnaryOpStmt>()->cast_type = lhs->element_type();
stmt->replace_with(signed_cast.get());
modifier.insert_before(stmt, std::move(unsigned_cast));
modifier.insert_before(stmt, std::move(shift));
modifier.insert_before(stmt, std::move(signed_cast));
modifier.erase(stmt);
}
}

Expand Down
16 changes: 16 additions & 0 deletions tests/python/test_bit_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,19 @@ def sar(a: ti.i32, b: ti.i32) -> ti.i32:
# for negative number
for i in range(n):
assert sar(neg_test_num, i) == -2**(n - i)


@ti.test()
def test_bit_shr():
@ti.kernel
def shr(a: ti.i32, b: ti.i32) -> ti.i32:
return ti.bit_shr(a, b)

n = 8
test_num = 2**n
neg_test_num = -test_num
for i in range(n):
assert shr(test_num, i) == 2**(n - i)
for i in range(n):
offset = 0x100000000 if i > 0 else 0
assert shr(neg_test_num, i) == (neg_test_num + offset) >> i
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test will fail on backends other than LLVM due to unimplemented op.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This includes(but not limited to...) cc, metal and opengl.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch... I believe on these backends SHR is not implemented (but SAR is). How about this: let's demote bit_shr into a series of three operations

  • UnaryOpStmt bit_cast into unsigned
  • BinaryOpStmt bit_sar
  • UnaryOpStmt bit_cast into signed

in this pass

void visit(BinaryOpStmt *stmt) override {

https://github.com/yuanming-hu/taichi/blob/2cc50cfaeb84c6ba4f3df7d3a0caa028daadb51f/taichi/transforms/demote_operations.cpp#L17

So that backend developers don't need to worry about SAR. This is a very late pass in the compilation process - we still need SAR in the IR for certain domain-specific optimizations.

Copy link
Collaborator

@archibate archibate Sep 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that backend developers don't need to worry about SAR. This is a very late pass in the compilation process - we still need SAR in the IR for certain domain-specific optimizations.

Good idea! I'll later add uint support to OpenGL so that SHR works.

EDIT: So bit_sar should act as SHR when operand is uint to make this method work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@archibate That's what I'd also like to discuss. I did some experiments, in the LLVM backend, bit_sar is implemented using CreateAShr which simply copies the MSB without considering the type. While on other backends, bit_sar is expressed with >> which will behave differently according to unsigned/signed information. I think we need to decide which kind of SAR operation we truly want, either a pure one copying the MSB or a one that is more similar to the >> in C.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the discussions. My two cents:

  • In the frontend >> always translates to BinaryOpType::bit_sar and ti.bit_shr always translates to BinaryOpType::bit_shr.
  • In demote_operations.cpp we convert bit_shr on signed integers into three sub-operations as discussed above. Then we only have bit_sar for the backend.
  • In the backend, since we only have bit_sar, its behavior is determined by the type of its operands.

Copy link
Collaborator Author

@TH3CHARLie TH3CHARLie Sep 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuanming-hu Thanks for the clarification.

The point(or problem) I'd like to discuss is the behavior of bit_sar which bit_shr will eventually rely on. As I wrote in the above comment, in the LLVM backend, bit_sar is implemented with a CreateAShr which directly maps to a sar instruction and simply copies the MSB, that is, it will ignore the type information so casting signed to unsigned will not work since the low-level bits are not changed by type casts. In other backends (at least Metal), bit_sar is implemented using the operator >>, which will consider the type information, that is, it will map to a zext and sar when working on unsigned integers. This is creating different behaviors across backends. We should first decide what kind of behavior we want for bit_sar before we move on.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's follow the Metal behavior bit_sar in backends. And provide bit_shr using demote_operations as described above.
The LLVM backend should do some branches to match the behavior of Metal, e.g.:

// on bit_sar
if (dtype == unsigned) {
  CreateAShr();
} else {
  CreateLShr();
}

Copy link
Member

@yuanming-hu yuanming-hu Sep 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps one source of confusion is that LLVM doesn't distinguish signed/unsigned integers (so does hardware such x64). So SHR always shifts everything and SAR copies the MSB in LLVM.

In Taichi, on unsigned integers bit_shr=bit_sar and they both map to SHR in LLVM. On signed integers, bit_shr maps to SHR in LLVM and bit_sar maps to SAR in LLVM.