Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add quantized version of reshape with DNNL reorder primitive.
Browse files Browse the repository at this point in the history
  • Loading branch information
agrabow committed Jan 20, 2022
1 parent 50a8ee8 commit a3870eb
Show file tree
Hide file tree
Showing 7 changed files with 352 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/operator/nn/dnnl/dnnl_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ DNNLReshapeFwd& GetReshapeForward(const OpReqType& req,
DNNLReshapeSignature key;
key.AddSign(req);
key.AddSign(input);
key.AddSign(output);

auto it = fwds.find(key);
if (it == fwds.end()) {
Expand Down
4 changes: 4 additions & 0 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ struct NumpyXReshapeParam : public dmlc::Parameter<NumpyXReshapeParam> {
}
};

bool NumpyXReshapeShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs);

template <typename xpu>
void NumpyTranspose(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
154 changes: 154 additions & 0 deletions src/operator/quantization/dnnl/dnnl_quantized_reshape-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file dnnl_quantized_reshape-inl.h
* \author: Adam Grabowski, adam.grabowski@intel.com
*/

#ifndef MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RESHAPE_INL_H_
#define MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RESHAPE_INL_H_

#if MXNET_USE_ONEDNN == 1
#include "../../tensor/matrix_op-inl.h"
#include "../../numpy/np_matrix_op-inl.h"
#include "../../nn/dnnl/dnnl_ops-inl.h"
#include <string>
#include <vector>

namespace mxnet {
namespace op {

struct QuantizedReshapeParam : public dmlc::Parameter<QuantizedReshapeParam> {
mxnet::TShape newshape;
mxnet::Tuple<int> shape;
bool reverse, keep_highest, is_numpy_op;
std::string order;

DMLC_DECLARE_PARAMETER(QuantizedReshapeParam) {
DMLC_DECLARE_FIELD(newshape).set_default(mxnet::TShape(0, -1));
DMLC_DECLARE_FIELD(shape).set_default(mxnet::Tuple<int>());
DMLC_DECLARE_FIELD(reverse).set_default(false);
DMLC_DECLARE_FIELD(order).set_default("C");
DMLC_DECLARE_FIELD(keep_highest).set_default(false);
DMLC_DECLARE_FIELD(is_numpy_op).set_default(true);
}

void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream newshape_s, shape_s, reverse_s, order_s, keep_highest_s, is_numpy_op_s;
newshape_s << newshape;
shape_s << shape;
reverse_s << reverse;
order_s << order;
keep_highest_s << keep_highest;
is_numpy_op_s << is_numpy_op;
(*dict)["newshape"] = newshape_s.str();
(*dict)["shape"] = shape_s.str();
(*dict)["reverse"] = reverse_s.str();
(*dict)["order"] = order_s.str();
(*dict)["keep_highest"] = keep_highest_s.str();
(*dict)["is_numpy_op"] = is_numpy_op_s.str();
}
};

bool QuantizedReshapeInferShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
const QuantizedReshapeParam& param = nnvm::get<QuantizedReshapeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 3U);
mxnet::ShapeVector input = {in_attrs->at(0)};
mxnet::ShapeVector output = {out_attrs->at(0)};
nnvm::NodeAttrs _attrs;
bool ret;

if (param.is_numpy_op) {
NumpyXReshapeParam _param;
_param.newshape = param.newshape;
_param.reverse = param.reverse;
_param.order = param.order;
_attrs.parsed = _param;
ret = NumpyXReshapeShape(_attrs, &input, &output);
} else {
ReshapeParam _param;
_param.shape = param.shape;
_param.keep_highest = param.keep_highest;
_param.reverse = param.reverse;
_attrs.parsed = _param;
ret = ReshapeShape(_attrs, &input, &output);
}
SHAPE_ASSIGN_CHECK(*in_attrs, 1, mxnet::TShape{1});
SHAPE_ASSIGN_CHECK(*in_attrs, 2, mxnet::TShape{1});
SHAPE_ASSIGN_CHECK(*out_attrs, 0, output[0]);
SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape{1});
SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape{1});

return ret;
}

bool QuantizedReshapeStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 3U);
return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
}

bool QuantizedReshapeType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 3U);
TYPE_ASSIGN_CHECK(*in_attrs, 1, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*in_attrs, 2, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32);
return (*in_attrs)[0] != -1;
}

static void DNNLQuantizedReshapeForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK(inputs[0].dtype() == mshadow::kUint8 || inputs[0].dtype() == mshadow::kInt8)
<< "dnnl_quantized_reshape op only supports uint8 and int8 as input type";

if (SupportDNNLReshape(inputs[0], outputs[0])) {
OpReqType reqType;
if (inputs[0].GetDNNLData()->get_data_handle() != outputs[0].GetDNNLData()->get_data_handle())
reqType = kWriteTo;
else
reqType = req[0];
DNNLRun(DNNLReshapeForward, attrs, ctx, inputs[0], reqType, outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
*outputs[1].data().dptr<float>() = *inputs[1].data().dptr<float>();
*outputs[2].data().dptr<float>() = *inputs[2].data().dptr<float>();
}

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_ONEDNN == 1
#endif // MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RESHAPE_INL_H_
110 changes: 110 additions & 0 deletions src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file dnnl_quantized_reshape.cc
* \author: Adam Grabowski, adam.grabowski@intel.com
*/

#if MXNET_USE_ONEDNN == 1
#include "./dnnl_quantized_reshape-inl.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(QuantizedReshapeParam);

NNVM_REGISTER_OP(_contrib_quantized_reshape)
.add_alias("_npx_quantized_reshape")
.set_num_inputs(3)
.set_num_outputs(3)
.set_attr_parser(ParamParser<QuantizedReshapeParam>)
.set_attr<nnvm::FListInputNames>(
"FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "min_data", "max_data"};
})
.set_attr<nnvm::FListOutputNames>(
"FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output", "min_output", "max_output"};
})
.set_attr<nnvm::FInplaceOption>(
"FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}, {1, 1}, {2, 2}};
})
.set_attr<FComputeEx>("FComputeEx<cpu>", DNNLQuantizedReshapeForward)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FInferStorageType>("FInferStorageType", QuantizedReshapeStorageType)
.set_attr<mxnet::FInferShape>("FInferShape", QuantizedReshapeInferShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizedReshapeType)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FQuantizable>("FQuantizable",
[](const NodeAttrs& attrs) { return QuantizeType::kSupport; })
.add_argument("data", "NDArray-or-Symbol", "Array to be reshaped.")
.add_argument("min_data",
"NDArray-or-Symbol",
"The minimum scalar value "
"possibly produced for the data")
.add_argument("max_data",
"NDArray-or-Symbol",
"The maximum scalar value "
"possibly produced for the data")
.add_arguments(QuantizedReshapeParam::__FIELDS__());

template <bool is_numpy_op>
nnvm::ObjectPtr QuantizedReshapeNode(const NodeAttrs& attrs) {
QuantizedReshapeParam param;
if (is_numpy_op) {
const NumpyXReshapeParam& _param = nnvm::get<NumpyXReshapeParam>(attrs.parsed);
param.newshape = _param.newshape;
param.reverse = _param.reverse;
param.order = _param.order;
param.keep_highest = false;
param.is_numpy_op = true;
} else {
const ReshapeParam& _param = nnvm::get<ReshapeParam>(attrs.parsed);
param.shape = _param.shape;
param.keep_highest = _param.keep_highest;
param.reverse = _param.reverse;
param.is_numpy_op = false;
}

nnvm::ObjectPtr node = nnvm::Node::Create();
node->attrs.op = Op::Get("_contrib_quantized_reshape");
node->attrs.name = "quantized_" + attrs.name;
param.SetAttrDict(&(node->attrs.dict));
if (node->op() != nullptr && node->op()->attr_parser != nullptr) {
node->op()->attr_parser(&(node->attrs));
}
return node;
}

NNVM_REGISTER_OP(_npx_reshape).set_attr<FQuantizedOp>("FQuantizedOp", QuantizedReshapeNode<true>);

NNVM_REGISTER_OP(Reshape).set_attr<FQuantizedOp>("FQuantizedOp", QuantizedReshapeNode<false>);

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_ONEDNN == 1
22 changes: 22 additions & 0 deletions tests/python/dnnl/subgraphs/test_conv_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,28 @@ def forward(self, x):
check_fusion(net, data_shape, attr)


@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('use_bias', [True, False])
def test_conv_reshape_conv(use_bias, data_shape):

class Conv_Reshape_Conv(nn.HybridBlock):
def __init__(self, **kwargs):
super(Conv_Reshape_Conv, self).__init__(**kwargs)
self.conv0 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=1, use_bias=use_bias)
self.conv1 = nn.Conv2D(channels=32, kernel_size=(5, 5), strides=1, use_bias=use_bias)

def forward(self, x):
out = self.conv0(x)
out = mx.npx.reshape(out, newshape=(-1, int(out.shape[1]/4), out.shape[2]*2, out.shape[3]*2))
out = self.conv1(out)
return out

attr = {'conv': []}
net = Conv_Reshape_Conv()
check_fusion(net, data_shape, attr)


@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('use_bias', [True, False])
Expand Down
21 changes: 21 additions & 0 deletions tests/python/dnnl/subgraphs/test_fc_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@ def forward(self, x):
check_fusion(net, data_shape, attrs, check_quantization=flatten)


@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('use_bias', [True, False])
@pytest.mark.parametrize('flatten', [True, False])
def test_fc_reshape(data_shape, use_bias, flatten):

class FC_Reshape(nn.HybridBlock):
def __init__(self, use_bias, flatten, **kwargs):
super(FC_Reshape, self).__init__(**kwargs)
self.fc = nn.Dense(units=64, use_bias=use_bias, flatten=flatten)

def forward(self, x):
out = self.fc(x)
out = mx.npx.reshape(out, newshape=(1, -1))
return out

attrs = {'fc': {}}
net = FC_Reshape(use_bias, flatten)
check_fusion(net, data_shape, attrs, check_quantization=flatten)


@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('use_bias', [True, False])
Expand Down
40 changes: 40 additions & 0 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,46 @@ def check_quantized_bn(data_shape, qdtype):
check_quantized_bn((32, 3, 224, 224), qdtype)


def test_quantized_reshape():
test_cases = [((2, 3, 5, 5), (-2, -1), False, (2, 75)),
((2, 3, 5, 5), (-2, -2, -1), False, (2, 3, 25)),
((5, 3, 4, 5), (-2, -1, -2), False, (5, 15, 4)),
((2, 3, 5, 4), (-1, -2, -2), False, (8, 3, 5)),
((2, 3, 5, 5), (-2, -2, -2, -2), False, (2, 3, 5, 5)),
((2, 1, 4, 5), (-2, -3, -2, -2), False, (2, 4, 5)),
((1, 1, 4, 1), (-3, -3, -2, -2), False, (4, 1)),
((1, 1, 1, 1), (-3, -3, -3, -3), False, ()),
((2, 4, 5, 3), (-1, 2, 2, 1), False, (30, 2, 2, 1)),
((2, 3, 5, 6), (-4,), False, (2, 3, 5, 6)),
((2, 3, 5, 6), (6, 1, -4), False, (6, 1, 5, 6)),
((2, 3, 5, 6), (-5, -5), False, (6, 30)),
((2, 3, 5, 6), (-5, -1), False, (6, 30)),
((64,), (-6, 16, 4), False, (16, 4)),
((64,), (-6, 16, -1), False, (16, 4)),
((64, 1, 2, 3), (-6, 16, -1, -4), False, (16, 4, 1, 2, 3)),
((8, 5, 4, 6), (-4, -1, 3, -6), True, (8, 5, 4, 2, 3))]

def check_quantized_reshape(shape, qdtype, newshape, reverse, expected_ret_shape):
if qdtype == 'uint8':
data_low = 0.0
data_high = 127.0
else:
data_low = -127.0
data_high = 127.0
qdata = mx.np.random.uniform(low=data_low, high=data_high, size=shape).astype(qdtype)
min_data = mx.np.array([-1023.343], dtype='float32')
max_data = mx.np.array([2343.324275], dtype='float32')
qoutput, min_output, max_output = npx.quantized_reshape(qdata, min_data, max_data, newshape=newshape, reverse=reverse)
assert qoutput.shape == expected_ret_shape
assert same(qdata.asnumpy().flatten(), qoutput.asnumpy().flatten())
assert same(min_data.asnumpy(), min_output.asnumpy())
assert same(max_data.asnumpy(), max_output.asnumpy())

for qdtype in ['int8', 'uint8']:
for shape, newshape, reverse, expected_ret_shape in test_cases:
check_quantized_reshape(shape, qdtype, newshape, reverse, expected_ret_shape)


def test_quantize_params():
if is_test_for_native_cpu():
print('skipped testing quantized_params for native cpu since it is not supported yet')
Expand Down

0 comments on commit a3870eb

Please sign in to comment.