From 0431eac1ebe14560c0ea409f6113208121f0bb6c Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 19 Oct 2020 16:19:37 -0700 Subject: [PATCH 1/4] Faster pointwise fusion graph pass (#19269) * Faster pointwise fusion graph pass * Fix lint * Fix lint 2 * Fixes * Fixing slice parameter handling in fusion * Fixing the slice fix * Fix the cycle bug * Added test * Fix lint * Fix merging of subgraphs * Fixes from review --- src/executor/exec_pass.h | 16 +- src/executor/pointwise_fusion_pass.cc | 517 ++++++++++++---------- src/executor/simple_partition_pass.cc | 265 ++++++++++++ src/executor/simple_partition_pass.h | 599 +++++++++----------------- src/imperative/cached_op.h | 5 +- src/operator/fusion/fused_op.cu | 51 ++- tests/python/gpu/test_fusion.py | 20 + 7 files changed, 805 insertions(+), 668 deletions(-) create mode 100644 src/executor/simple_partition_pass.cc diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 4552fa173fe4..07ccc0257cda 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -222,22 +222,14 @@ Graph DetectInplaceAddTo(Graph g); Graph EliminateCommonExpr(Graph && g); /*! - * \brief Fuse pointwise operations in the forward pass. + * \brief Fuse pointwise operations in the graph. * * \param g input graph (needs to be entire graph, not just forward part) + * \param num_forward_outputs number of outputs in the graph produced by the forward pass * - * \return graph with fused pointwise operations in the forward pass + * \return copy of the graph with fused pointwise operations */ -Graph FusePointwiseForward(Graph&& g); - -/*! - * \brief Fuse pointwise operations in the backward pass. - * - * \param g input graph (needs to be entire graph, not just forward part) - * - * \return graph with fused pointwise operations in the backward pass - */ -Graph FusePointwiseBackward(Graph&& g); +Graph FusePointwise(const Graph& g, const size_t num_forward_outputs); /*! * \brief Issue a one-time warning that fusion is not possible for this platform or build. diff --git a/src/executor/pointwise_fusion_pass.cc b/src/executor/pointwise_fusion_pass.cc index 3203f67e8b68..aa139903339e 100644 --- a/src/executor/pointwise_fusion_pass.cc +++ b/src/executor/pointwise_fusion_pass.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include "./simple_partition_pass.h" #include "../operator/fusion/fused_op-inl.h" #include "../operator/fusion/fused_op.h" @@ -57,281 +58,321 @@ void WarnFusionNotSupported() { #if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC namespace { - bool IsFusionCompatible(nnvm::Node* n) { - using namespace mxnet::fusion; - if (n->op() == nullptr) - return false; - std::string op_name = n->op()->name; - if (ops_desc.count(op_name)) - return true; - if (slice_ops.count(op_name)) - return false; - if (std::find(variable_io_ops.begin(), - variable_io_ops.end(), - op_name) != - variable_io_ops.end()) - return true; - if (op_name == "LeakyReLU") { - std::string act_type = n->attrs.dict.at("act_type"); - if (LeakyReLU_ops.count(act_type)) - return true; - else - return false; - } - if (op_name == "_backward_LeakyReLU") { - std::string act_type = n->attrs.dict.at("act_type"); - if (LeakyReLU_bwd_ops.count(act_type)) - return true; - else - return false; - } + +bool IsFusionCompatible(const nnvm::Node* n) { + using namespace mxnet::fusion; + if (n->op() == nullptr) + return false; + const std::string& op_name = n->op()->name; + if (ops_desc.count(op_name)) + return true; + if (slice_ops.count(op_name)) return false; + if (std::find(variable_io_ops.begin(), + variable_io_ops.end(), + op_name) != + variable_io_ops.end()) + return true; + if (op_name == "LeakyReLU") { + std::string act_type = n->attrs.dict.at("act_type"); + if (LeakyReLU_ops.count(act_type)) + return true; + else + return false; } + if (op_name == "_backward_LeakyReLU") { + std::string act_type = n->attrs.dict.at("act_type"); + if (LeakyReLU_bwd_ops.count(act_type)) + return true; + else + return false; + } + return false; +} - bool IsInputsOnlyCompatible(nnvm::Node* n) { - using namespace mxnet::fusion; - if (n->op() == nullptr) - return false; - std::string op_name = n->op()->name; - if (slice_ops.count(op_name)) { - if (op_name == "slice") { - // slice with non-default step attribute is not supported - // currently - if (n->attrs.dict.count("step") && - !(n->attrs.dict.at("step") == "()" || - n->attrs.dict.at("step") == "[]")) { - return false; - } +bool IsInputsOnlyCompatible(const nnvm::Node* n) { + using namespace mxnet::fusion; + if (n->op() == nullptr) + return false; + const std::string& op_name = n->op()->name; + if (slice_ops.count(op_name)) { + if (op_name == "slice") { + // slice with non-default step attribute is not supported + // currently + if (n->attrs.dict.count("step") && + !(n->attrs.dict.at("step") == "()" || + n->attrs.dict.at("step") == "[]")) { + return false; } - return true; } - return false; + return true; } + return false; +} + +void CreateSubgraphNode(const nnvm::Graph& subgraph, + size_t inputs_size, + nnvm::Node* subgraph_node) { + static const Op* fused_op_ptr = Op::Get("_FusedOp"); + subgraph_node->attrs.subgraphs.emplace_back(std::make_shared()); + subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs; + subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size); + subgraph_node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size()); + subgraph_node->attrs.op = fused_op_ptr; + subgraph_node->op()->attr_parser(&(subgraph_node->attrs)); +} + +struct EntryInfo { + int source_node; + int index; +}; - nnvm::ObjectPtr CreateSubgraphNode(const Graph& subgraph, size_t inputs_size) { - nnvm::Symbol subgraph_sym; - auto node = nnvm::Node::Create(); - subgraph_sym.outputs = subgraph.outputs; - node->attrs.subgraphs.emplace_back(std::make_shared(subgraph_sym)); - node->attrs.name = "FusedOp"; - node->attrs.dict["num_inputs"] = std::to_string(inputs_size); - node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size()); - node->attrs.op = Op::Get("_FusedOp"); - node->op()->attr_parser(&(node->attrs)); - return node; +inline int SetInsert(const EntryInfo& new_elem, + std::vector* elements) { + for (size_t i = 0; i < elements->size(); ++i) { + if ((new_elem.source_node == elements->at(i).source_node) && + (new_elem.index == elements->at(i).index)) { + return i; + } } + elements->emplace_back(new_elem); + return elements->size() - 1; +} + } // namespace -/*! - * \brief Replace a set of nodes by a subgraph node. - * This function is used specifically in pointwise fusion. +/* \brief Create (if necessary) copy of the graph, replacing subgraphs with + * FusedOps. If there are no subgraphs to be replaced, the + * original graph is returned. + * \param g original graph. + * \param subgraph_assignment assignment of nodes in g's IndexedGraphs to + * subgraphs. Values from -1 to num_subgraphs - 1 + * are allowed, -1 means that the node is not in a + * subgraph. + * \param num_subgraphs number of subgraphs. + * \param create_subgraph_node function used to prepare the subgraph node. */ template -Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector& subgraph_sets, - FCreateNode create_subgraph_node) { - for (auto subgraph_set : subgraph_sets) { - // Create MXNet subgraph - Graph subgraph; - const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set); - subgraph.outputs.resize(sub_outputs_in_main.size()); - for (auto p : sub_outputs_in_main) { - subgraph.outputs[p.second] = p.first; - } - // To generate a subgraph an input has to be replaced by data node (no op) - // and it has to be agnostic to the node from which it's an output - // (For example, even if two inputs are two different outputs from the same node, - // they need to be replaced by two completely separate data nodes) - auto inputs = GetSubgraphInputs(subgraph, subgraph_set); - auto subgraph_node = create_subgraph_node(subgraph, inputs.size()); - subgraph_node->inputs = inputs; - // replug inputs of node out of subgraph to be output of the subgraph node - // if it was a node in the subgraph - DFSVisit(g.outputs, - [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const nnvm::ObjectPtr node) { - if (!subgraph_set.count(node.get())) { - for (auto &e : node->inputs) { - auto it = sub_outputs_in_main.find(e); - if (it != sub_outputs_in_main.end()) { - e.node = subgraph_node; - e.index = it->second; - } - } - } - }); - // replug outputs of the graph to be output of the subgraph node - // if it was a node in the subgraph - for (auto &e : g.outputs) { - auto it = sub_outputs_in_main.find(e); - if (it != sub_outputs_in_main.end()) { - e.node = subgraph_node; - e.index = it->second; - } +Graph CopyAndReplaceSubgraphs(const Graph& g, + const std::vector& subgraph_assignment, + const int num_subgraphs, + FCreateNode create_subgraph_node) { + if (num_subgraphs == 0) { + return g; + } + + Graph ret; + + const auto& idx = g.indexed_graph(); + + CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) << + "Every node in the graph needs to be included in subgraph assignment."; + + std::vector new_nodes; + new_nodes.reserve(idx.num_nodes()); + struct SubgraphInfo { + nnvm::Graph graph; + nnvm::ObjectPtr subgraph_node; + std::vector outputs; + std::vector inputs; + std::vector input_nodes; + }; + + std::vector subgraphs(num_subgraphs); + + for (auto& info : subgraphs) { + info.subgraph_node = nnvm::Node::Create(); + } + + for (size_t i = 0; i < idx.num_nodes(); ++i) { + // First copy the node, it will be used + // either in the new graph or inside a + // subgraph. Variables are not copied. + if (idx[i].source->op() != nullptr) { + new_nodes.emplace_back(nnvm::Node::Create()); + auto& node_copy = new_nodes.back(); + node_copy->attrs = idx[i].source->attrs; + node_copy->info = idx[i].source->info; + } else { + new_nodes.emplace_back(idx[i].weak_ref.lock()); + continue; } - // move control dependencies between nodes of the subgraph and out of the subgraph - // to a dependencies between the subgraph node and the nodes out of the subgraph - DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr& node) { - if (subgraph_set.count(node.get())) { - auto it = node->control_deps.begin(); - static auto& is_fusion = Op::GetAttr("TIsFusionHelper"); - std::vector new_control_deps; - // Use the first control dependency to get the inferattr helper - if (it != node->control_deps.end()) { - if (subgraph_set.count(it->get())) { - new_control_deps.push_back(*it); + auto& node_copy = new_nodes.back(); + const int subgraph_id = subgraph_assignment[i]; + if (subgraph_id != -1) { + auto& info = subgraphs[subgraph_id]; + for (const auto& input : idx[i].inputs) { + const int their_subgraph = subgraph_assignment[input.node_id]; + if (their_subgraph == subgraph_id) { + node_copy->inputs.emplace_back(new_nodes[input.node_id], + input.index, + input.version); + } else { + int input_num; + int output_num; + if (their_subgraph == -1) { + input_num = SetInsert({static_cast(input.node_id), + static_cast(input.index)}, &(info.inputs)); } else { - if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) { - uint32_t node_id = subgraph_node->control_deps.size(); - subgraph_node->control_deps.push_back(*it); - auto helper_node = op::MakeNode("_FusedOpOutHelper", - "FusedOp_" + node->attrs.name + "_outhelper", - nullptr, - nullptr, - nullptr); - helper_node->attrs.parsed = - FusedOpHelperParamPtr(new FusedOpHelperParam( - nnvm::get(subgraph_node->attrs.parsed), - node_id)); - new_control_deps.push_back(helper_node); + auto& their_subgraph_info = subgraphs[their_subgraph]; + output_num = SetInsert({static_cast(input.node_id), + static_cast(input.index)}, + &(their_subgraph_info.outputs)); + input_num = SetInsert({static_cast(idx.num_nodes() + their_subgraph), + output_num}, + &(info.inputs)); + } + if (static_cast(input_num) == info.input_nodes.size()) { + info.input_nodes.emplace_back(nnvm::Node::Create()); + info.input_nodes.back()->attrs.name = "input_" + std::to_string(input_num); + if (their_subgraph == -1) { + info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id], + input.index, + input.version); } else { - new_control_deps.push_back(*it); + info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node, + output_num, + input.version); } } - ++it; + node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0); } - node->control_deps = new_control_deps; } - }); - - std::ostringstream name_oss; - // the name of the new node will be the concatenation of all the node names in the subgraph - DFSVisit(subgraph.outputs, [&name_oss](const nnvm::ObjectPtr n) { - if (n->op() != nullptr) { - name_oss << n->op()->name << "_"; + } else { + for (const auto& input : idx[i].inputs) { + const int subgraph_id = subgraph_assignment[input.node_id]; + if (subgraph_id == -1) { + node_copy->inputs.emplace_back(new_nodes[input.node_id], + input.index, + input.version); + } else { + auto& info = subgraphs[subgraph_id]; + const int output_num = SetInsert({static_cast(input.node_id), + static_cast(input.index)}, + &(info.outputs)); + node_copy->inputs.emplace_back(info.subgraph_node, + output_num, + input.version); + } } - }); - auto subgraph_name = name_oss.str(); - subgraph_name.pop_back(); - subgraph_node->attrs.name = subgraph_name; + } - const auto& index = subgraph.indexed_graph(); - DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::ObjectPtr& node) { - for (auto &e : node->control_deps) { - if (subgraph_set.count(e.get())) { - uint32_t node_id = index.node_id(e.get()); - auto helper_node = op::MakeNode("_FusedOpHelper", - subgraph_node->attrs.name + "_" - + node->attrs.name + "_helper", - nullptr, - nullptr, - nullptr); - helper_node->attrs.parsed = - FusedOpHelperParamPtr(new FusedOpHelperParam( - nnvm::get(subgraph_node->attrs.parsed), - node_id)); - e = helper_node; - } + // Control deps + for (const auto& dep : idx[i].control_deps) { + if (subgraph_id == subgraph_assignment[dep]) { + node_copy->control_deps.emplace_back(new_nodes[dep]); } - }); + } } - Graph new_graph; - new_graph.outputs = g.outputs; - return new_graph; -} -/* \brief Add nodes as inputs to the subgraph. This is used for operations - * which are only compatible when they are the first nodes in the - * subgraph. - */ -template -void AddInputsOnlyCompatible(const Graph &g, - std::vector >* subsets, - IsCompatible is_compatible) { - std::unordered_map node2setidx; - size_t subgraphs_fullsize = 0; - for (auto& s : *subsets) { - subgraphs_fullsize += s.size(); - } - node2setidx.reserve(subgraphs_fullsize); - for (size_t i = 0; i < subsets->size(); ++i) { - for (auto& n : (*subsets)[i]) { - node2setidx.insert({n, i}); + ret.outputs.reserve(idx.outputs().size()); + for (const auto& output : idx.outputs()) { + const int subgraph_id = subgraph_assignment[output.node_id]; + if (subgraph_id == -1) { + ret.outputs.emplace_back(new_nodes[output.node_id], + output.index, + output.version); + } else { + const int output_num = SetInsert({static_cast(output.node_id), + static_cast(output.index)}, + &(subgraphs[subgraph_id].outputs)); + ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node, + output_num, + output.version); } } - std::vector > to_add(subsets->size()); - DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const nnvm::ObjectPtr& n) { - const auto& it = node2setidx.find(n.get()); - if (it != node2setidx.end()) { - for (auto& e : n->inputs) { - if (is_compatible(e.node.get())) - to_add[it->second].push_back(e.node.get()); - } + + for (auto& info : subgraphs) { + info.graph.outputs.reserve(info.outputs.size()); + for (const auto& entry_info : info.outputs) { + info.graph.outputs.emplace_back(new_nodes[entry_info.source_node], + entry_info.index, + 0); } - }); + create_subgraph_node(info.graph, info.inputs.size(), info.subgraph_node.get()); + } - // Avoid duplicating the node that is input of two subsets - std::unordered_set added; - for (size_t i = 0; i < subsets->size(); ++i) { - std::vector heads; - for (auto n : subsets->at(i)) { - for (auto e : n->inputs) { - if (!subsets->at(i).count(e.node.get())) - heads.push_back(e); + for (size_t i = 0; i < idx.num_nodes(); ++i) { + // Add _FusedOpHelper nodes + const int subgraph_id = subgraph_assignment[i]; + for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) { + const auto& dep = idx[i].control_deps[dep_num]; + const int their_subgraph_id = subgraph_assignment[dep]; + if (subgraph_id != -1 && their_subgraph_id == -1) { + // Not in any subgraph, use FusedOpOutHelper + auto& info = subgraphs[subgraph_id]; + size_t node_id = info.subgraph_node->control_deps.size(); + info.subgraph_node->control_deps.emplace_back(new_nodes[dep]); + auto helper_node = op::MakeNode("_FusedOpOutHelper", + "FusedOp_" + new_nodes[i]->attrs.name + "_outhelper", + nullptr, + nullptr, + nullptr); + helper_node->attrs.parsed = + FusedOpHelperParamPtr(new FusedOpHelperParam( + nnvm::get(info.subgraph_node->attrs.parsed), + node_id)); + new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + dep_num, + std::move(helper_node)); + } else if (their_subgraph_id != subgraph_id && + their_subgraph_id != -1) { + auto& info = subgraphs[their_subgraph_id]; + const auto& subgraph_idx = info.graph.indexed_graph(); + uint32_t node_id = subgraph_idx.node_id(new_nodes[dep].get()); + auto helper_node = op::MakeNode("_FusedOpHelper", + info.subgraph_node->attrs.name + "_" + + idx[i].source->attrs.name + "_helper", + nullptr, + nullptr, + nullptr); + helper_node->attrs.parsed = + FusedOpHelperParamPtr(new FusedOpHelperParam( + nnvm::get(info.subgraph_node->attrs.parsed), + node_id)); + new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + dep_num, + std::move(helper_node)); } } - for (size_t j = 0; j < to_add[i].size(); ++j) { - if (!added.count(to_add[i][j])) { - bool make_cycle = false; - const auto& node = to_add[i][j]; - std::vector _heads; - std::copy_if(heads.begin(), heads.end(), std::back_inserter(_heads), - [&node](const nnvm::NodeEntry& n) { - return n.node.get() != node; - }); - DFSVisit(_heads, [&make_cycle, &node](const nnvm::ObjectPtr& n) { - if (n.get() == node) - make_cycle = true; - }); - if (!make_cycle) { - (*subsets)[i].insert(to_add[i][j]); - added.insert(to_add[i][j]); + } + for (auto& info : subgraphs) { + const auto& idx = info.graph.indexed_graph(); + const auto& input_nodes = idx.input_nodes(); + std::vector subgraph_inputs; + subgraph_inputs.reserve(info.subgraph_node->inputs.size()); + for (const int input : input_nodes) { + for (size_t i = 0; i < info.input_nodes.size(); ++i) { + const auto& input_ptr = info.input_nodes[i].get(); + if (input_ptr == idx[input].source) { + subgraph_inputs.emplace_back(info.subgraph_node->inputs[i]); } } } + info.subgraph_node->inputs.swap(subgraph_inputs); + std::string name; + for (size_t i = 0; i < idx.num_nodes(); ++i) { + if (idx[i].source->op() != nullptr) { + name += idx[i].source->op()->name + "_"; + } + } + info.subgraph_node->attrs.name = name; } -} - -Graph FusePointwiseForward(Graph &&g) { - Graph ret; - g.indexed_graph(); - const auto& num_forward_outputs = g.GetAttr("num_forward_outputs"); - Graph fg; - fg.outputs.insert(fg.outputs.begin(), g.outputs.begin(), - g.outputs.begin() + num_forward_outputs); - auto subsets = GetCompatibleSubsets(fg, IsFusionCompatible); - AddInputsOnlyCompatible(fg, &subsets, IsInputsOnlyCompatible); - g = ReplaceSubgraphsPointwise(std::move(g), subsets, CreateSubgraphNode); - ret.outputs = g.outputs; return ret; } -Graph FusePointwiseBackward(Graph &&g) { - Graph ret; - g.indexed_graph(); - const auto& num_forward_outputs = g.GetAttr("num_forward_outputs"); - Graph fg; - fg.outputs.insert(fg.outputs.begin(), g.outputs.begin(), - g.outputs.begin() + num_forward_outputs); - std::unordered_set exclusion_set; - DFSVisit(fg.outputs, [&exclusion_set](const nnvm::ObjectPtr& n) { - exclusion_set.insert(n.get()); - }); - auto subsets = GetCompatibleSubsets(g, [&exclusion_set](nnvm::Node* n) { - if (exclusion_set.count(n)) - return false; - return IsFusionCompatible(n); - }); - g = ReplaceSubgraphsPointwise(std::move(g), subsets, CreateSubgraphNode); - ret.outputs = g.outputs; +Graph FusePointwise(const Graph &g, const size_t num_forward_outputs) { + auto start = std::chrono::steady_clock::now(); + auto [subset_assignment, num_subsets] = GetCompatibleSubsets(g, num_forward_outputs, // NOLINT(*) + IsFusionCompatible, + IsInputsOnlyCompatible); + Graph ret = CopyAndReplaceSubgraphs(g, subset_assignment, num_subsets, + CreateSubgraphNode); + auto end = std::chrono::steady_clock::now(); + if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) { + auto diff = end - start; + LOG(INFO) << "Pointwise fusion graph pass took: " + << std::chrono::duration(diff).count() + << "ms."; + } return ret; } #endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC diff --git a/src/executor/simple_partition_pass.cc b/src/executor/simple_partition_pass.cc new file mode 100644 index 000000000000..941959d4bb45 --- /dev/null +++ b/src/executor/simple_partition_pass.cc @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file simple_partition_pass.cc + * \brief Utilities used in simple partition pass + * \author Przemyslaw Tredak + */ + +#include "./simple_partition_pass.h" +#include +#include + +namespace mxnet { +namespace exec { + +namespace detail { + +const IntervalVec* LargerSet(const IntervalVec* const first, + const IntervalVec* const second) noexcept { + const IntervalVec* ret = nullptr; + auto first_iter = first->begin(); + auto second_iter = second->begin(); + while (first_iter != first->end() && + second_iter != second->end()) { + if (*first_iter == *second_iter) { + ++first_iter; + ++second_iter; + } else { + // Entry in first set not seen in the second set + if (first_iter->second < second_iter->first) { + if (ret == first || ret == nullptr) { + ret = first; + ++first_iter; + } else { + return nullptr; + } + continue; + } + // Entry in second set not seen in the first set + if (second_iter->second < first_iter->first) { + if (ret == second || ret == nullptr) { + ret = second; + ++second_iter; + } else { + return nullptr; + } + continue; + } + // Entry in first set fully encloses the entry in the second set + if (first_iter->first <= second_iter->first && + first_iter->second >= second_iter->second) { + if (ret == first || ret == nullptr) { + ret = first; + ++second_iter; + } else { + return nullptr; + } + continue; + } + // Entry in second set fully encloses the entry in the first set + if (second_iter->first <= first_iter->first && + second_iter->second >= first_iter->second) { + if (ret == second || ret == nullptr) { + ret = second; + ++first_iter; + } else { + return nullptr; + } + continue; + } + // Entries intersect but one is not fully enclosed in the other + return nullptr; + } + } + if (ret == nullptr) { + // The common part is the same + return second_iter == second->end() ? first : second; + } else { + if ((ret == first && second_iter == second->end()) || + (ret == second && first_iter == first->end())) { + return ret; + } + } + return nullptr; +} + +void MergeSets(const IntervalVec** const my_set, + const IntervalVec* const other_set, + std::vector>* const storage) noexcept { + if ((*my_set == nullptr) || (*my_set)->size() == 0) { + *my_set = other_set; + return; + } + if (other_set == nullptr || other_set->size() == 0) { + return; + } + auto* larger_set = LargerSet(*my_set, other_set); + if (larger_set != nullptr) { + *my_set = larger_set; + return; + } + auto my_iter = (*my_set)->cbegin(); + auto other_iter = other_set->cbegin(); + auto new_set = IntervalVec(); + int last_end = -10; // less than -1 + while (my_iter != (*my_set)->cend() && + other_iter != other_set->cend()) { + const auto& mine = *my_iter; + const auto& other = *other_iter; + if (other.second < mine.first - 1) { + // other interval is before ours + if (last_end >= other.first - 1) { + new_set.back().second = other.second; + } else { + new_set.emplace_back(other); + } + last_end = other.second; + ++other_iter; + } else if (other.first > mine.second + 1) { + // other interval is after ours + if (last_end >= mine.first - 1) { + new_set.back().second = mine.second; + } else { + new_set.emplace_back(mine); + } + last_end = mine.second; + ++my_iter; + } else { + // Intervals can be merged together + Interval n(std::min(mine.first, other.first), + std::max(mine.second, other.second)); + if (last_end >= n.first - 1) { + new_set.back().second = n.second; + } else { + new_set.emplace_back(n); + } + last_end = n.second; + if (other.second >= mine.second) { + ++my_iter; + } + if (mine.second >= other.second) { + ++other_iter; + } + } + } + auto remaining_iter = my_iter == (*my_set)->cend() ? other_iter : my_iter; + auto remaining_end = my_iter == (*my_set)->cend() ? other_set->cend() : (*my_set)->cend(); + // Add the rest of entries + for (; remaining_iter != remaining_end; ++remaining_iter) { + auto& mine = new_set.back(); + const auto& other = *remaining_iter; + if (other.second < mine.first - 1) { + // other interval is before ours, should never happen + continue; + } else if (other.first > mine.second + 1) { + // other interval is after ours + new_set.emplace_back(other); + } else { + // Intervals can be merged together + mine.first = std::min(mine.first, other.first); + mine.second = std::max(mine.second, other.second); + } + } + storage->emplace_back(std::make_unique(std::move(new_set))); + *my_set = storage->back().get(); +} + +bool Intersect(const IntervalVec& checked_sets, + const IntervalVec& excluded_sets) noexcept { + size_t current_interval = 0, current_other_interval = 0; + while (current_interval < checked_sets.size() && + current_other_interval < excluded_sets.size()) { + const auto& mine = checked_sets[current_interval]; + const auto& other = excluded_sets[current_other_interval]; + if (other.second < mine.first) { + // other interval is before ours + ++current_other_interval; + } else if (other.first > mine.second) { + // other interval is after ours + ++current_interval; + } else { + // Intervals intersect + return true; + } + } + return false; +} + +void AddSet(const IntervalVec** const sets, const int set_to_add, + std::vector>* const storage) noexcept { + if (*sets != nullptr && (*sets)->size() != 0) { + for (auto& interval : (**sets)) { + if (set_to_add >= interval.first && + set_to_add <= interval.second) { + return; + } + } + } + storage->emplace_back( + std::make_unique(1, std::make_pair(set_to_add, set_to_add))); + MergeSets(sets, storage->back().get(), storage); +} + +int GetSetMapping(const int set, std::vector* const set_mapping) noexcept { + if (set == -1) return -1; + int temp = set; + while ((*set_mapping)[temp] != temp) { + temp = (*set_mapping)[temp]; + } + (*set_mapping)[set] = temp; + return temp; +} + +void CheckAndUpdateCombinedExcludedSets(const IntervalVec** const combined_excluded_sets_ptr, + const IntervalVec* const new_excluded_sets, + std::vector* const excluded_sets_ptr, + const int set_id, + const int first_node_in_set, + const size_t new_node_id, + const std::vector& set_assignment, + std::vector* const set_mapping_ptr, + const IntervalVec& inverse_set_mapping, + std::vector>* const + storage) noexcept { + const auto* previous_excluded_sets = *combined_excluded_sets_ptr; + MergeSets(combined_excluded_sets_ptr, new_excluded_sets, storage); + if (new_excluded_sets != nullptr) { + if (previous_excluded_sets == nullptr || + *previous_excluded_sets != **(combined_excluded_sets_ptr)) { + // Their set's excluded sets list got larger, need to update the descendants + // of their set + auto& excluded_sets = *excluded_sets_ptr; + for (size_t j = first_node_in_set; j < new_node_id; ++j) { + if (GetSetMapping(set_assignment[j], set_mapping_ptr) == set_id || + (excluded_sets[j] != nullptr && + Intersect(inverse_set_mapping, *excluded_sets[j]))) { + MergeSets(&excluded_sets[j], *combined_excluded_sets_ptr, storage); + } + } + } + } +} + +} // namespace detail + +} // namespace exec +} // namespace mxnet diff --git a/src/executor/simple_partition_pass.h b/src/executor/simple_partition_pass.h index 1ca0086dbc53..764486f8e6c1 100644 --- a/src/executor/simple_partition_pass.h +++ b/src/executor/simple_partition_pass.h @@ -18,10 +18,10 @@ */ /*! - * Copyright (c) 2019 by Contributors + * Copyright (c) 2019-2020 by Contributors * \file simple_partition_pass.h * \brief Simple pass for partitioning a graph. - * \author Clement Fuji Tsang + * \author Clement Fuji Tsang, Przemyslaw Tredak */ #ifndef MXNET_EXECUTOR_SIMPLE_PARTITION_PASS_H_ #define MXNET_EXECUTOR_SIMPLE_PARTITION_PASS_H_ @@ -34,440 +34,237 @@ #include #include #include +#include #include "exec_pass.h" namespace mxnet { namespace exec { +namespace detail { -/*! - * \brief Custom graph class, which contains bi-directional nodes - * required for traversing in both directions (from outputs to inputs - * and vice versa). It is a non-owning layer on top of NNVM graph, since - * NNVM graph enables traversing only in 1 direction (from outputs to inputs). +using Interval = std::pair; +using IntervalVec = std::vector; + +/* \brief Return the set that fully contains the other set, or nullptr + * if neither set is a subset of another. */ -class BidirectionalGraph { - public: - struct Node { - nnvm::Node* nnvmptr; - std::vector inputs; - std::vector outputs; - }; +const IntervalVec* LargerSet(const IntervalVec* const first, + const IntervalVec* const second) noexcept; - explicit BidirectionalGraph(const Graph &g) { - auto& idx = g.indexed_graph(); - auto num_nodes = idx.num_nodes(); - nodes.reserve(num_nodes); - nnvm2nid.reserve(num_nodes); - outputs.reserve(idx.outputs().size()); - // Create all the nodes in a new graph from - // nodes in the NNVM graph and store them - // in nodes array - DFSVisit(g.outputs, [this](const nnvm::ObjectPtr& n) { - Node new_node; - new_node.nnvmptr = n.get(); - nnvm2nid[n.get()] = static_cast(nodes.size()); - nodes.emplace_back(std::move(new_node)); - }); - // Create all connections between nodes in - // the graph (both directions) - for (const auto& it : nnvm2nid) { - nnvm::Node* nnvmnode = it.first; - uint32_t nid = it.second; - for (auto& n : nnvmnode->inputs) { - uint32_t input_nid = nnvm2nid[n.node.get()]; - nodes[input_nid].outputs.emplace_back(&nodes[nid]); - nodes[nid].inputs.emplace_back(&nodes[input_nid]); - } - } - // Create output connections from the graph - for (auto& e : g.outputs) { - uint32_t nid = nnvm2nid[e.node.get()]; - outputs.emplace_back(&nodes[nid]); - } - } +/* \brief Compute the sum of the 2 sets and store it in my_set. + */ +void MergeSets(const IntervalVec** const my_set, + const IntervalVec* const other_set, + std::vector>* const storage) noexcept; - /* \brief Get all subsets of nodes, where: - * - graph constructed from nodes in each subset is a connected graph - * - every node fulfills a predicate is_compatible - * - if nodes u and v are part of a subset, then for each path between - * u and v in the original directed graph, all nodes on those paths - * are also part of the subset - * \param is_compatible A function taking nnvm::Node* and returning bool - * which identifies which nodes should be included in - * subsets. - */ - template - std::vector> get_subsets(FCompatible is_compatible) { - std::vector> subgraphs; - std::unordered_set incomp_set; - std::vector> separation_sets; - // Check each node for compatibility - // and, if it is incompatible, mark nodes - // on each side of it as not possible to be - // in the same subset - for (Node& node : nodes) { - if (!is_compatible(node.nnvmptr)) { - incomp_set.insert(&node); - } - } - for (Node& node : nodes) { - if (incomp_set.count(&node) != 0) { - // Check if all your inputs are incompatible too. - // If so, then your separation set does not matter, - // because it will covered by the sets of your inputs - bool inside_node = true; - for (Node* input : node.inputs) { - if (incomp_set.count(input) == 0) { - inside_node = false; - } - } - if (!inside_node) { - std::unordered_set in_graph; - std::unordered_set out_graph; - std::vector dummy_head; - dummy_head.emplace_back(&node); - DFS(dummy_head, false, [&out_graph](Node* node) { - out_graph.insert(node); - }); - DFS(dummy_head, true, [&in_graph](Node* node) { - in_graph.insert(node); - }); - separation_sets.push_back(std::make_pair(true, - std::make_pair(in_graph, out_graph))); - } else { - separation_sets.push_back(std::make_pair(false, PairSet())); - } - } else { - separation_sets.push_back(std::make_pair(false, PairSet())); - } - } - IncompMap incomp_map; - // For each node construct the map of nodes that cannot be in - // the same subset - index_t num_nodes = nodes.size(); - for (index_t i = 0; i < num_nodes; ++i) { - const auto n = &(nodes[i]); - if (incomp_set.count(n) == 0) { - for (index_t j = i + 1; j < num_nodes; ++j) { - const auto& sep_set_pair = separation_sets[j]; - if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) { - const auto& p = sep_set_pair.second; - if (p.first.count(n)) { - incomp_map[n].insert(p.second.begin(), p.second.end()); - } else if (p.second.count(n)) { - incomp_map[n].insert(p.first.begin(), p.first.end()); - } - } - } - for (index_t j = i - 1; j >= 0; --j) { - const auto& sep_set_pair = separation_sets[j]; - if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) { - const auto& p = sep_set_pair.second; - if (p.first.count(n)) { - incomp_map[n].insert(p.second.begin(), p.second.end()); - } else if (p.second.count(n)) { - incomp_map[n].insert(p.first.begin(), p.first.end()); - } - } - } - for (Node* incomp_n : incomp_set) { - incomp_map[n].erase(incomp_n); - } - } - } - std::unordered_set unused_set; +/* \brief Returns true if there is non-empty intersection + * between the 2 sets. + */ +bool Intersect(const IntervalVec& checked_sets, + const IntervalVec& excluded_sets) noexcept; - for (auto& n : nodes) { - if (incomp_set.count(&n) == 0) { - unused_set.insert(&n); - } - } - std::unordered_set visited; - std::deque stack(outputs.begin(), outputs.end()); - // Create subsets - while (!stack.empty()) { - Node* vertex = stack.front(); - stack.pop_front(); - if (!visited.count(vertex)) { - visited.insert(vertex); - if (unused_set.count(vertex)) { - subgraphs.emplace_back(naive_grow_subgraph(vertex, &unused_set, &incomp_map)); - } - for (Node* input : vertex->inputs) { - stack.emplace_back(input); - } - } - } - return subgraphs; - } +/* \brief Add a single entry to the sets. + */ +void AddSet(const IntervalVec** const sets, const int set_to_add, + std::vector>* const storage) noexcept; - private: - using PairSet = std::pair, std::unordered_set>; - using PairVec = std::pair, std::vector>; - using IncompMap = std::unordered_map>; +/* \brief Get the true mapping of the set (which could change + * due to merging of multiple sets. + */ +int GetSetMapping(const int set, std::vector* const set_mapping) noexcept; - /* \brief Traverse the graph using DFS in either direction. - * \param heads Starting nodes for the DFS algorithm. - * \param reverse If true, DFS will traverse the graph from - * outputs to inputs. Otherwise, it will - * traverse the graph from inputs to outputs. - * \param fvisit Function to call on each visisted node. - */ - template - void DFS(const std::vector& heads, bool reverse, FVisit fvisit) { - std::unordered_set visited; - std::vector vec(heads.begin(), heads.end()); - visited.reserve(heads.size()); - while (!vec.empty()) { - Node* vertex = vec.back(); - vec.pop_back(); - if (visited.count(vertex) == 0) { - visited.insert(vertex); - fvisit(vertex); - std::vector nexts = reverse ? vertex->inputs : vertex->outputs; - for (Node* node : nexts) { - if (visited.count(node) == 0) { - vec.emplace_back(node); - } - } - } - } - } +/* \brief Check if 2 ids are on the same side of the cutoff + * (so either both on the FWD side or the BWD side). + */ +inline bool IsSamePass(const int my_id, const int their_id, const int cutoff) noexcept { + return (my_id > cutoff && their_id > cutoff) || + (my_id <= cutoff && their_id <= cutoff); +} - /* \brief Get the connected subgraph that contains the head node, - * only previously unused nodes, according to the rules - * from incompatibility map. - * \param head Node which needs to be part of the returned subgraph. - * \param unused_set Only nodes from this set will be considered when - * adding to the growing subgraph. - * \param incomp_map Map containing data on which nodes are incompatible - * to be in the same subgraph. - */ - std::unordered_set naive_grow_subgraph(Node* head, - std::unordered_set* unused_set, - IncompMap* incomp_map) { - std::unordered_set subgraph; - std::unordered_set incomp_set; - std::deque stack; - stack.emplace_back(head); - while (!stack.empty()) { - Node* vertex = stack.back(); - stack.pop_back(); - if (unused_set->count(vertex) && !incomp_set.count(vertex)) { - unused_set->erase(vertex); - subgraph.insert(vertex); - incomp_set.insert((*incomp_map)[vertex].begin(), (*incomp_map)[vertex].end()); - // Traverse the grpah in both directions - for (Node* input : vertex->inputs) { - if (unused_set->count(input) && !incomp_set.count(input)) { - stack.emplace_back(input); - } - } - for (Node* output : vertex->outputs) { - if (unused_set->count(output) && !incomp_set.count(output)) { - stack.emplace_back(output); - } - } - } - } - return subgraph; - } +/* \brief Check if adding a new node to the set changes the excluded set of the future + * fused node. If so, update all descendants of the fused node. + * + * \param combined_excluded_sets_ptr pointer to the set's list of excluded sets + * before adding the new node + * \param new_excluded_sets list of excluded sets of the new node + * \param excluded_sets_ptr pointer to the lists of excluded sets of all the nodes + * \param set_id number of the set, to which the new node is added + * \param first_node_in_set id of the first node in the set, according to topological ordering + * \param new_node_id id of the node added to the set + * \param set_assignment assignment of sets + * \param set_mapping_ptr pointer to the mappings of sets + * \param inverse_set_mapping inverse mapping of the set + * \param storage memory storage + */ +void CheckAndUpdateCombinedExcludedSets(const IntervalVec** const combined_excluded_sets_ptr, + const IntervalVec* const new_excluded_sets, + std::vector* const excluded_sets_ptr, + const int set_id, + const int first_node_in_set, + const size_t new_node_id, + const std::vector& set_assignment, + std::vector* const set_mapping_ptr, + const IntervalVec& inverse_set_mapping, + std::vector>* const + storage) noexcept; - friend class Graph; +} // namespace detail - std::vector nodes; - std::unordered_map nnvm2nid; - std::vector outputs; -}; // class BidirectionalGraph -using NodeEntrySet = std::unordered_set; -using NodeRawPtrSet = std::unordered_set; +/* \brief Get all subsets of nodes, where: + * - graph constructed from nodes in each subset is a connected graph + * - every node fulfills a predicate is_compatible + * - if nodes u and v are part of a subset, then for each path between + * u and v in the original directed graph, all nodes on those paths + * are also part of the subset + * \param g NNVM graph + * \param num_forward_outputs Number of outputs from the graph that come + * from the forward pass + * \param is_compatible A function taking nnvm::Node* and returning bool + * which identifies which nodes could be included in + * subsets. + * \param is_input_only_compatible A function taking nnvm::Node* and + * returning bool which identifies which + * nodes could be included in subsets only + * as the first operations (their inputs + * need to be excluded). + * \return tuple (subset assignment, number of found subsets) + */ +template +std::tuple, int> GetCompatibleSubsets( + const Graph& g, + const size_t num_forward_outputs, + FCompatible is_compatible, + FInputOnlyCompatible is_input_only_compatible) { -/*! - * \brief Get the output nodes of the subgraph in the main graph. - * \return a map between the node in the main graph and the output index of the subgraph node -*/ -nnvm::NodeEntryMap GetSubgraphOutputs(Graph g, NodeRawPtrSet subgraph_set) { - nnvm::NodeEntryMap outputs; - uint32_t count = 0; - for (auto& e : g.outputs) { - if (subgraph_set.count(e.node.get()) && !outputs.count(e)) { - outputs.insert({e, count++}); + using namespace detail; + const auto& idx = g.indexed_graph(); + std::vector set_assignment(idx.num_nodes(), -1); + std::vector*> excluded_sets(idx.num_nodes()); + std::vector set_mapping; + std::vector*> combined_excluded_sets; + std::vector first_node_in_set; + std::vector*> inverse_set_mapping; + std::vector>> storage; + + int last_forward_node = -1; + for (size_t i = 0; i < num_forward_outputs; ++i) { + const int output_id = idx.outputs()[i].node_id; + if (last_forward_node < output_id) { + last_forward_node = output_id; } } - DFSVisit(g.outputs, [&subgraph_set, &outputs, &count](const nnvm::ObjectPtr &node){ - if (!subgraph_set.count(node.get())) { - for (auto& e : node->inputs) { - if (subgraph_set.count(e.node.get()) && !outputs.count(e)) { - outputs.insert({e, count++}); - } - } - } - }); - return outputs; -} -/*! - * \brief Create new input nodes of the subgraph and plug them. - * \return the inputs of the subgraph node in the main graph -*/ -std::vector GetSubgraphInputs(Graph g, NodeRawPtrSet subgraph_set) { - std::vector inputs; - nnvm::NodeEntryMap entry_map; - DFSVisit(g.outputs, [&subgraph_set, &inputs, &entry_map](const nnvm::ObjectPtr &node){ - if (subgraph_set.count(node.get())) { - for (auto &e : node->inputs) { - if (!subgraph_set.count(e.node.get())) { - if (entry_map.count(e)) { - e = entry_map[e]; + int num_sets = 0; + for (size_t i = 0; i < idx.num_nodes(); ++i) { + const auto& node = idx[i]; + auto& my_excluded_sets = excluded_sets[i]; + for (const auto& input : node.inputs) { + MergeSets(&my_excluded_sets, excluded_sets[input.node_id], &storage); + } + if (is_compatible(node.source)) { + int my_set = -1; + for (const auto& input : node.inputs) { + int their_set = GetSetMapping(set_assignment[input.node_id], &set_mapping); + if (their_set != -1 && + their_set != my_set && + IsSamePass(i, input.node_id, last_forward_node) && + (my_excluded_sets == nullptr || + !Intersect(*inverse_set_mapping[their_set], *my_excluded_sets))) { + if (my_set == -1) { + my_set = their_set; + CheckAndUpdateCombinedExcludedSets(&(combined_excluded_sets[their_set]), + my_excluded_sets, + &excluded_sets, + their_set, + first_node_in_set[their_set], + i, + set_assignment, + &set_mapping, + *(inverse_set_mapping[their_set]), + &storage); } else { - auto new_node = nnvm::Node::Create(); - new_node->attrs.name = "input_" + std::to_string(inputs.size()); - entry_map.insert({e, nnvm::NodeEntry{new_node, 0, 0}}); - inputs.push_back(e); - e.node = new_node; - e.index = 0; + MergeSets(&inverse_set_mapping[my_set], + inverse_set_mapping[their_set], + &storage); + set_mapping[their_set] = my_set; + first_node_in_set[my_set] = std::min(first_node_in_set[my_set], + first_node_in_set[their_set]); + CheckAndUpdateCombinedExcludedSets(&(combined_excluded_sets[their_set]), + combined_excluded_sets[my_set], + &excluded_sets, + my_set, + first_node_in_set[my_set], + i, + set_assignment, + &set_mapping, + *(inverse_set_mapping[my_set]), + &storage); } } } + if (my_set == -1) { + set_mapping.emplace_back(num_sets); + combined_excluded_sets.emplace_back(my_excluded_sets); + first_node_in_set.emplace_back(i); + storage.emplace_back(std::make_unique>( + 1, std::make_pair(num_sets, + num_sets))); + inverse_set_mapping.emplace_back(storage.back().get()); + my_set = num_sets++; + } + set_assignment[i] = my_set; + } else { + for (const auto& input : node.inputs) { + int their_set = GetSetMapping(set_assignment[input.node_id], &set_mapping); + if (their_set != -1) { + AddSet(&my_excluded_sets, their_set, &storage); + } + } + if ((is_input_only_compatible != nullptr) && + is_input_only_compatible(node.source)) { + set_mapping.emplace_back(num_sets); + combined_excluded_sets.emplace_back(my_excluded_sets); + first_node_in_set.emplace_back(i); + storage.emplace_back(std::make_unique>( + 1, std::make_pair(num_sets, + num_sets))); + inverse_set_mapping.emplace_back(storage.back().get()); + set_assignment[i] = num_sets++; + } } - }); - // Fix ordering of w.r.t to topology - Graph _g; - _g.outputs = g.outputs; - const auto &idx = _g.indexed_graph(); - std::sort(inputs.begin(), inputs.end(), - [&idx, &entry_map](const nnvm::NodeEntry lhs, const nnvm::NodeEntry rhs) { - return idx.entry_id(entry_map.at(lhs)) < idx.entry_id(entry_map.at(rhs)); - }); - return inputs; -} - -std::unordered_map GetGraphInputsMap(const Graph& g) { - std::unordered_map outputs; - auto& idx = g.indexed_graph(); - outputs.reserve(idx.num_nodes()); - std::vector input_nodes = idx.input_nodes(); - for (size_t i = 0; i < input_nodes.size(); ++i) { - outputs[input_nodes[i]] = static_cast(i); } - return outputs; -} -/*! - * \brief Helper function to display what nodes are in a specific subset. - */ -void dispNodesSet(Graph g, NodeRawPtrSet s) { - DFSVisit(g.outputs, [&s](const nnvm::ObjectPtr n){ - if (s.count(n.get())) { - std::cout << " Y " << n->attrs.name << std::endl; - } else { - std::cout << " N " << n->attrs.name << std::endl; - } - }); -} + for (int& set : set_assignment) { + set = GetSetMapping(set, &set_mapping); + } -/*! - * \brief Replace a set of nodes by a subgraph node. - */ -template -Graph ReplaceSubgraphs(Graph&& g, const std::vector& subgraph_sets, - FCreateNode create_subgraph_node) { - for (auto subgraph_set : subgraph_sets) { - // Create MXNet subgraph - Graph subgraph; - const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set); - subgraph.outputs.resize(sub_outputs_in_main.size()); - for (auto p : sub_outputs_in_main) { - subgraph.outputs[p.second] = p.first; + std::vector set_reorder(num_sets, 0); + // First count the number of elements in each set. + for (int& set : set_assignment) { + if (set != -1) { + ++set_reorder[set]; } - // To generate a subgraph an input has to be replaced by data node (no op) - // and it has to be agnostic to the node from which it's an output - // (For example, even if two inputs are two different outputs from the same node, - // they need to be replaced by two completely separate data nodes) - auto inputs = GetSubgraphInputs(subgraph, subgraph_set); - auto subgraph_node = create_subgraph_node(subgraph); - subgraph_node->inputs = inputs; - // replug inputs of node out of subgraph to be output of the subgraph node - // if it was a node in the subgraph - DFSVisit(g.outputs, - [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const nnvm::ObjectPtr node) { - if (!subgraph_set.count(node.get())) { - for (auto &e : node->inputs) { - auto it = sub_outputs_in_main.find(e); - if (it != sub_outputs_in_main.end()) { - e.node = subgraph_node; - e.index = it->second; - } - } - } - }); - // replug outputs of the graph to be output of the subgraph node - // if it was a node in the subgraph - for (auto &e : g.outputs) { - auto it = sub_outputs_in_main.find(e); - if (it != sub_outputs_in_main.end()) { - e.node = subgraph_node; - e.index = it->second; - } + } + // Then reorder them, removing sets that have + // only a single element. + int final_num_sets = 0; + for (int& set : set_reorder) { + if (set > 1) { + set = final_num_sets++; + } else { + set = -1; } - // move control dependencies between nodes of the subgraph and out of the subgraph - // to a dependencies between the subgraph node and the nodes out of the subgraph - DFSVisit(g.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr& node) { - for (auto &e : node->control_deps) { - if (subgraph_set.count(e.get())) - e = subgraph_node; - } - }); - DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr& node) { - auto it = node->control_deps.begin(); - while (it != node->control_deps.end()) { - if (subgraph_set.count(it->get())) { - ++it; - } else { - subgraph_node->control_deps.push_back(*it); - it = node->control_deps.erase(it); - } - } - }); } - Graph new_graph; - new_graph.outputs = g.outputs; - return new_graph; -} -/* \brief Get all subsets of nodes, where: - * - graph constructed from nodes in each subset is a connected graph - * - every node fulfills a predicate is_compatible - * - if nodes u and v are part of a subset, then for each path between - * u and v in the original directed graph, all nodes on those paths - * are also part of the subset - * \param g NNVM graph - * \param is_compatible A function taking nnvm::Node* and returning bool - * which identifies which nodes should be included in - * subsets. - */ -template -std::vector GetCompatibleSubsets(const Graph& g, FCompatible is_compatible) { - BidirectionalGraph biG = BidirectionalGraph(g); - std::vector> subsets = - biG.get_subsets(is_compatible); - std::vector nnvm_subsets; - nnvm_subsets.reserve(subsets.size()); - for (auto& subset : subsets) { - if (subset.size() > 1) { - NodeRawPtrSet node_set; - node_set.reserve(subset.size()); - for (auto& n : subset) { - node_set.insert(n->nnvmptr); - } - nnvm_subsets.push_back(node_set); + for (int& set : set_assignment) { + if (set != -1) { + set = set_reorder[set]; } } - return nnvm_subsets; + + return {set_assignment, final_num_sets}; } } // namespace exec diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index c56d8cf198b5..23ab4a068a53 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -238,10 +238,7 @@ void OptimizeGraph(nnvm::Graph * full_graph, nnvm::Graph * fwd_graph, nnvm::Grap common::CopyGraph(&unoptimized_graph, *full_graph, false); if (common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) { - full_graph->attrs["num_forward_outputs"] = std::make_shared(num_forward_outputs); - *full_graph = exec::FusePointwiseForward(std::move(*full_graph)); - full_graph->attrs["num_forward_outputs"] = std::make_shared(num_forward_outputs); - *full_graph = exec::FusePointwiseBackward(std::move(*full_graph)); + *full_graph = exec::FusePointwise(*full_graph, num_forward_outputs); // Check the topological order of inputs const auto &original_inputs = unoptimized_graph.indexed_graph().input_nodes(); const auto &new_inputs = full_graph->indexed_graph().input_nodes(); diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu index cb13dbf6ce97..1b914ccd43b3 100644 --- a/src/operator/fusion/fused_op.cu +++ b/src/operator/fusion/fused_op.cu @@ -261,12 +261,40 @@ std::string FusedOp::GenerateCode(const std::vector &req, const auto& var_name = g[node_id].source->attrs.name; const auto vec_name = "vec_" + var_name + "_" + std::to_string(i); load_index[node_id] = 0; - auto parse_tuple = [](const std::string& input, const std::string def) { + auto parse_tuple = [ndim](const std::string& input, const std::string& def) { std::string out = input; - replaceString(&out, "(", "{"); - replaceString(&out, ")", "}"); + replaceString(&out, " ", ""); + if (out[0] == '(') { + replaceString(&out, "(", "{"); + replaceString(&out, ")", "}"); + // First check if out is () + int n_entries = out.size() != 2; + for (size_t i = 1; i < out.size() - 1; ++i) { + if (out[i] == ',') { + ++n_entries; + } + } + if (n_entries != ndim) { + out.pop_back(); + for (int i = n_entries; i < ndim; ++i) { + out += "," + def; + } + out += "}"; + } + } else { + out = "{" + std::move(out); + for (int i = 1; i < ndim; ++i) { + out += "," + def; + } + out += "}"; + } replaceString(&out, "None", def); + return out; + }; + auto parse_int = [](const std::string& input, const std::string& def) { + std::string out = input; replaceString(&out, " ", ""); + replaceString(&out, "None", def); return out; }; auto build_tuple = [ndim](int axis, const std::string str, const std::string def) { @@ -279,11 +307,11 @@ std::string FusedOp::GenerateCode(const std::vector &req, } std::string tuple = "{"; for (int i = 0; i < axis; i++) { - tuple = tuple + def + ","; + tuple += def + ","; } tuple += str; for (int i = axis + 1; i < ndim; i++) { - tuple = tuple + "," + def; + tuple += "," + def; } tuple += "}"; return tuple; @@ -295,12 +323,6 @@ std::string FusedOp::GenerateCode(const std::vector &req, } return false; }; - auto build_string_axis = [ndim](int axis) { - if (axis < 0) { - axis = ndim + axis; - } - return std::to_string(axis); - }; auto build_string_end = [i, ndim, var_name](std::string* code) { std::string end_var_name = var_name + "_" + std::to_string(i) + "_end"; *code += "op::Shape<" + std::to_string(ndim) + "> "+ end_var_name + ";\n"; @@ -323,12 +345,15 @@ std::string FusedOp::GenerateCode(const std::vector &req, } end = extra_var_name; } else { - begin = parse_tuple(source->attrs.dict.at("begin"), "0"); - end = parse_tuple(source->attrs.dict.at("end"), "INT_MAX"); if (op_name == "slice_axis") { + begin = parse_int(source->attrs.dict.at("begin"), "0"); + end = parse_int(source->attrs.dict.at("end"), "INT_MAX"); int axis = std::stoi(source->attrs.dict.at("axis")); begin = build_tuple(axis, begin, "0"); end = build_tuple(axis, end, "INT_MAX"); + } else { + begin = parse_tuple(source->attrs.dict.at("begin"), "0"); + end = parse_tuple(source->attrs.dict.at("end"), "INT_MAX"); } if (check_shapes) { if (check_tuple(begin) && check_tuple(end)) { diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index 1bbf5982f45f..aa0102dcb0c8 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -337,6 +337,26 @@ def test_fusion_reshape_executor(): out = f.forward(is_train=False, data1=data, data2=data) assert out[0].sum().asscalar() == 150 +@with_seed() +def test_fusion_cycle(): + class Test(gluon.nn.HybridBlock): + def __init__(self, **kwargs): + super(Test, self).__init__(**kwargs) + + def hybrid_forward(self, F, x, y): + x = F.relu(x) + y = F.relu(y) + z1 = F.expand_dims(F.sum_axis(x, axis=1), axis=1) + z2 = F.expand_dims(F.sum_axis(y, axis=1), axis=1) + return x + z2, y + z1 + + t = Test() + a = mx.nd.zeros(shape=(10,1), ctx=mx.gpu()) + b = mx.nd.zeros(shape=(10,1), ctx=mx.gpu()) + t.hybridize(static_alloc=True, static_shape=True) + out = t(a, b) + mx.nd.waitall() + if __name__ == '__main__': import nose nose.runmodule() From d99d3fa93b3c7e42a40061db93d1f5bf7acb57cf Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 22 Oct 2020 17:03:12 -0700 Subject: [PATCH 2/4] Use std::tie instead of C++17 structured binding --- src/executor/pointwise_fusion_pass.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/executor/pointwise_fusion_pass.cc b/src/executor/pointwise_fusion_pass.cc index aa139903339e..961dade0fa93 100644 --- a/src/executor/pointwise_fusion_pass.cc +++ b/src/executor/pointwise_fusion_pass.cc @@ -361,9 +361,11 @@ Graph CopyAndReplaceSubgraphs(const Graph& g, Graph FusePointwise(const Graph &g, const size_t num_forward_outputs) { auto start = std::chrono::steady_clock::now(); - auto [subset_assignment, num_subsets] = GetCompatibleSubsets(g, num_forward_outputs, // NOLINT(*) - IsFusionCompatible, - IsInputsOnlyCompatible); + std::vector subset_assignment; + int num_subsets; + std::tie(subset_assignment, num_subsets) = GetCompatibleSubsets(g, num_forward_outputs, // NOLINT(*) + IsFusionCompatible, + IsInputsOnlyCompatible); Graph ret = CopyAndReplaceSubgraphs(g, subset_assignment, num_subsets, CreateSubgraphNode); auto end = std::chrono::steady_clock::now(); From 023db4bfbe3601466f2148b926c83d6fc89a16b5 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 23 Oct 2020 08:53:25 -0700 Subject: [PATCH 3/4] More fixes for lack of c++17 --- src/executor/graph_executor.cc | 5 +---- src/executor/simple_partition_pass.h | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 28b79aeda1ee..3f2c7c93fb70 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1009,10 +1009,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, common::CopyGraph(&unoptimized_graph, g, false); if (common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) { - g.attrs["num_forward_outputs"] = std::make_shared(num_forward_outputs_); - g = FusePointwiseForward(std::move(g)); - g.attrs["num_forward_outputs"] = std::make_shared(num_forward_outputs_); - g = FusePointwiseBackward(std::move(g)); + g = exec::FusePointwise(std::move(g), num_forward_outputs_); // Check the topological order of inputs const auto &original_inputs = unoptimized_graph.indexed_graph().input_nodes(); const auto &new_inputs = g.indexed_graph().input_nodes(); diff --git a/src/executor/simple_partition_pass.h b/src/executor/simple_partition_pass.h index 764486f8e6c1..2a135b41dca1 100644 --- a/src/executor/simple_partition_pass.h +++ b/src/executor/simple_partition_pass.h @@ -264,7 +264,7 @@ std::tuple, int> GetCompatibleSubsets( } } - return {set_assignment, final_num_sets}; + return std::make_tuple(std::move(set_assignment), final_num_sets); } } // namespace exec From c3e1edad266a177f742786ed77a84e40e19b574a Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 23 Oct 2020 13:38:07 -0700 Subject: [PATCH 4/4] Fix --- tests/python/gpu/test_fusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index aa0102dcb0c8..943cea4ebbe9 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -339,7 +339,8 @@ def test_fusion_reshape_executor(): @with_seed() def test_fusion_cycle(): - class Test(gluon.nn.HybridBlock): + from mxnet.gluon import HybridBlock + class Test(HybridBlock): def __init__(self, **kwargs): super(Test, self).__init__(**kwargs)