From 74a3072638be900ea9511b2529b8b322d2c85f80 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Sat, 8 Oct 2022 15:03:03 +0800 Subject: [PATCH] [spirv] Generate OpBitFieldUExtract for BitExtractStmt (#6208) Issue: https://github.com/taichi-dev/taichi/pull/6141#issuecomment-1255339288, #6134 ### Brief Summary In SPIR-V there's actually `OpBitFieldUExtract` which does the job of `BitExtractStmt` perfectly, so let's stop demoting `BitExtractStmt` and let codegen handle it with the best instruction(s). --- taichi/codegen/spirv/spirv_codegen.cpp | 8 +------- taichi/codegen/spirv/spirv_ir_builder.cpp | 7 +++++++ taichi/codegen/spirv/spirv_ir_builder.h | 1 + taichi/transforms/demote_operations.cpp | 19 ------------------- 4 files changed, 9 insertions(+), 26 deletions(-) diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 6fb86dbb6091e..bcd83fc8c4043 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -465,13 +465,7 @@ class TaskCodegen : public IRVisitor { spirv::Value tmp0 = ir_->int_immediate_number(stype, stmt->bit_begin); spirv::Value tmp1 = ir_->int_immediate_number(stype, stmt->bit_end - stmt->bit_begin); - spirv::Value tmp2 = - ir_->make_value(spv::OpShiftRightArithmetic, stype, input_val, tmp0); - spirv::Value tmp3 = - ir_->make_value(spv::OpShiftLeftLogical, stype, - ir_->int_immediate_number(stype, 1), tmp1); - spirv::Value tmp4 = ir_->sub(tmp3, ir_->int_immediate_number(stype, 1)); - spirv::Value val = ir_->make_value(spv::OpBitwiseAnd, stype, tmp2, tmp4); + spirv::Value val = ir_->bit_field_extract(input_val, tmp0, tmp1); ir_->register_value(stmt->raw_name(), val); } diff --git a/taichi/codegen/spirv/spirv_ir_builder.cpp b/taichi/codegen/spirv/spirv_ir_builder.cpp index 2fa13b459c6b3..f419fc52736bd 100644 --- a/taichi/codegen/spirv/spirv_ir_builder.cpp +++ b/taichi/codegen/spirv/spirv_ir_builder.cpp @@ -1037,6 +1037,13 @@ DEFINE_BUILDER_CMP_OP(ge, GreaterThanEqual); DEFINE_BUILDER_CMP_UOP(eq, Equal); DEFINE_BUILDER_CMP_UOP(ne, NotEqual); +Value IRBuilder::bit_field_extract(Value base, Value offset, Value count) { + TI_ASSERT(is_integral(base.stype.dt)); + TI_ASSERT(is_integral(offset.stype.dt)); + TI_ASSERT(is_integral(count.stype.dt)); + return make_value(spv::OpBitFieldUExtract, base.stype, base, offset, count); +} + Value IRBuilder::select(Value cond, Value a, Value b) { TI_ASSERT(a.stype.id == b.stype.id); TI_ASSERT(cond.stype.id == t_bool_.id); diff --git a/taichi/codegen/spirv/spirv_ir_builder.h b/taichi/codegen/spirv/spirv_ir_builder.h index 3f5c280037cf3..c1bbabbc68814 100644 --- a/taichi/codegen/spirv/spirv_ir_builder.h +++ b/taichi/codegen/spirv/spirv_ir_builder.h @@ -445,6 +445,7 @@ class IRBuilder { Value le(Value a, Value b); Value gt(Value a, Value b); Value ge(Value a, Value b); + Value bit_field_extract(Value base, Value offset, Value count); Value select(Value cond, Value a, Value b); // Create a cast that cast value to dst_type diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index e72054a5ba64b..4ab8ae3fa2143 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -61,25 +61,6 @@ class DemoteOperations : public BasicStmtVisitor { return floor; } - void visit(BitExtractStmt *stmt) override { - // @ti.func - // def bit_extract(input, begin, end): - // return (input >> begin) & ((1 << (end - begin)) - 1) - VecStatement statements; - auto begin = statements.push_back( - TypedConstant(stmt->input->ret_type, stmt->bit_begin)); - auto input_sar_begin = statements.push_back( - BinaryOpType::bit_sar, stmt->input, begin); - auto mask = statements.push_back(TypedConstant( - stmt->input->ret_type, (1LL << (stmt->bit_end - stmt->bit_begin)) - 1)); - auto ret = statements.push_back(BinaryOpType::bit_and, - input_sar_begin, mask); - ret->ret_type = stmt->ret_type; - stmt->replace_usages_with(ret); - modifier.insert_before(stmt, std::move(statements)); - modifier.erase(stmt); - } - void visit(BinaryOpStmt *stmt) override { auto lhs = stmt->lhs; auto rhs = stmt->rhs;