Skip to content

Commit

Permalink
[spirv] Generate OpBitFieldUExtract for BitExtractStmt (#6208)
Browse files Browse the repository at this point in the history
Issue:
#6141 (comment),
#6134

### Brief Summary

In SPIR-V there's actually `OpBitFieldUExtract` which does the job of
`BitExtractStmt` perfectly, so let's stop demoting `BitExtractStmt` and
let codegen handle it with the best instruction(s).
  • Loading branch information
strongoier authored Oct 8, 2022
1 parent 3dbce29 commit 74a3072
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 26 deletions.
8 changes: 1 addition & 7 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,7 @@ class TaskCodegen : public IRVisitor {
spirv::Value tmp0 = ir_->int_immediate_number(stype, stmt->bit_begin);
spirv::Value tmp1 =
ir_->int_immediate_number(stype, stmt->bit_end - stmt->bit_begin);
spirv::Value tmp2 =
ir_->make_value(spv::OpShiftRightArithmetic, stype, input_val, tmp0);
spirv::Value tmp3 =
ir_->make_value(spv::OpShiftLeftLogical, stype,
ir_->int_immediate_number(stype, 1), tmp1);
spirv::Value tmp4 = ir_->sub(tmp3, ir_->int_immediate_number(stype, 1));
spirv::Value val = ir_->make_value(spv::OpBitwiseAnd, stype, tmp2, tmp4);
spirv::Value val = ir_->bit_field_extract(input_val, tmp0, tmp1);
ir_->register_value(stmt->raw_name(), val);
}

Expand Down
7 changes: 7 additions & 0 deletions taichi/codegen/spirv/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,13 @@ DEFINE_BUILDER_CMP_OP(ge, GreaterThanEqual);
DEFINE_BUILDER_CMP_UOP(eq, Equal);
DEFINE_BUILDER_CMP_UOP(ne, NotEqual);

Value IRBuilder::bit_field_extract(Value base, Value offset, Value count) {
TI_ASSERT(is_integral(base.stype.dt));
TI_ASSERT(is_integral(offset.stype.dt));
TI_ASSERT(is_integral(count.stype.dt));
return make_value(spv::OpBitFieldUExtract, base.stype, base, offset, count);
}

Value IRBuilder::select(Value cond, Value a, Value b) {
TI_ASSERT(a.stype.id == b.stype.id);
TI_ASSERT(cond.stype.id == t_bool_.id);
Expand Down
1 change: 1 addition & 0 deletions taichi/codegen/spirv/spirv_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ class IRBuilder {
Value le(Value a, Value b);
Value gt(Value a, Value b);
Value ge(Value a, Value b);
Value bit_field_extract(Value base, Value offset, Value count);
Value select(Value cond, Value a, Value b);

// Create a cast that cast value to dst_type
Expand Down
19 changes: 0 additions & 19 deletions taichi/transforms/demote_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,6 @@ class DemoteOperations : public BasicStmtVisitor {
return floor;
}

void visit(BitExtractStmt *stmt) override {
// @ti.func
// def bit_extract(input, begin, end):
// return (input >> begin) & ((1 << (end - begin)) - 1)
VecStatement statements;
auto begin = statements.push_back<ConstStmt>(
TypedConstant(stmt->input->ret_type, stmt->bit_begin));
auto input_sar_begin = statements.push_back<BinaryOpStmt>(
BinaryOpType::bit_sar, stmt->input, begin);
auto mask = statements.push_back<ConstStmt>(TypedConstant(
stmt->input->ret_type, (1LL << (stmt->bit_end - stmt->bit_begin)) - 1));
auto ret = statements.push_back<BinaryOpStmt>(BinaryOpType::bit_and,
input_sar_begin, mask);
ret->ret_type = stmt->ret_type;
stmt->replace_usages_with(ret);
modifier.insert_before(stmt, std::move(statements));
modifier.erase(stmt);
}

void visit(BinaryOpStmt *stmt) override {
auto lhs = stmt->lhs;
auto rhs = stmt->rhs;
Expand Down

0 comments on commit 74a3072

Please sign in to comment.