Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] Draw nodes as record shape, allow embedding states into nodes #1876

Merged
merged 1 commit into from
Sep 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/taichi/misc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ def veci(*args, **kwargs):
return core_veci(*args, **kwargs)


def dump_dot(filepath=None, rankdir=None):
def dump_dot(filepath=None, rankdir=None, embed_states_threshold=0):
from taichi.core import ti_core
d = ti_core.dump_dot(rankdir)
d = ti_core.dump_dot(rankdir, embed_states_threshold)
if filepath is not None:
with open(filepath, 'w') as fh:
fh.write(d)
Expand Down
12 changes: 0 additions & 12 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,6 @@ uint64 hash(IRNode *stmt) {
return ret;
}

inline const SNode *get_snode_in_clear_list_task(const OffloadedStmt *task) {
TI_ASSERT(is_clear_list_task(task));
return task->body->back()->as<ClearListStmt>()->snode;
}

inline SNode *get_snode_in_clear_list_task(OffloadedStmt *task) {
// Avoid duplication: https://stackoverflow.com/a/123995/12003165
const auto *sn =
get_snode_in_clear_list_task(static_cast<const OffloadedStmt *>(task));
return const_cast<SNode *>(sn);
}

} // namespace

uint64 IRBank::get_hash(IRNode *ir) {
Expand Down
91 changes: 77 additions & 14 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ void StateFlowGraph::print() {
fmt::print("=======================\n");
}

std::string StateFlowGraph::dump_dot(
const std::optional<std::string> &rankdir) {
std::string StateFlowGraph::dump_dot(const std::optional<std::string> &rankdir,
int embed_states_threshold) {
using SFGNode = StateFlowGraph::Node;
using TaskType = OffloadedStmt::TaskType;
std::stringstream ss;
Expand All @@ -443,31 +443,78 @@ std::string StateFlowGraph::dump_dot(
// https://graphviz.org/doc/info/lang.html ID naming
return fmt::format("n_{}_{}", n->meta->name, n->launch_id);
};

auto escaped_label = [](const std::string &s) {
std::stringstream ss;
for (char c : s) {
// Braces, vertical bars ,angle brackets and spaces needs to be escaped.
// Just escape whitespaces for now...
if (c == ' ') {
ss << '\\';
}
ss << c;
}
return ss.str();
};
// Graph level configuration.
if (rankdir) {
ss << " rankdir=" << *rankdir << "\n";
}
ss << "\n";
// Specify the node styles
std::unordered_set<const SFGNode *> latest_state_nodes;
std::unordered_set<const SFGNode *> nodes_with_embedded_states;
// TODO: make this configurable
for (const auto &p : latest_state_owner_) {
latest_state_nodes.insert(p.second);
}
std::vector<const SFGNode *> nodes_with_no_inputs;
for (const auto &nd : nodes_) {
const auto *n = nd.get();
ss << " " << fmt::format("{} [label=\"{}\"", node_id(n), n->string());
if (nd->is_initial_node) {
ss << ",shape=box";
} else if (latest_state_nodes.find(n) != latest_state_nodes.end()) {
ss << ",peripheries=2";

std::stringstream labels;
if (!n->is_initial_node && !n->output_edges.empty() &&
(n->output_edges.size() < embed_states_threshold)) {
// Example:
//
// |-----------------------|
// | node foo |
// |-----------------------|
// | X_mask | X_value |
// |-----------------------|
//
// label={ node\ foo | { <X_mask> X_mask | <X_value> X_value } }
// See DOT node port...
labels << "{ " << escaped_label(n->string()) << " | { ";
const auto &edges = n->output_edges;
for (auto it = edges.begin(); it != edges.end(); ++it) {
if (it != edges.begin()) {
labels << " | ";
}
const auto name = it->first.name();
// Each state corresponds to one port
// "<port> displayed\ text"
labels << "<" << name << "> " << escaped_label(name);
}
labels << " } }";

nodes_with_embedded_states.insert(n);
} else {
// No states embedded.
labels << escaped_label(n->string());
}
ss << " "
<< fmt::format("{} [label=\"{}\" shape=record", node_id(n),
labels.str());
if (latest_state_nodes.find(n) != latest_state_nodes.end()) {
ss << " peripheries=2";
}
// Highlight user-defined tasks
const auto tt = nd->meta->type;
if (!nd->is_initial_node &&
(tt == TaskType::range_for || tt == TaskType::struct_for ||
tt == TaskType::serial)) {
ss << ",style=filled,fillcolor=lightgray";
ss << " style=filled fillcolor=lightgray";
}
ss << "]\n";
if (nd->input_edges.empty())
Expand All @@ -486,15 +533,29 @@ std::string StateFlowGraph::dump_dot(
for (const auto &p : from->output_edges) {
for (const auto *to : p.second) {
stack.push_back(to);
std::string style;

const bool states_embedded =
(nodes_with_embedded_states.find(from) !=
nodes_with_embedded_states.end());
std::string from_node_port = node_id(from);
std::stringstream attribs;
if (states_embedded) {
// The state is embedded inside the node. We draw the edge from
// the port corresponding to this state.
// Format is "{node}:{port}"
from_node_port += fmt::format(":{}", p.first.name());
} else {
// Show the state on the edge label
attribs << fmt::format("label=\"{}\"", p.first.name());
}

if (!from->has_state_flow(p.first, to)) {
style = "style=dotted";
attribs << " style=dotted";
}

ss << " "
<< fmt::format("{} -> {} [label=\"{}\" {}]", node_id(from),
node_id(to), p.first.name(), style)
<< fmt::format("{} -> {} [{}]", from_node_port, node_id(to),
attribs.str())
<< '\n';
}
}
Expand Down Expand Up @@ -780,9 +841,11 @@ void async_print_sfg() {
get_current_program().async_engine->sfg->print();
}

std::string async_dump_dot(std::optional<std::string> rankdir) {
std::string async_dump_dot(std::optional<std::string> rankdir,
int embed_states_threshold) {
// https://pybind11.readthedocs.io/en/stable/advanced/functions.html#allow-prohibiting-none-arguments
return get_current_program().async_engine->sfg->dump_dot(rankdir);
return get_current_program().async_engine->sfg->dump_dot(
rankdir, embed_states_threshold);
}

TLANG_NAMESPACE_END
11 changes: 9 additions & 2 deletions taichi/program/state_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,16 @@ class StateFlowGraph {

void print();

// Returns a string representing a DOT graph
// Returns a string representing a DOT graph.
//
// |embed_states_threshold|: We can choose to embed the states into the task
// node itself, if there aren't too many output states. This defines the
// maximum number of output states a task can have for the states to be
// embedded in the node.
//
// TODO: In case we add more and more DOT configs, create a struct?
std::string dump_dot(const std::optional<std::string> &rankdir);
std::string dump_dot(const std::optional<std::string> &rankdir,
int embed_states_threshold = 0);

void insert_task(const TaskLaunchRecord &rec);

Expand Down
6 changes: 4 additions & 2 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ TLANG_NAMESPACE_BEGIN

void async_print_sfg();

std::string async_dump_dot(std::optional<std::string> rankdir);
std::string async_dump_dot(std::optional<std::string> rankdir,
int embed_states_threshold);

std::string compiled_lib_dir;
std::string runtime_tmp_dir;
Expand Down Expand Up @@ -664,7 +665,8 @@ void export_lang(py::module &m) {
});

m.def("print_sfg", async_print_sfg);
m.def("dump_dot", async_dump_dot, py::arg("rankdir").none(true));
m.def("dump_dot", async_dump_dot, py::arg("rankdir").none(true),
py::arg("embed_states_threshold"));
}

TI_NAMESPACE_END