Skip to content

Commit

Permalink
identity fuse (apache#20884)
Browse files Browse the repository at this point in the history
rewrite test

fix sanity

remove clang warning

Co-authored-by: Bartlomiej Gawrych <barlomiej.gawrych@intel.com>
  • Loading branch information
bgawrych and Bartlomiej Gawrych authored Feb 15, 2022
1 parent 1c4776f commit 453ccb8
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 1 deletion.
173 changes: 173 additions & 0 deletions src/operator/subgraph/mkldnn/mkldnn_identity_property.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_identity_property.cc
* \brief Graph property for removing identity operators
*/

#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_IDENTITY_PROPERTY_H_
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_IDENTITY_PROPERTY_H_
#if MXNET_USE_MKLDNN == 1

#include <map>
#include <string>
#include <vector>

#include "../common.h"
#include "../../nn/dropout-inl.h"
#include "mkldnn_subgraph_base-inl.h"

namespace mxnet {
namespace op {

class SgMKLDNNIdentitySelector : public SubgraphSelectorV2 {
private:
std::vector<const BiDirectedNode*> matched_list_;
bool pattern_found = false;

public:
bool Select(const BiDirectedNode& seed_node,
const std::shared_ptr<NodeAttr>& node_attr) override {
bool status = false;
if (seed_node.node->op() == Op::Get("_copy")) {
status = true;
}

if (seed_node.node->op() == Op::Get("Dropout")) {
auto const& dropout_param = nnvm::get<DropoutParam>(seed_node.node->attrs.parsed);
if (dropout_param.mode == dropout::kTraining) {
status = true;
}
}

if (status) {
matched_list_.clear();
matched_list_.emplace_back(&seed_node);
return true;
}
return false;
}

bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node) override {
if (pattern_found || input_node.node->is_variable()) {
return false;
} else if (input_node.node->op()) {
matched_list_.emplace_back(&input_node);
pattern_found = true;
return true;
}
return false;
}

bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& output_node) override {
return false;
}

std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& candidates) override {
// candidates should contain only two nodes - custom node and identity node
if (pattern_found && candidates.size() == matched_list_.size()) {
CHECK_EQ(candidates.size(), 2);
return candidates;
} else {
return std::vector<BiDirectedNode*>(0);
}
}

void Reset() override {
CHECK_GE(matched_list_.size(), 1);
auto new_selector = SgMKLDNNIdentitySelector();
new_selector.Select(*matched_list_[0], nullptr);
*this = new_selector;
}
};

inline bool IsIdentityNode(const nnvm::ObjectPtr node) {
return node->op() && (node->op() == Op::Get("_copy") || node->op() == Op::Get("Dropout"));
}

class SgMKLDNNIdentityProperty : public SubgraphProperty {
public:
SgMKLDNNIdentityProperty() {}

static SubgraphPropertyPtr Create() {
static const std::string& name = "MKLDNN Identity optimization passs";
auto property = std::make_shared<SgMKLDNNIdentityProperty>();
property->SetAttr<std::string>("property_name", name);
property->SetAttr<bool>("inference_only", true);
return property;
}

nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
const int subgraph_id = 0) const override {
nnvm::NodeEntry identity_node_entry;
for (auto entry : sym.outputs) {
if (IsIdentityNode(entry.node)) {
identity_node_entry = entry;
}
}

auto last_node = identity_node_entry.node;
nnvm::Symbol new_sym;
new_sym.outputs.emplace_back(last_node);

nnvm::ObjectPtr org_node;
DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr& node) {
if (!IsIdentityNode(node)) {
org_node = node;
}
});

// Create copy of original node
nnvm::ObjectPtr n = nnvm::Node::Create();
n->attrs = org_node->attrs;
if (n->op() && n->op()->attr_parser) {
n->op()->attr_parser(&(n->attrs));
}

return n;
}

void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
std::vector<nnvm::NodeEntry*>* output_entries) const override {
// output of identity must be connected as output of operator before identity
// e.g. for: /--index 0--> custom_op
// (n) slice
// \--index 1--> Dropout --index 0--> OUT_NODE
// for OUT_NODE index 0 must be changed to index 1
for (size_t i = 0; i < output_entries->size(); ++i) {
auto out_node = output_entries->at(i)->node;
if (IsIdentityNode(out_node)) {
output_entries->at(i)->index = out_node->inputs[0].index;
}
output_entries->at(i)->node = n;
}
}

SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
auto selector = std::make_shared<SgMKLDNNIdentitySelector>();
return selector;
}
};

} // namespace op
} // namespace mxnet

#endif // if MXNET_USE_MKLDNN == 1
#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_IDENTITY_PROPERTY_H_
2 changes: 1 addition & 1 deletion src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ static inline bool SupportMKLDNNAttr(const std::shared_ptr<NodeAttr>& node_attr)
return (node_attr->dispatch_mode == DispatchMode::kFComputeEx) &&
(node_attr->itype[0] == mshadow::kFloat32 ||
node_attr->itype[0] == mshadow::kBfloat16) &&
(ndim == 1 || ndim == 2 || ndim == 4 || ndim == 5);
(ndim >= 1 && ndim <= 5);
} else {
return true;
}
Expand Down
6 changes: 6 additions & 0 deletions src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "mkldnn_conv_property.h"
#include "mkldnn_fc_property.h"
#include "mkldnn_identity_property.h"
#include "mkldnn_post_quantize_property.h"
#include "mkldnn_fc_post_quantize_property.h"
#include "mkldnn_elemwisemul_post_quantize_property.h"
Expand All @@ -35,6 +36,8 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN)
.set_attr("enable", MKLDNNEnvSet())
.set_attr("context", Context::CPU());

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNIdentityProperty);

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty);

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty);
Expand All @@ -44,12 +47,15 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNTransformerProperty);
MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE)
.set_attr("context", Context::CPU());

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNIdentityProperty);

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty)
.set_attr("quantize", true);

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty)
.set_attr("quantize", true);

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNIdentityProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerProperty);

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerPostQuantizeProperty);
Expand Down
28 changes: 28 additions & 0 deletions tests/python/mkl/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,28 @@ def fc_eltwise(no_bias, data_shape, flatten=True, alg='relu'):

return sym, attr

def fc_identity_eltwise(data_shape, identity_node):
attrs = {'sg_mkldnn_fully_connected_eltwise_0' : {'with_eltwise': 'true'},
'sg_mkldnn_fully_connected_eltwise_1' : {'with_eltwise': 'true'}}
data, fc1_weight = head_symbol(data_shape)
fc2_weight = mx.symbol.Variable('fc2_weight', dtype='float32')

sym = mx.symbol.FullyConnected(name='fc1', data=data, weight=fc1_weight, num_hidden=64,
no_bias=True, flatten=True)
if identity_node == 'copy':
sym = mx.symbol.identity(sym)
else:
sym = mx.symbol.Dropout(sym)
sym = mx.symbol.Activation(sym, act_type='relu')
sym = mx.symbol.FullyConnected(name='fc2', data=sym, weight=fc2_weight, num_hidden=64,
no_bias=True, flatten=True)
if identity_node == 'copy':
sym = mx.symbol.identity(sym)
else:
sym = mx.symbol.Dropout(sym)
sym = mx.symbol.Activation(sym, act_type='relu')
return sym, attrs

def single_selfatt_qk(data_shape, nheads=16):
attr = {'selfatt_qk': {}}
data = mx.symbol.Variable('data', shape=data_shape, dtype='float32')
Expand Down Expand Up @@ -876,6 +898,12 @@ def test_single_fc():
else:
check_fusion(syms, dshape, attrs, check_quantization=False)

@with_seed()
def test_fc_eltwise_identity():
for dshape, identity_node in itertools.product(DATA_SHAPE, ['copy', 'dropout']):
syms, attrs = fc_identity_eltwise(dshape, identity_node)
check_fusion(syms, dshape, attrs, check_quantization=False)

@with_seed()
def test_fc_eltwise():
for dshape, no_bias, flatten, alg in itertools.product(DATA_SHAPE,
Expand Down

0 comments on commit 453ccb8

Please sign in to comment.