diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 8901c0afb369a..57a56776736ff 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -29,6 +29,69 @@ using mkldnn::concat; using mkldnn::stream; using platform::to_void_cast; +template +class ConcatMKLDNNHandler + : public platform::MKLDNNHandlerNoCachingT { + public: + ConcatMKLDNNHandler(const framework::ExecutionContext& ctx, + const mkldnn::engine mkldnn_engine, + const std::vector& inputs, Tensor* output) + : platform::MKLDNNHandlerNoCachingT(mkldnn_engine, + ctx.GetPlace()) { + int concat_axis = ctx.Attr("axis"); + const int rank = inputs[0]->dims().size(); + PADDLE_ENFORCE_EQ( + concat_axis >= -rank && concat_axis < rank, true, + platform::errors::InvalidArgument( + "The axis is expected to be in range of [%d, %d), but got %d", + -rank, rank, concat_axis)); + + if (ctx.HasInput("AxisTensor")) { + auto* axis_tensor = ctx.Input("AxisTensor"); + concat_axis = GetDataFromTensor(axis_tensor)[0]; + auto out_dims = inputs[0]->dims(); + for (size_t i = 1; i < inputs.size(); ++i) { + out_dims[concat_axis] += inputs[i]->dims()[concat_axis]; + } + output->Resize(out_dims); + } + + if (concat_axis < 0) { + concat_axis = concat_axis + rank; + } + + memory::data_type dt = framework::ToMKLDNNDataType(inputs[0]->type()); + std::vector srcs_md; + srcs_md.reserve(inputs.size()); + + // Create memory descriptors for each of inputs + for (size_t i = 0; i < inputs.size(); ++i) { + const auto dims = framework::vectorize(inputs[i]->dims()); + srcs_md.emplace_back(memory::desc(dims, dt, inputs[i]->format())); + } + + auto dst_dims = framework::vectorize(output->dims()); + auto dst_md = memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any); + + this->AcquireForwardPrimitiveDescriptor(dst_md, concat_axis, srcs_md); + } + + // (jczaja) concat oneDNN prim is not having .desc attribute so + // we cannot use base AcquireForwardPrimitiveDescriptor + void AcquireForwardPrimitiveDescriptor( + const memory::desc& dst_md, const int concat_axis, + const std::vector& srcs_md) { + this->fwd_pd_.reset(new dnnl::concat::primitive_desc( + dst_md, concat_axis, srcs_md, this->engine_)); + } + + std::shared_ptr AcquireSrcMemory(const Tensor& input, int i) { + const T* input_data = input.data(); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), + to_void_cast(input_data)); + } +}; + static void EnforceLayouts(const std::vector inputs) { for (auto* input : inputs) { PADDLE_ENFORCE_EQ( @@ -40,28 +103,6 @@ static void EnforceLayouts(const std::vector inputs) { } } -static memory::desc CreateMemDesc(const Tensor& input, - const memory::data_type& dt) { - const auto dims = paddle::framework::vectorize(input.dims()); - const auto format = input.format(); - auto mem_desc = memory::desc(dims, dt, format); - return mem_desc; -} - -static platform::CPUPlace GetCpuPlace( - const paddle::framework::ExecutionContext& ctx) { - auto place = ctx.GetPlace(); - PADDLE_ENFORCE(paddle::platform::is_cpu_place(place), - platform::errors::InvalidArgument("It must use CPUPlace.")); - return BOOST_GET_CONST(platform::CPUPlace, place); -} - -static const mkldnn::engine& GetMKLDNNEngine( - const paddle::framework::ExecutionContext& ctx) { - auto& dev_ctx = ctx.template device_context(); - return dev_ctx.GetEngine(); -} - // From a multi-input, gather only nonempty inputs static const std::vector ReduceMultiInput( const std::vector& inputs) { @@ -72,160 +113,32 @@ static const std::vector ReduceMultiInput( return reduced; } -static const std::vector GetDimsForKey( - const std::vector& inputs) { - auto dims_key = paddle::framework::vectorize(inputs[0]->dims()); - for (auto it = std::next(inputs.begin()); it != inputs.end(); ++it) { - dims_key.push_back((*it)->dims()[0]); - } - return dims_key; -} - -template -class ConcatPrimitiveFactory { - public: - concat::primitive_desc CreateConcatPrimDescriptor( - const std::vector multi_input, Tensor* output, - int concat_axis, const mkldnn::engine& mkldnn_engine, - const memory::data_type& dt = memory::data_type::f32) { - CreateSourcesDescriptors(multi_input, mkldnn_engine, dt); - auto dst_desc = CreateDstMemDescriptor(output, dt); - return concat::primitive_desc(dst_desc, concat_axis, srcs_d, mkldnn_engine); - } - - concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd, - Tensor* output, platform::CPUPlace place, - const mkldnn::engine& mkldnn_engine) { - dst_mem = mkldnn::memory( - concat_pd.dst_desc(), mkldnn_engine, - output->mutable_data(place, concat_pd.dst_desc().get_size())); - - return concat(concat_pd); - } - - void SetSrcDataHandleByIndex(const std::vector& srcs, const size_t& i, - void* handler) { - srcs[i].set_data_handle(handler); - } - - void SetDstDataHandle(const memory& dst_mem, void* handler) { - dst_mem.set_data_handle(handler); - } - - std::vector GetSrcs() { return srcs; } - - memory GetDst() { return dst_mem.get(); } - - private: - memory::desc CreateDstMemDescriptor(Tensor* output, - const memory::data_type& dt) { - auto dst_dims = paddle::framework::vectorize(output->dims()); - return memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any); - } - - void CreateSourcesDescriptors(const std::vector multi_input, - const mkldnn::engine& mkldnn_engine, - const memory::data_type& dt) { - for (size_t i = 0; i < multi_input.size(); i++) { - auto mem_desc = CreateMemDesc(*multi_input[i], dt); - srcs_d.push_back(mem_desc); - srcs.push_back(memory(mem_desc, mkldnn_engine, - to_void_cast(multi_input[i]->data()))); - } - } - - private: - std::vector srcs_d; - std::vector srcs; - paddle::optional dst_mem; -}; - template class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); // If any of the multiple inputs of concat has an input size of 0, the // actual size of the multi_input will change auto multi_input = ReduceMultiInput(ctx.MultiInput("X")); EnforceLayouts(multi_input); Tensor* output = ctx.Output("Out"); - int concat_axis = ctx.Attr("axis"); - const int rank = multi_input[0]->dims().size(); - PADDLE_ENFORCE_EQ( - concat_axis >= -rank && concat_axis < rank, true, - platform::errors::InvalidArgument( - "The axis is expected to be in range of [%d, %d), but got %d", - -rank, rank, concat_axis)); - platform::MKLDNNDeviceContext::tls().log_lib_version(); - if (ctx.HasInput("AxisTensor")) { - auto* axis_tensor = ctx.Input("AxisTensor"); - concat_axis = GetDataFromTensor(axis_tensor)[0]; - auto out_dims = multi_input[0]->dims(); - for (size_t i = 1; i < multi_input.size(); ++i) { - out_dims[concat_axis] += multi_input[i]->dims()[concat_axis]; - } - output->Resize(out_dims); - } + ConcatMKLDNNHandler handler(ctx, mkldnn_engine, multi_input, output); - if (concat_axis < 0) { - concat_axis = concat_axis + rank; - } - auto& dev_ctx = - ctx.template device_context(); - auto place = GetCpuPlace(ctx); - - memory::data_type dt = - paddle::framework::ToMKLDNNDataType(multi_input[0]->type()); - - ConcatPrimitiveFactory prim_creator; - std::string key = - platform::CreateKey(dev_ctx, GetDimsForKey(multi_input), - multi_input.size(), ctx.OutputName("Out"), dt); - key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + std::vector> srcs; + srcs.reserve(multi_input.size()); - const std::string key_prim = key + "@concat_p"; - const std::string key_concat_pd = key + "@concat_pd"; - const std::string key_srcs = key + "@concat_srcs"; - const std::string key_dst = key + "@concat_dst"; - - std::shared_ptr concat_pd; - std::shared_ptr> srcs; - std::shared_ptr dst_mem; - auto concat_p = std::static_pointer_cast(dev_ctx.GetBlob(key_prim)); - - const auto& mkldnn_engine = dev_ctx.GetEngine(); - if (concat_p == nullptr) { - concat_pd = std::make_shared( - prim_creator.CreateConcatPrimDescriptor( - multi_input, output, concat_axis, mkldnn_engine, dt)); - concat_p = std::make_shared(prim_creator.CreateConcatPrimitive( - *concat_pd, output, place, mkldnn_engine)); - srcs = std::make_shared>(prim_creator.GetSrcs()); - dst_mem = std::make_shared(prim_creator.GetDst()); - dev_ctx.SetBlob(key_prim, concat_p); - dev_ctx.SetBlob(key_concat_pd, concat_pd); - dev_ctx.SetBlob(key_srcs, srcs); - dev_ctx.SetBlob(key_dst, dst_mem); - } else { - srcs = std::static_pointer_cast>( - dev_ctx.GetBlob(key_srcs)); - dst_mem = std::static_pointer_cast(dev_ctx.GetBlob(key_dst)); - concat_pd = std::static_pointer_cast( - dev_ctx.GetBlob(key_concat_pd)); - for (size_t i = 0; i < multi_input.size(); i++) { - prim_creator.SetSrcDataHandleByIndex( - *srcs, i, to_void_cast(multi_input[i]->data())); - } - prim_creator.SetDstDataHandle( - *dst_mem, - output->mutable_data(place, concat_pd->dst_desc().get_size())); - } + auto dst_mem = handler.AcquireDstMemory(output); + auto concat_p = handler.AcquireForwardPrimitive(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); std::unordered_map args; for (size_t i = 0; i < multi_input.size(); ++i) { - args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, (*srcs).at(i)}); + srcs.push_back(handler.AcquireSrcMemory(*(multi_input[i]), i)); + args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *(srcs.at(i))}); } args.insert({MKLDNN_ARG_DST, *dst_mem}); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_concat_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_concat_int8_mkldnn_op.py index ca15ea2aaf975..ef2fa1c1cc268 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_concat_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_concat_int8_mkldnn_op.py @@ -122,4 +122,6 @@ def init_shape(self): create_test_int8_class(TestConcatOp2) if __name__ == '__main__': + from paddle import enable_static + enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py index 4f3dece5be342..4900b42d3618d 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py @@ -87,4 +87,6 @@ def init_kernel_type(self): if __name__ == '__main__': + from paddle import enable_static + enable_static() unittest.main()