Skip to content

Commit

Permalink
Add disable attr to subgraph property (apache#15926)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhennanQin authored and pengzhao-intel committed Aug 17, 2019
1 parent 5a4c01b commit a8b9728
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 16 deletions.
6 changes: 4 additions & 2 deletions docs/tutorials/c++/subgraphAPI.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ class SgProperty : public SubgraphProperty {
};
```
`SetAttr` is optional and developer can define their own attributes to control property behavior.
There're 2 built-in attributes that used by MXNet executor.
There're some built-in attributes that used by MXNet executor.

`property_name` : std::string, name of this property.
`property_name` : std::string, name of this property, used for diagnose.

`disable` : bool, whther to disable this property.

`inference_only` : bool, apply this property only for inference. Property will be skiped when need_grad=True. Default `false` if this attribute isn't defined.

Expand Down
8 changes: 8 additions & 0 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,14 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend_name,
auto backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name);
const auto& subgraph_prop_list = backend->GetSubgraphProperties();
for (auto property : subgraph_prop_list) {
if (property->HasAttr("disable") && property->GetAttr<bool>("disable") == true) {
auto full_name = property->HasAttr("property_name")
? property->GetAttr<std::string>("property_name")
: std::string();
LOG(INFO) << "subgraph property " << full_name << " from backend " << backend_name
<< " is disabled.";
continue;
}
nnvm::Graph g = Symbol2Graph(*s);
property->SetAttr("graph", g);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
Expand Down
9 changes: 7 additions & 2 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1659,10 +1659,15 @@ static bool SubgraphBackendCheck(const op::SubgraphBackendPtr& backend,
static bool SubgraphPropertyCheck(const std::string& backend_name,
const op::SubgraphPropertyPtr& prop, bool need_grad,
bool verbose = false) {
auto full_name =
prop->HasAttr("property_name") ? prop->GetAttr<std::string>("property_name") : std::string();
if (prop->HasAttr("disable") && prop->GetAttr<bool>("disable") == true) {
LOG(INFO) << "subgraph property " << full_name << " from backend " << backend_name
<< " is disabled.";
return false;
}
if (prop->HasAttr("inference_only") && prop->GetAttr<bool>("inference_only") == true) {
if (need_grad) {
auto full_name = prop->HasAttr("property_name") ? prop->GetAttr<std::string>("property_name")
: std::string();
if (verbose) {
LOG(INFO) << "skip partitioning graph with subgraph property " << full_name
<< " from backend " << backend_name << " as it requires `grad_req=null`.";
Expand Down
7 changes: 3 additions & 4 deletions src/operator/subgraph/mkldnn/mkldnn_conv_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,12 @@ class SgMKLDNNConvProperty : public SubgraphProperty {
}
static SubgraphPropertyPtr Create() {
static const std::string &name = "MKLDNN convolution optimization pass";
if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_CONV_OPT", 0)) {
LOG(INFO) << name << " is disabled.";
return nullptr;
}
auto property = std::make_shared<SgMKLDNNConvProperty>();
property->SetAttr<std::string>("property_name", name);
property->SetAttr<bool>("inference_only", true);
if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_CONV_OPT", 0)) {
property->SetAttr<bool>("disable", true);
}
return property;
}
nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
Expand Down
7 changes: 3 additions & 4 deletions src/operator/subgraph/mkldnn/mkldnn_fc_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,12 @@ class SgMKLDNNFCProperty : public SubgraphProperty {

static SubgraphPropertyPtr Create() {
static const std::string &name = "MKLDNN FullyConnected optimization pass";
if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FC_OPT", 0)) {
LOG(INFO) << name << " is disabled.";
return nullptr;
}
auto property = std::make_shared<SgMKLDNNFCProperty>();
property->SetAttr<std::string>("property_name", name);
property->SetAttr<bool>("inference_only", true);
if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FC_OPT", 0)) {
property->SetAttr<bool>("disable", true);
}
return property;
}

Expand Down
11 changes: 7 additions & 4 deletions src/operator/subgraph/subgraph_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class SubgraphPropertyEntry {

template<typename T>
SubgraphPropertyEntry set_attr(const std::string& name, const T value) const {
entry_->SetAttr<T>(name, value);
if (entry_) entry_->SetAttr<T>(name, value);
return *this;
}

Expand Down Expand Up @@ -403,9 +403,12 @@ class SubgraphBackend {
}
}

SubgraphPropertyPtr& RegisterSubgraphProperty(const SubgraphPropertyPtr prop) {
prop_ptr_.push_back(prop);
return prop_ptr_.back();
SubgraphPropertyPtr RegisterSubgraphProperty(SubgraphPropertyPtr prop) {
if (prop) {
prop_ptr_.push_back(prop);
return prop_ptr_.back();
}
return prop;
}

const std::string& GetName() const { return name_; }
Expand Down

0 comments on commit a8b9728

Please sign in to comment.