diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index c23a5a852dcb..f6af58bce995 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -607,6 +607,11 @@ class OpSignature { eles.push_back(val); } + void AddSign(float val) { + hash = dmlc::HashCombine(hash, val); + eles.push_back(val); + } + bool operator==(const OpSignature &sign) const { if (hash != sign.hash) return false; diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc index 2078ac4fead8..06a0ea37f95b 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc @@ -39,6 +39,57 @@ static inline float GetScale(const NDArray& data, float min, float max) { return data_range / MaxAbs(min, max); } +class MKLDNNQuantizedElemwiseAddFwd { + public: + mkldnn::sum::primitive_desc fwd_pd; + + MKLDNNQuantizedElemwiseAddFwd( + const mkldnn::memory::desc &output_desc, + const std::vector &scales, + const std::vector &data_md) + : fwd_pd(output_desc, scales, data_md, CpuEngine::Get()->get_engine()) { + fwd_ = std::make_shared(fwd_pd); + data_.resize(data_md.size()); + } + + const mkldnn::sum &GetFwd() const { return *fwd_; } + + private: + std::shared_ptr fwd_; + std::vector> data_; + std::shared_ptr out_; +}; + +static MKLDNNQuantizedElemwiseAddFwd &GetQuantizedElemwiseAddForward( + const mkldnn::memory::desc &output_desc, const std::vector &scales, + const std::vector &in_data, const std::vector &out_data, + const std::vector &data_md) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map fwds; +#else + static MX_THREAD_LOCAL std::unordered_map fwds; +#endif + OpSignature key; + key.AddSign(in_data); + key.AddSign(in_data[quantized_elemwise_add_enum::kAMin].data().dptr()[0]); + key.AddSign(in_data[quantized_elemwise_add_enum::kAMax].data().dptr()[0]); + key.AddSign(in_data[quantized_elemwise_add_enum::kBMin].data().dptr()[0]); + key.AddSign(in_data[quantized_elemwise_add_enum::kBMax].data().dptr()[0]); + key.AddSign(out_data); + key.AddSign(out_data[quantized_elemwise_add_enum::kMin].data().dptr()[0]); + key.AddSign(out_data[quantized_elemwise_add_enum::kMax].data().dptr()[0]); + + auto it = fwds.find(key); + if (it == fwds.end()) { + MKLDNNQuantizedElemwiseAddFwd fwd(output_desc, scales, data_md); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + + static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& in_data, const std::vector& req, @@ -166,16 +217,17 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons auto output_desc = mkldnn::memory::desc(i_dims, output_data_type, mkldnn::memory::format_tag::any); - mkldnn::sum::primitive_desc pdesc(output_desc, scales, in_desc, engine); + MKLDNNQuantizedElemwiseAddFwd &fwd = GetQuantizedElemwiseAddForward(output_desc, scales, + in_data, out_data, in_desc); auto mem = CreateMKLDNNMem(out_data[quantized_elemwise_add_enum::kOut], - pdesc.dst_desc(), + fwd.fwd_pd.dst_desc(), req[0], &in_data[0]); mkldnn_args_map_t args({{MKLDNN_ARG_MULTIPLE_SRC, *dataA_mem}, {MKLDNN_ARG_MULTIPLE_SRC + 1, *dataB_mem}, {MKLDNN_ARG_DST, *mem.second}}); MKLDNNStream *stream = MKLDNNStream::Get(); - stream->RegisterPrimArgs(mkldnn::sum(pdesc), args); + stream->RegisterPrimArgs(fwd.GetFwd(), args); CommitOutput(out_data[quantized_elemwise_add_enum::kOut], mem); stream->Submit(); diff --git a/src/operator/quantization/quantized_elemwise_add.cc b/src/operator/quantization/quantized_elemwise_add.cc index 0e7034e88b8c..f821e6598192 100644 --- a/src/operator/quantization/quantized_elemwise_add.cc +++ b/src/operator/quantization/quantized_elemwise_add.cc @@ -125,9 +125,6 @@ and max thresholds representing the threholds for quantizing the float32 output .add_argument("rhs_max", "NDArray-or-Symbol", "6th input"); -// TODO(zhangrong): need extra condition check if there's benefited if it's switched on -// Since it's not compute-intensive. -#if 0 NNVM_REGISTER_OP(elemwise_add) .set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { nnvm::NodePtr node = nnvm::Node::Create(); @@ -139,7 +136,6 @@ NNVM_REGISTER_OP(elemwise_add) } return node; }); -#endif } // namespace op } // namespace mxnet diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 6fe33f5ee52b..a371abddd22e 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -958,6 +958,8 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape=N excluded_sym_names = excluded_names else: excluded_sym_names = excluded_names + optional_names + if name == 'sym4': + excluded_op_names += ['elemwise_add'] qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s, arg_params=arg_params,