Skip to content

Commit

Permalink
[ir] Add ArithmeticInterpretor to evaluate a subset of CHI IR (#2342)
Browse files Browse the repository at this point in the history
* [ir] Add ArithmeticInterpretor to evaluate a subset of CHI IR

* fix
  • Loading branch information
k-ye authored May 14, 2021
1 parent 0a25bf5 commit c824a8b
Showing 3 changed files with 218 additions and 0 deletions.
152 changes: 152 additions & 0 deletions taichi/analysis/arithmetic_interpretor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#include "taichi/analysis/arithmetic_interpretor.h"

#include <algorithm>
#include <type_traits>
#include <vector>

#include "taichi/ir/type_utils.h"
#include "taichi/ir/visitors.h"

namespace taichi {
namespace lang {
namespace {

using CodeRegion = ArithmeticInterpretor::CodeRegion;
using EvalContext = ArithmeticInterpretor::EvalContext;

std::vector<Stmt *> get_raw_statements(const Block *block) {
const auto &stmts = block->statements;
std::vector<Stmt *> res(stmts.size());
std::transform(stmts.begin(), stmts.end(), res.begin(),
[](const std::unique_ptr<Stmt> &s) { return s.get(); });
return res;
}

class EvalVisitor : public IRVisitor {
public:
explicit EvalVisitor() {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}

std::optional<TypedConstant> run(const CodeRegion &region,
const EvalContext &init_ctx) {
context_ = init_ctx;
failed_ = false;

auto stmts = get_raw_statements(region.block);
if (stmts.empty()) {
return std::nullopt;
}
auto *begin_stmt = (region.begin == nullptr) ? stmts.front() : region.begin;
auto *end_stmt = (region.end == nullptr) ? stmts.back() : region.end;

auto cur_iter = std::find(stmts.begin(), stmts.end(), begin_stmt);
auto end_iter = std::find(stmts.begin(), stmts.end(), end_stmt);
if ((cur_iter == stmts.end()) || (end_iter == stmts.end())) {
return std::nullopt;
}
Stmt *cur_stmt = nullptr;
while (cur_iter != end_iter) {
cur_stmt = *cur_iter;
cur_stmt->accept(this);
if (failed_) {
return std::nullopt;
}
++cur_iter;
}
return context_.maybe_get(cur_stmt);
}

void visit(ConstStmt *stmt) override {
TI_ASSERT(stmt->val.size() == 1);
context_.insert(stmt, stmt->val.data[0]);
}

void visit(BinaryOpStmt *stmt) override {
auto lhs_opt = context_.maybe_get(stmt->lhs);
auto rhs_opt = context_.maybe_get(stmt->rhs);
if (!lhs_opt || !rhs_opt) {
failed_ = true;
return;
}
auto lhs = lhs_opt.value();
auto rhs = rhs_opt.value();
if (lhs.dt != rhs.dt) {
failed_ = true;
return;
}

const auto op = stmt->op_type;
const auto dt = lhs.dt;
// TODO: Consider using macros to avoid duplication
if (is_real(dt)) {
// Put floating point numbers first because is_signed/unsigned asserts
// that the data type being integral.
auto res_opt = eval_bin_op(lhs.val_float(), rhs.val_float(), op);
insert_or_failed(stmt, dt, res_opt);
} else if (is_signed(dt)) {
auto res_opt = eval_bin_op(lhs.val_int(), rhs.val_int(), op);
insert_or_failed(stmt, dt, res_opt);
} else if (is_unsigned(dt)) {
auto res_opt = eval_bin_op(lhs.val_uint(), rhs.val_uint(), op);
insert_or_failed(stmt, dt, res_opt);
} else {
TI_NOT_IMPLEMENTED;
failed_ = true;
}
}

void visit(Stmt *stmt) override {
failed_ = (context_.maybe_get(stmt) == std::nullopt);
}

private:
template <typename T>
static std::optional<T> eval_bin_op(T lhs, T rhs, BinaryOpType op) {
if (op == BinaryOpType::add) {
return lhs + rhs;
}
if (op == BinaryOpType::sub) {
return lhs - rhs;
}
if (op == BinaryOpType::mul) {
return lhs * rhs;
}
if (op == BinaryOpType::div) {
return lhs / rhs;
}
if constexpr (std::is_integral_v<T>) {
if (op == BinaryOpType::mod) {
return lhs % rhs;
}
}
return std::nullopt;
}

template <typename T>
void insert_or_failed(const Stmt *stmt,
DataType dt,
std::optional<T> val_opt) {
if (!val_opt) {
failed_ = true;
return;
}
context_.insert(stmt, TypedConstant(dt, val_opt.value()));
}

EvalContext context_;
bool failed_{false};
};

} // namespace

std::optional<TypedConstant> ArithmeticInterpretor::evaluate(
const CodeRegion &region,
const EvalContext &init_ctx) const {
EvalVisitor ev;
return ev.run(region, init_ctx);
}

} // namespace lang
} // namespace taichi
65 changes: 65 additions & 0 deletions taichi/analysis/arithmetic_interpretor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#pragma once

#include <optional>
#include <unordered_map>

#include "taichi/ir/statements.h"
#include "taichi/ir/type.h"

namespace taichi {
namespace lang {

/**
* Interprets a sequence of CHI IR statements within a block (acts like a
* VM based on CHI).
*/
class ArithmeticInterpretor {
public:
/**
* Evaluation context that maps from a Stmt to a constant value.
*/
class EvalContext {
public:
EvalContext &insert(const Stmt *s, TypedConstant c) {
map_[s] = c;
return *this;
}

std::optional<TypedConstant> maybe_get(const Stmt *s) const {
auto itr = map_.find(s);
if (itr == map_.end()) {
return std::nullopt;
}
return itr->second;
}

private:
std::unordered_map<const Stmt *, TypedConstant> map_;
};

/**
* Defines the region of CHI statments to be evaluated.
*/
struct CodeRegion {
// Defines the sequence of CHI statements.
Block *block{nullptr};
// The beginning statement within |block| to be evaluated. If nullptr,
// evaluates from the beginning of |block|.
Stmt *begin{nullptr};
// The ending statement (exclusive) within |block| to be evaluated. If
// nullptr, evaluates to the end of |block|.
Stmt *end{nullptr};
};

/**
* Evaluates the sequence of CHI as defined in |region|.
* @param region: A sequence of CHI statements to be evaluated
* @param init_ctx: This context can mock the result for certain types of
* statements that are not supported, or cannot be evaluated statically.
*/
std::optional<TypedConstant> evaluate(const CodeRegion &region,
const EvalContext &init_ctx) const;
};

} // namespace lang
} // namespace taichi
1 change: 1 addition & 0 deletions taichi/ir/type_utils.h
Original file line number Diff line number Diff line change
@@ -91,6 +91,7 @@ inline bool is_integral(DataType dt) {
}

inline bool is_signed(DataType dt) {
// Shall we return false if is_integral returns false?
TI_ASSERT(is_integral(dt));
if (auto t = dt->cast<CustomIntType>())
return t->get_is_signed();

0 comments on commit c824a8b

Please sign in to comment.