Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] Make lower_scalar_pointer testable #2407

Merged
merged 3 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions taichi/analysis/arithmetic_interpretor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,35 @@ class EvalVisitor : public IRVisitor {
}
}

void visit(BitExtractStmt *stmt) override {
auto val_opt = context_.maybe_get(stmt->input);
if (!val_opt) {
failed_ = true;
return;
}
const uint64_t mask = (1ULL << (stmt->bit_end - stmt->bit_begin)) - 1;
auto val = val_opt.value().val_int();
val = (val >> stmt->bit_begin) & mask;
insert_to_ctx(stmt, stmt->ret_type, val);
}

void visit(LinearizeStmt *stmt) override {
int64_t val = 0;
for (int i = 0; i < (int)stmt->inputs.size(); ++i) {
auto idx_opt = context_.maybe_get(stmt->inputs[i]);
if (!idx_opt) {
failed_ = true;
return;
}
val = (val * stmt->strides[i]) + idx_opt.value().val_int();
}
insert_to_ctx(stmt, stmt->ret_type, val);
}

void visit(Stmt *stmt) override {
if (context_.should_ignore(stmt)) {
return;
}
failed_ = (context_.maybe_get(stmt) == std::nullopt);
}

Expand Down Expand Up @@ -135,6 +163,11 @@ class EvalVisitor : public IRVisitor {
context_.insert(stmt, TypedConstant(dt, val_opt.value()));
}

template <typename T>
void insert_to_ctx(const Stmt *stmt, DataType dt, const T &val) {
context_.insert(stmt, TypedConstant(dt, val));
}

EvalContext context_;
bool failed_{false};
};
Expand Down
37 changes: 36 additions & 1 deletion taichi/analysis/arithmetic_interpretor.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <optional>
#include <unordered_set>
#include <unordered_map>

#include "taichi/ir/statements.h"
Expand All @@ -20,11 +21,23 @@ class ArithmeticInterpretor {
*/
class EvalContext {
public:
/**
* Pre-defines a value for statement @param s.
*
* @param s: Statement to be evaluated
* @param c: Predefined value
*/
EvalContext &insert(const Stmt *s, TypedConstant c) {
map_[s] = c;
return *this;
}

/**
* Tries to get the evaluated value for statement @param s.
*
* @param s: Statement to get
* @return: The evaluated value, empty if not found.
*/
std::optional<TypedConstant> maybe_get(const Stmt *s) const {
auto itr = map_.find(s);
if (itr == map_.end()) {
Expand All @@ -33,8 +46,30 @@ class ArithmeticInterpretor {
return itr->second;
}

/**
* Tells the interpretor to ignore statement @param s.
*
* This is effective only for statements that are not supported by
* ArithmeticInterpretor.
*
* @param s: Statemet to ignore
k-ye marked this conversation as resolved.
Show resolved Hide resolved
*/
void ignore(const Stmt *s) {
ignored_.insert(s);
}

/**
* Checks if statement @param s is ignored.
*
* @return: True if ignored
*/
bool should_ignore(const Stmt *s) {
return ignored_.count(s) > 0;
}

private:
std::unordered_map<const Stmt *, TypedConstant> map_;
std::unordered_set<const Stmt *> ignored_;
};

/**
Expand All @@ -55,7 +90,7 @@ class ArithmeticInterpretor {
* Evaluates the sequence of CHI as defined in |region|.
* @param region: A sequence of CHI statements to be evaluated
* @param init_ctx: This context can mock the result for certain types of
* statements that are not supported, or cannot be evaluated statically.
* statements that are not supported, or cannot be evaluated statically.
*/
std::optional<TypedConstant> evaluate(const CodeRegion &region,
const EvalContext &init_ctx) const;
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void IRBuilder::reset() {
insert_point_.position = 0;
}

std::unique_ptr<IRNode> IRBuilder::extract_ir() {
std::unique_ptr<Block> IRBuilder::extract_ir() {
auto result = std::move(root_);
reset();
return result;
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class IRBuilder {
void reset();

// Extract the IR.
std::unique_ptr<IRNode> extract_ir();
std::unique_ptr<Block> extract_ir();

// General inserter. Returns stmt.get().
template <typename XStmt>
Expand Down Expand Up @@ -235,7 +235,7 @@ class IRBuilder {
}

private:
std::unique_ptr<IRNode> root_{nullptr};
std::unique_ptr<Block> root_{nullptr};
InsertPoint insert_point_;
};

Expand Down
Loading