Skip to content

Commit

Permalink
[oneDNN] Concat refactoring and disabling caching (#35002)
Browse files Browse the repository at this point in the history
* - concat refactoring draft

* - cmpilation fixes

* - yet another compilation fix

* - fix

* - compilation fix

* - fixes to compilation

* - another compilation fix

* - fix

* - Added overloaded AcquirePrimitiveDesc for concat

* - fix

* - reserve introduced

* - UT fixes

* - test concat int8 improved

* - fixes

* - fix to crash

* - lint fixes

* - fixes after review

* - some other fixes from review
  • Loading branch information
jczaja authored Aug 24, 2021
1 parent 3b0d8a7 commit d9c0f09
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 160 deletions.
233 changes: 73 additions & 160 deletions paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,69 @@ using mkldnn::concat;
using mkldnn::stream;
using platform::to_void_cast;

template <typename T>
class ConcatMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::concat> {
public:
ConcatMKLDNNHandler(const framework::ExecutionContext& ctx,
const mkldnn::engine mkldnn_engine,
const std::vector<const Tensor*>& inputs, Tensor* output)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::concat>(mkldnn_engine,
ctx.GetPlace()) {
int concat_axis = ctx.Attr<int>("axis");
const int rank = inputs[0]->dims().size();
PADDLE_ENFORCE_EQ(
concat_axis >= -rank && concat_axis < rank, true,
platform::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank, rank, concat_axis));

if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<Tensor>("AxisTensor");
concat_axis = GetDataFromTensor(axis_tensor)[0];
auto out_dims = inputs[0]->dims();
for (size_t i = 1; i < inputs.size(); ++i) {
out_dims[concat_axis] += inputs[i]->dims()[concat_axis];
}
output->Resize(out_dims);
}

if (concat_axis < 0) {
concat_axis = concat_axis + rank;
}

memory::data_type dt = framework::ToMKLDNNDataType(inputs[0]->type());
std::vector<memory::desc> srcs_md;
srcs_md.reserve(inputs.size());

// Create memory descriptors for each of inputs
for (size_t i = 0; i < inputs.size(); ++i) {
const auto dims = framework::vectorize<int64_t>(inputs[i]->dims());
srcs_md.emplace_back(memory::desc(dims, dt, inputs[i]->format()));
}

auto dst_dims = framework::vectorize<int64_t>(output->dims());
auto dst_md = memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any);

this->AcquireForwardPrimitiveDescriptor(dst_md, concat_axis, srcs_md);
}

// (jczaja) concat oneDNN prim is not having .desc attribute so
// we cannot use base AcquireForwardPrimitiveDescriptor
void AcquireForwardPrimitiveDescriptor(
const memory::desc& dst_md, const int concat_axis,
const std::vector<memory::desc>& srcs_md) {
this->fwd_pd_.reset(new dnnl::concat::primitive_desc(
dst_md, concat_axis, srcs_md, this->engine_));
}

std::shared_ptr<mkldnn::memory> AcquireSrcMemory(const Tensor& input, int i) {
const T* input_data = input.data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
to_void_cast<T>(input_data));
}
};

static void EnforceLayouts(const std::vector<const Tensor*> inputs) {
for (auto* input : inputs) {
PADDLE_ENFORCE_EQ(
Expand All @@ -40,28 +103,6 @@ static void EnforceLayouts(const std::vector<const Tensor*> inputs) {
}
}

static memory::desc CreateMemDesc(const Tensor& input,
const memory::data_type& dt) {
const auto dims = paddle::framework::vectorize<int64_t>(input.dims());
const auto format = input.format();
auto mem_desc = memory::desc(dims, dt, format);
return mem_desc;
}

static platform::CPUPlace GetCpuPlace(
const paddle::framework::ExecutionContext& ctx) {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(paddle::platform::is_cpu_place(place),
platform::errors::InvalidArgument("It must use CPUPlace."));
return BOOST_GET_CONST(platform::CPUPlace, place);
}

static const mkldnn::engine& GetMKLDNNEngine(
const paddle::framework::ExecutionContext& ctx) {
auto& dev_ctx = ctx.template device_context<platform::MKLDNNDeviceContext>();
return dev_ctx.GetEngine();
}

// From a multi-input, gather only nonempty inputs
static const std::vector<const Tensor*> ReduceMultiInput(
const std::vector<const Tensor*>& inputs) {
Expand All @@ -72,160 +113,32 @@ static const std::vector<const Tensor*> ReduceMultiInput(
return reduced;
}

static const std::vector<int> GetDimsForKey(
const std::vector<const Tensor*>& inputs) {
auto dims_key = paddle::framework::vectorize<int>(inputs[0]->dims());
for (auto it = std::next(inputs.begin()); it != inputs.end(); ++it) {
dims_key.push_back((*it)->dims()[0]);
}
return dims_key;
}

template <typename T>
class ConcatPrimitiveFactory {
public:
concat::primitive_desc CreateConcatPrimDescriptor(
const std::vector<const Tensor*> multi_input, Tensor* output,
int concat_axis, const mkldnn::engine& mkldnn_engine,
const memory::data_type& dt = memory::data_type::f32) {
CreateSourcesDescriptors(multi_input, mkldnn_engine, dt);
auto dst_desc = CreateDstMemDescriptor(output, dt);
return concat::primitive_desc(dst_desc, concat_axis, srcs_d, mkldnn_engine);
}

concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd,
Tensor* output, platform::CPUPlace place,
const mkldnn::engine& mkldnn_engine) {
dst_mem = mkldnn::memory(
concat_pd.dst_desc(), mkldnn_engine,
output->mutable_data<T>(place, concat_pd.dst_desc().get_size()));

return concat(concat_pd);
}

void SetSrcDataHandleByIndex(const std::vector<memory>& srcs, const size_t& i,
void* handler) {
srcs[i].set_data_handle(handler);
}

void SetDstDataHandle(const memory& dst_mem, void* handler) {
dst_mem.set_data_handle(handler);
}

std::vector<memory> GetSrcs() { return srcs; }

memory GetDst() { return dst_mem.get(); }

private:
memory::desc CreateDstMemDescriptor(Tensor* output,
const memory::data_type& dt) {
auto dst_dims = paddle::framework::vectorize<int64_t>(output->dims());
return memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any);
}

void CreateSourcesDescriptors(const std::vector<const Tensor*> multi_input,
const mkldnn::engine& mkldnn_engine,
const memory::data_type& dt) {
for (size_t i = 0; i < multi_input.size(); i++) {
auto mem_desc = CreateMemDesc(*multi_input[i], dt);
srcs_d.push_back(mem_desc);
srcs.push_back(memory(mem_desc, mkldnn_engine,
to_void_cast(multi_input[i]->data<T>())));
}
}

private:
std::vector<memory::desc> srcs_d;
std::vector<mkldnn::memory> srcs;
paddle::optional<mkldnn::memory> dst_mem;
};

template <typename T>
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
// If any of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change
auto multi_input = ReduceMultiInput(ctx.MultiInput<Tensor>("X"));
EnforceLayouts(multi_input);
Tensor* output = ctx.Output<Tensor>("Out");
int concat_axis = ctx.Attr<int>("axis");
const int rank = multi_input[0]->dims().size();
PADDLE_ENFORCE_EQ(
concat_axis >= -rank && concat_axis < rank, true,
platform::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank, rank, concat_axis));
platform::MKLDNNDeviceContext::tls().log_lib_version();

if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<Tensor>("AxisTensor");
concat_axis = GetDataFromTensor(axis_tensor)[0];
auto out_dims = multi_input[0]->dims();
for (size_t i = 1; i < multi_input.size(); ++i) {
out_dims[concat_axis] += multi_input[i]->dims()[concat_axis];
}
output->Resize(out_dims);
}
ConcatMKLDNNHandler<T> handler(ctx, mkldnn_engine, multi_input, output);

if (concat_axis < 0) {
concat_axis = concat_axis + rank;
}
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
auto place = GetCpuPlace(ctx);

memory::data_type dt =
paddle::framework::ToMKLDNNDataType(multi_input[0]->type());

ConcatPrimitiveFactory<T> prim_creator;
std::string key =
platform::CreateKey(dev_ctx, GetDimsForKey(multi_input),
multi_input.size(), ctx.OutputName("Out"), dt);
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
std::vector<std::shared_ptr<memory>> srcs;
srcs.reserve(multi_input.size());

const std::string key_prim = key + "@concat_p";
const std::string key_concat_pd = key + "@concat_pd";
const std::string key_srcs = key + "@concat_srcs";
const std::string key_dst = key + "@concat_dst";

std::shared_ptr<concat::primitive_desc> concat_pd;
std::shared_ptr<std::vector<memory>> srcs;
std::shared_ptr<memory> dst_mem;
auto concat_p = std::static_pointer_cast<concat>(dev_ctx.GetBlob(key_prim));

const auto& mkldnn_engine = dev_ctx.GetEngine();
if (concat_p == nullptr) {
concat_pd = std::make_shared<concat::primitive_desc>(
prim_creator.CreateConcatPrimDescriptor(
multi_input, output, concat_axis, mkldnn_engine, dt));
concat_p = std::make_shared<concat>(prim_creator.CreateConcatPrimitive(
*concat_pd, output, place, mkldnn_engine));
srcs = std::make_shared<std::vector<memory>>(prim_creator.GetSrcs());
dst_mem = std::make_shared<memory>(prim_creator.GetDst());
dev_ctx.SetBlob(key_prim, concat_p);
dev_ctx.SetBlob(key_concat_pd, concat_pd);
dev_ctx.SetBlob(key_srcs, srcs);
dev_ctx.SetBlob(key_dst, dst_mem);
} else {
srcs = std::static_pointer_cast<std::vector<memory>>(
dev_ctx.GetBlob(key_srcs));
dst_mem = std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_dst));
concat_pd = std::static_pointer_cast<concat::primitive_desc>(
dev_ctx.GetBlob(key_concat_pd));
for (size_t i = 0; i < multi_input.size(); i++) {
prim_creator.SetSrcDataHandleByIndex(
*srcs, i, to_void_cast<T>(multi_input[i]->data<T>()));
}
prim_creator.SetDstDataHandle(
*dst_mem,
output->mutable_data<T>(place, concat_pd->dst_desc().get_size()));
}
auto dst_mem = handler.AcquireDstMemory(output);
auto concat_p = handler.AcquireForwardPrimitive();

auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
std::unordered_map<int, memory> args;
for (size_t i = 0; i < multi_input.size(); ++i) {
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, (*srcs).at(i)});
srcs.push_back(handler.AcquireSrcMemory(*(multi_input[i]), i));
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *(srcs.at(i))});
}
args.insert({MKLDNN_ARG_DST, *dst_mem});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,6 @@ def init_shape(self):
create_test_int8_class(TestConcatOp2)

if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,6 @@ def init_kernel_type(self):


if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()

0 comments on commit d9c0f09

Please sign in to comment.