Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
move InferForwardAttrs to common/
Browse files Browse the repository at this point in the history
  • Loading branch information
mseth10 committed Aug 17, 2019
1 parent 57e276a commit dfeeacf
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 35 deletions.
4 changes: 3 additions & 1 deletion src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "nnvm/pass_functions.h"
#include "nnvm/symbolic.h"
#include "./c_api_common.h"
#include "../common/exec_utils.h"
#include "../operator/operator_common.h"
#include "../executor/exec_pass.h"
#include "../operator/subgraph/subgraph_property.h"
Expand Down Expand Up @@ -1213,7 +1214,8 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
arg_dtypes.push_back(in_arg.dtype());
arg_stypes.push_back(in_arg.storage_type());
in_arg_ctxes[i] = in_arg.ctx();
orig_g = InferForwardAttrs(orig_g, arg_shapes, arg_dtypes, arg_stypes,
}
orig_g = common::InferForwardAttrs(orig_g, arg_shapes, arg_dtypes, arg_stypes,
default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes, true);
}
std::vector<std::pair<std::string, std::string>> options_map;
Expand Down
36 changes: 36 additions & 0 deletions src/common/exec_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,42 @@ inline nnvm::Graph AssignContext(nnvm::Graph g,
return g;
}

/*!
* \brief infers shapes, dtypes, stypes, contexts for the forward graph
*/
inline nnvm::Graph InferForwardAttrs(nnvm::Graph g,
mxnet::ShapeVector arg_shapes,
nnvm::DTypeVector arg_dtypes,
StorageTypeVector arg_stypes,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& aux_state_ctxes,
bool partial_shape = false) {
const auto& indexed_graph = g.indexed_graph();
const auto num_forward_inputs = indexed_graph.input_nodes().size();
g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
if (!partial_shape) {
HandleInferShapeError(num_forward_inputs, indexed_graph,
g.GetAttr<mxnet::ShapeVector>("shape"));
}
}
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
HandleInferTypeError(num_forward_inputs, indexed_graph,
g.GetAttr<nnvm::DTypeVector>("dtype"));
}
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
g.GetAttr<StorageTypeVector>("storage_type"));
}
return g;
}

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_EXEC_UTILS_H_
Expand Down
34 changes: 0 additions & 34 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1602,40 +1602,6 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start,
return ret;
}

// Infer shapes, dtypes, stypes, contexts for the forward graph
static nnvm::Graph InferForwardAttrs(nnvm::Graph g,
mxnet::ShapeVector arg_shapes,
nnvm::DTypeVector arg_dtypes,
StorageTypeVector arg_stypes,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& aux_state_ctxes,
bool partial_shape = false) {
const auto& indexed_graph = g.indexed_graph();
const auto num_forward_inputs = indexed_graph.input_nodes().size();
g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
if (!partial_shape) {
HandleInferShapeError(num_forward_inputs, indexed_graph,
g.GetAttr<mxnet::ShapeVector>("shape"));
}
}
g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
HandleInferTypeError(num_forward_inputs, indexed_graph,
g.GetAttr<nnvm::DTypeVector>("dtype"));
}
g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
g.GetAttr<StorageTypeVector>("storage_type"));
}
return g;
}

static bool SubgraphBackendCheck(const op::SubgraphBackendPtr& backend,
const Context& default_ctx,
bool verbose = false) {
Expand Down
1 change: 1 addition & 0 deletions src/operator/subgraph/subgraph_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <unordered_map>
#include <vector>
#include <string>
#include <utility>

namespace mxnet {
namespace op {
Expand Down

0 comments on commit dfeeacf

Please sign in to comment.