Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Remove legacy BitExtractStmt #7221

Merged
merged 2 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions taichi/analysis/arithmetic_interpretor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,6 @@ class EvalVisitor : public IRVisitor {
}
}

void visit(BitExtractStmt *stmt) override {
auto val_opt = context_.maybe_get(stmt->input);
if (!val_opt) {
failed_ = true;
return;
}
const uint64_t mask = (1ULL << (stmt->bit_end - stmt->bit_begin)) - 1;
auto val = val_opt.value().val_int();
val = (val >> stmt->bit_begin) & mask;
insert_to_ctx(stmt, stmt->ret_type, val);
}

void visit(LinearizeStmt *stmt) override {
int64_t val = 0;
for (int i = 0; i < (int)stmt->inputs.size(); ++i) {
Expand Down
6 changes: 0 additions & 6 deletions taichi/codegen/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,6 @@ class CCTransformer : public IRVisitor {
stmt->tb);
}

void visit(BitExtractStmt *stmt) override {
emit("{} = (({} >> {}) & ((1 << {}) - 1));",
define_var("Ti_i32", stmt->raw_name()), stmt->input->raw_name(),
stmt->bit_begin, stmt->bit_end - stmt->bit_begin);
}

std::string define_var(std::string const &type, std::string const &name) {
if (C90_COMPAT) {
emit_header("{} {};", type, name);
Expand Down
7 changes: 0 additions & 7 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1698,13 +1698,6 @@ void TaskCodeGenLLVM::visit(GetRootStmt *stmt) {
0));
}

void TaskCodeGenLLVM::visit(BitExtractStmt *stmt) {
int mask = (1u << (stmt->bit_end - stmt->bit_begin)) - 1;
llvm_val[stmt] = builder->CreateAnd(
builder->CreateLShr(llvm_val[stmt->input], stmt->bit_begin),
tlctx->get_constant(mask));
}

void TaskCodeGenLLVM::visit(LinearizeStmt *stmt) {
llvm::Value *val = tlctx->get_constant(0);
for (int i = 0; i < (int)stmt->inputs.size(); i++) {
Expand Down
2 changes: 0 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,6 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(GetRootStmt *stmt) override;

void visit(BitExtractStmt *stmt) override;

void visit(LinearizeStmt *stmt) override;

void visit(IntegerOffsetStmt *stmt) override;
Expand Down
10 changes: 0 additions & 10 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,16 +471,6 @@ class TaskCodegen : public IRVisitor {
ir_->register_value(stmt->raw_name(), val);
}

void visit(BitExtractStmt *stmt) override {
spirv::Value input_val = ir_->query_value(stmt->input->raw_name());
auto stype = input_val.stype;
spirv::Value tmp0 = ir_->int_immediate_number(stype, stmt->bit_begin);
spirv::Value tmp1 =
ir_->int_immediate_number(stype, stmt->bit_end - stmt->bit_begin);
spirv::Value val = ir_->bit_field_extract(input_val, tmp0, tmp1);
ir_->register_value(stmt->raw_name(), val);
}

void visit(LoopIndexStmt *stmt) override {
const auto stmt_name = stmt->raw_name();
if (stmt->loop->is<OffloadedStmt>()) {
Expand Down
1 change: 0 additions & 1 deletion taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ PER_STATEMENT(AdStackAccAdjointStmt)
// SNode Micro Ops
PER_STATEMENT(GetRootStmt)
PER_STATEMENT(IntegerOffsetStmt)
PER_STATEMENT(BitExtractStmt)
PER_STATEMENT(LinearizeStmt)
PER_STATEMENT(SNodeLookupStmt)
PER_STATEMENT(GetChStmt)
Expand Down
30 changes: 0 additions & 30 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,36 +404,6 @@ ClearListStmt::ClearListStmt(SNode *snode) : snode(snode) {
TI_STMT_REG_FIELDS;
}

int LoopIndexStmt::max_num_bits() const {
if (auto range_for = loop->cast<RangeForStmt>()) {
// Return the max number of bits only if both begin and end are
// non-negative consts.
if (!range_for->begin->is<ConstStmt>() || !range_for->end->is<ConstStmt>())
return -1;
auto begin = range_for->begin->as<ConstStmt>();
if (begin->val.val_int() < 0)
return -1;
auto end = range_for->end->as<ConstStmt>();
return (int)bit::ceil_log2int(end->val.val_int());
} else if (auto struct_for = loop->cast<StructForStmt>()) {
return struct_for->snode->get_num_bits(index);
} else if (auto offload = loop->cast<OffloadedStmt>()) {
if (offload->task_type == OffloadedStmt::TaskType::range_for) {
if (!offload->const_begin || !offload->const_end)
return -1;
if (offload->begin_value < 0)
return -1;
return bit::ceil_log2int(offload->end_value);
} else if (offload->task_type == OffloadedStmt::TaskType::struct_for) {
return offload->snode->get_num_bits(index);
} else {
TI_NOT_IMPLEMENTED
}
} else {
TI_NOT_IMPLEMENTED
}
}

BitStructType *BitStructStoreStmt::get_bit_struct() const {
return ptr->as<SNodeLookupStmt>()->snode->dt->as<BitStructType>();
}
Expand Down
27 changes: 0 additions & 27 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -1196,30 +1196,6 @@ class LinearizeStmt : public Stmt {
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* Extract an interval of bits from an integral value.
* Equivalent to (|input| >> |bit_begin|) &
* ((1 << (|bit_end| - |bit_begin|)) - 1).
*/
class BitExtractStmt : public Stmt {
public:
Stmt *input;
int bit_begin, bit_end;
bool simplified;
BitExtractStmt(Stmt *input, int bit_begin, int bit_end)
: input(input), bit_begin(bit_begin), bit_end(bit_end) {
simplified = false;
TI_STMT_REG_FIELDS;
}

bool has_global_side_effect() const override {
return false;
}

TI_STMT_DEF_FIELDS(ret_type, input, bit_begin, bit_end, simplified);
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* The SNode root.
*/
Expand Down Expand Up @@ -1440,9 +1416,6 @@ class LoopIndexStmt : public Stmt {
return false;
}

// Return the number of bits of the loop, or -1 if unknown.
int max_num_bits() const;

TI_STMT_DEF_FIELDS(ret_type, loop, index);
TI_DEFINE_ACCEPT_AND_CLONE
};
Expand Down
8 changes: 2 additions & 6 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor {
if (execute_once_)
return;
if (!(stmt->is<UnaryOpStmt>() || stmt->is<BinaryOpStmt>() ||
stmt->is<TernaryOpStmt>() || stmt->is<BitExtractStmt>() ||
stmt->is<GlobalLoadStmt>() || stmt->is<AllocaStmt>())) {
stmt->is<TernaryOpStmt>() || stmt->is<GlobalLoadStmt>() ||
stmt->is<AllocaStmt>())) {
// TODO: this list may be incomplete
return;
}
Expand Down Expand Up @@ -680,10 +680,6 @@ class ADTransform : public IRVisitor {
// do nothing
}

void visit(BitExtractStmt *stmt) override {
// do nothing
}

void visit(IntegerOffsetStmt *stmt) override {
// do nothing
}
Expand Down
19 changes: 0 additions & 19 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,25 +208,6 @@ class ConstantFold : public BasicStmtVisitor {
}
}

void visit(BitExtractStmt *stmt) override {
auto input = stmt->input->cast<ConstStmt>();
if (!input)
return;
std::unique_ptr<Stmt> result_stmt;
if (is_signed(input->val.dt)) {
auto result = (input->val.val_int() >> stmt->bit_begin) &
((1LL << (stmt->bit_end - stmt->bit_begin)) - 1);
result_stmt = Stmt::make<ConstStmt>(TypedConstant(input->val.dt, result));
} else {
auto result = (input->val.val_uint() >> stmt->bit_begin) &
((1LL << (stmt->bit_end - stmt->bit_begin)) - 1);
result_stmt = Stmt::make<ConstStmt>(TypedConstant(input->val.dt, result));
}
stmt->replace_usages_with(result_stmt.get());
modifier.insert_before(stmt, std::move(result_stmt));
modifier.erase(stmt);
}

static bool run(IRNode *node,
Program *program,
const CompileConfig &compile_config) {
Expand Down
5 changes: 0 additions & 5 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,11 +550,6 @@ class IRPrinter : public IRVisitor {
stmt->input->name(), stmt->offset);
}

void visit(BitExtractStmt *stmt) override {
print("{}{} = bit_extract({}) bit_range=[{}, {})", stmt->type_hint(),
stmt->name(), stmt->input->name(), stmt->bit_begin, stmt->bit_end);
}

void visit(GetRootStmt *stmt) override {
if (stmt->root() == nullptr)
print("{}{} = get root nullptr", stmt->type_hint(), stmt->name());
Expand Down
92 changes: 0 additions & 92 deletions taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,98 +134,6 @@ class BasicBlockSimplify : public IRVisitor {
}
}

void visit(BitExtractStmt *stmt) override {
if (is_done(stmt))
return;

// step 0: eliminate empty extraction
if (stmt->bit_begin == stmt->bit_end) {
auto zero = Stmt::make<ConstStmt>(TypedConstant(0));
stmt->replace_usages_with(zero.get());
modifier.insert_after(stmt, std::move(zero));
modifier.erase(stmt);
return;
}

// step 1: eliminate useless extraction of another BitExtractStmt
if (stmt->bit_begin == 0 && stmt->input->is<BitExtractStmt>()) {
auto bstmt = stmt->input->as<BitExtractStmt>();
if (stmt->bit_end >= bstmt->bit_end - bstmt->bit_begin) {
stmt->replace_usages_with(bstmt);
modifier.erase(stmt);
return;
}
}

// step 2: eliminate useless extraction of a LoopIndexStmt
if (stmt->bit_begin == 0 && stmt->input->is<LoopIndexStmt>()) {
auto bstmt = stmt->input->as<LoopIndexStmt>();
const int max_num_bits = bstmt->max_num_bits();
if (max_num_bits != -1 && stmt->bit_end >= max_num_bits) {
stmt->replace_usages_with(bstmt);
modifier.erase(stmt);
return;
}
}

// step 3: try weakening when a struct for is used
if (current_struct_for && !stmt->simplified) {
const int num_loop_vars = current_struct_for->snode->num_active_indices;
for (int k = 0; k < num_loop_vars; k++) {
auto diff = irpass::analysis::value_diff_loop_index(
stmt->input, current_struct_for, k);
if (diff.linear_related() && diff.certain()) {
// case 1: last loop var, vectorized, has assumption on vec size
if (k == num_loop_vars - 1) {
auto load = Stmt::make<LoopIndexStmt>(current_struct_for, k);
load->ret_type = PrimitiveType::i32;
stmt->input = load.get();
int64 bound = 1LL << stmt->bit_end;
auto offset = (((int64)diff.low % bound + bound) % bound) &
~((1LL << (stmt->bit_begin)) - 1);
auto load_addr = load.get();
modifier.insert_before(stmt, std::move(load));
offset = diff.low; // TODO: Vectorization
if (stmt->bit_begin == 0 && bound == 1) { // TODO: Vectorization
// TODO: take care of cases where vectorization width != z
// dimension of the block
auto offset_stmt = Stmt::make<IntegerOffsetStmt>(stmt, offset);
stmt->replace_usages_with(offset_stmt.get());
// fix the offset stmt operand
offset_stmt->as<IntegerOffsetStmt>()->input = stmt;
modifier.insert_after(stmt, std::move(offset_stmt));
} else {
if (offset != 0) {
auto offset_const = Stmt::make<ConstStmt>(
TypedConstant(PrimitiveType::i32, offset));
auto sum = Stmt::make<BinaryOpStmt>(
BinaryOpType::add, load_addr, offset_const.get());
stmt->input = sum.get();
modifier.insert_before(stmt, std::move(offset_const));
modifier.insert_before(stmt, std::move(offset_const));
}
}
} else {
// insert constant
auto load = Stmt::make<LoopIndexStmt>(current_struct_for, k);
load->ret_type = PrimitiveType::i32;
auto constant = Stmt::make<ConstStmt>(TypedConstant(diff.low));
auto add = Stmt::make<BinaryOpStmt>(BinaryOpType::add, load.get(),
constant.get());
add->ret_type = PrimitiveType::i32;
stmt->input = add.get();
modifier.insert_before(stmt, std::move(load));
modifier.insert_before(stmt, std::move(constant));
modifier.insert_before(stmt, std::move(add));
}
stmt->simplified = true;
return;
}
}
}

set_done(stmt);
}
template <typename T>
static bool identical_vectors(const std::vector<T> &a,
const std::vector<T> &b) {
Expand Down
4 changes: 0 additions & 4 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,6 @@ class TypeCheck : public IRVisitor {
stmt->all_blocks_accept(this);
}

void visit(BitExtractStmt *stmt) override {
stmt->ret_type = stmt->input->ret_type;
}

void visit(LinearizeStmt *stmt) override {
stmt->ret_type = PrimitiveType::i32;
}
Expand Down