Skip to content

Commit

Permalink
[ir] Make lower_scalar_pointer testable (#2407)
Browse files Browse the repository at this point in the history
* [ir] Make lower_scalar_pointer testable

* comments

* Update taichi/analysis/arithmetic_interpretor.h

Co-authored-by: xumingkuan <xumingkuan0721@126.com>

Co-authored-by: xumingkuan <xumingkuan0721@126.com>
  • Loading branch information
k-ye and xumingkuan authored Jun 8, 2021
1 parent 15465fb commit 1aa4ed1
Show file tree
Hide file tree
Showing 8 changed files with 451 additions and 119 deletions.
33 changes: 33 additions & 0 deletions taichi/analysis/arithmetic_interpretor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,35 @@ class EvalVisitor : public IRVisitor {
}
}

void visit(BitExtractStmt *stmt) override {
auto val_opt = context_.maybe_get(stmt->input);
if (!val_opt) {
failed_ = true;
return;
}
const uint64_t mask = (1ULL << (stmt->bit_end - stmt->bit_begin)) - 1;
auto val = val_opt.value().val_int();
val = (val >> stmt->bit_begin) & mask;
insert_to_ctx(stmt, stmt->ret_type, val);
}

void visit(LinearizeStmt *stmt) override {
int64_t val = 0;
for (int i = 0; i < (int)stmt->inputs.size(); ++i) {
auto idx_opt = context_.maybe_get(stmt->inputs[i]);
if (!idx_opt) {
failed_ = true;
return;
}
val = (val * stmt->strides[i]) + idx_opt.value().val_int();
}
insert_to_ctx(stmt, stmt->ret_type, val);
}

void visit(Stmt *stmt) override {
if (context_.should_ignore(stmt)) {
return;
}
failed_ = (context_.maybe_get(stmt) == std::nullopt);
}

Expand Down Expand Up @@ -135,6 +163,11 @@ class EvalVisitor : public IRVisitor {
context_.insert(stmt, TypedConstant(dt, val_opt.value()));
}

template <typename T>
void insert_to_ctx(const Stmt *stmt, DataType dt, const T &val) {
context_.insert(stmt, TypedConstant(dt, val));
}

EvalContext context_;
bool failed_{false};
};
Expand Down
37 changes: 36 additions & 1 deletion taichi/analysis/arithmetic_interpretor.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <optional>
#include <unordered_set>
#include <unordered_map>

#include "taichi/ir/statements.h"
Expand All @@ -20,11 +21,23 @@ class ArithmeticInterpretor {
*/
class EvalContext {
public:
/**
* Pre-defines a value for statement @param s.
*
* @param s: Statement to be evaluated
* @param c: Predefined value
*/
EvalContext &insert(const Stmt *s, TypedConstant c) {
map_[s] = c;
return *this;
}

/**
* Tries to get the evaluated value for statement @param s.
*
* @param s: Statement to get
* @return: The evaluated value, empty if not found.
*/
std::optional<TypedConstant> maybe_get(const Stmt *s) const {
auto itr = map_.find(s);
if (itr == map_.end()) {
Expand All @@ -33,8 +46,30 @@ class ArithmeticInterpretor {
return itr->second;
}

/**
* Tells the interpretor to ignore statement @param s.
*
* This is effective only for statements that are not supported by
* ArithmeticInterpretor.
*
* @param s: Statement to ignore
*/
void ignore(const Stmt *s) {
ignored_.insert(s);
}

/**
* Checks if statement @param s is ignored.
*
* @return: True if ignored
*/
bool should_ignore(const Stmt *s) {
return ignored_.count(s) > 0;
}

private:
std::unordered_map<const Stmt *, TypedConstant> map_;
std::unordered_set<const Stmt *> ignored_;
};

/**
Expand All @@ -55,7 +90,7 @@ class ArithmeticInterpretor {
* Evaluates the sequence of CHI as defined in |region|.
* @param region: A sequence of CHI statements to be evaluated
* @param init_ctx: This context can mock the result for certain types of
* statements that are not supported, or cannot be evaluated statically.
* statements that are not supported, or cannot be evaluated statically.
*/
std::optional<TypedConstant> evaluate(const CodeRegion &region,
const EvalContext &init_ctx) const;
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void IRBuilder::reset() {
insert_point_.position = 0;
}

std::unique_ptr<IRNode> IRBuilder::extract_ir() {
std::unique_ptr<Block> IRBuilder::extract_ir() {
auto result = std::move(root_);
reset();
return result;
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class IRBuilder {
void reset();

// Extract the IR.
std::unique_ptr<IRNode> extract_ir();
std::unique_ptr<Block> extract_ir();

// General inserter. Returns stmt.get().
template <typename XStmt>
Expand Down Expand Up @@ -247,7 +247,7 @@ class IRBuilder {
}

private:
std::unique_ptr<IRNode> root_{nullptr};
std::unique_ptr<Block> root_{nullptr};
InsertPoint insert_point_;
};

Expand Down
Loading

0 comments on commit 1aa4ed1

Please sign in to comment.