Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MKL-DNN] Enable and Optimization for s8 eltwise_add (#16931)
Browse files Browse the repository at this point in the history
* optimization for s8 sum

* fix lint

* fix lint

* exclude sum in lstm cell

* remove debug info

* remove todo
  • Loading branch information
xinyu-intel authored and pengzhao-intel committed Dec 9, 2019
1 parent ce97e22 commit 538b18b
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 7 deletions.
5 changes: 5 additions & 0 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,11 @@ class OpSignature {
eles.push_back(val);
}

void AddSign(float val) {
hash = dmlc::HashCombine(hash, val);
eles.push_back(val);
}

bool operator==(const OpSignature &sign) const {
if (hash != sign.hash)
return false;
Expand Down
58 changes: 55 additions & 3 deletions src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,57 @@ static inline float GetScale(const NDArray& data, float min, float max) {
return data_range / MaxAbs(min, max);
}

class MKLDNNQuantizedElemwiseAddFwd {
public:
mkldnn::sum::primitive_desc fwd_pd;

MKLDNNQuantizedElemwiseAddFwd(
const mkldnn::memory::desc &output_desc,
const std::vector<float> &scales,
const std::vector<mkldnn::memory::desc> &data_md)
: fwd_pd(output_desc, scales, data_md, CpuEngine::Get()->get_engine()) {
fwd_ = std::make_shared<mkldnn::sum>(fwd_pd);
data_.resize(data_md.size());
}

const mkldnn::sum &GetFwd() const { return *fwd_; }

private:
std::shared_ptr<mkldnn::sum> fwd_;
std::vector<std::shared_ptr<mkldnn::memory>> data_;
std::shared_ptr<mkldnn::memory> out_;
};

static MKLDNNQuantizedElemwiseAddFwd &GetQuantizedElemwiseAddForward(
const mkldnn::memory::desc &output_desc, const std::vector<float> &scales,
const std::vector<NDArray> &in_data, const std::vector<NDArray> &out_data,
const std::vector<mkldnn::memory::desc> &data_md) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<OpSignature,
MKLDNNQuantizedElemwiseAddFwd, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<OpSignature,
MKLDNNQuantizedElemwiseAddFwd, OpHash> fwds;
#endif
OpSignature key;
key.AddSign(in_data);
key.AddSign(in_data[quantized_elemwise_add_enum::kAMin].data().dptr<float>()[0]);
key.AddSign(in_data[quantized_elemwise_add_enum::kAMax].data().dptr<float>()[0]);
key.AddSign(in_data[quantized_elemwise_add_enum::kBMin].data().dptr<float>()[0]);
key.AddSign(in_data[quantized_elemwise_add_enum::kBMax].data().dptr<float>()[0]);
key.AddSign(out_data);
key.AddSign(out_data[quantized_elemwise_add_enum::kMin].data().dptr<float>()[0]);
key.AddSign(out_data[quantized_elemwise_add_enum::kMax].data().dptr<float>()[0]);

auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNQuantizedElemwiseAddFwd fwd(output_desc, scales, data_md);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}


static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<NDArray>& in_data,
const std::vector<OpReqType>& req,
Expand Down Expand Up @@ -166,16 +217,17 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons
auto output_desc = mkldnn::memory::desc(i_dims,
output_data_type,
mkldnn::memory::format_tag::any);
mkldnn::sum::primitive_desc pdesc(output_desc, scales, in_desc, engine);
MKLDNNQuantizedElemwiseAddFwd &fwd = GetQuantizedElemwiseAddForward(output_desc, scales,
in_data, out_data, in_desc);
auto mem = CreateMKLDNNMem(out_data[quantized_elemwise_add_enum::kOut],
pdesc.dst_desc(),
fwd.fwd_pd.dst_desc(),
req[0],
&in_data[0]);
mkldnn_args_map_t args({{MKLDNN_ARG_MULTIPLE_SRC, *dataA_mem},
{MKLDNN_ARG_MULTIPLE_SRC + 1, *dataB_mem},
{MKLDNN_ARG_DST, *mem.second}});
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrimArgs(mkldnn::sum(pdesc), args);
stream->RegisterPrimArgs(fwd.GetFwd(), args);
CommitOutput(out_data[quantized_elemwise_add_enum::kOut], mem);
stream->Submit();

Expand Down
4 changes: 0 additions & 4 deletions src/operator/quantization/quantized_elemwise_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,6 @@ and max thresholds representing the threholds for quantizing the float32 output
.add_argument("rhs_max", "NDArray-or-Symbol", "6th input");


// TODO(zhangrong): need extra condition check if there's benefited if it's switched on
// Since it's not compute-intensive.
#if 0
NNVM_REGISTER_OP(elemwise_add)
.set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
nnvm::NodePtr node = nnvm::Node::Create();
Expand All @@ -139,7 +136,6 @@ NNVM_REGISTER_OP(elemwise_add)
}
return node;
});
#endif

} // namespace op
} // namespace mxnet
2 changes: 2 additions & 0 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,8 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape=N
excluded_sym_names = excluded_names
else:
excluded_sym_names = excluded_names + optional_names
if name == 'sym4':
excluded_op_names += ['elemwise_add']

qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s,
arg_params=arg_params,
Expand Down

0 comments on commit 538b18b

Please sign in to comment.