Skip to content

Commit

Permalink
[OneDNN] fix fast_rcnn bug (PaddlePaddle#57598)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglirong1999 authored and jiahy0825 committed Oct 16, 2023
1 parent 36a7c2d commit dc07a55
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions paddle/phi/kernels/funcs/data_layout_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ void TransDataLayoutFromOneDNN(DataLayout in_layout,
auto& pool = DeviceContextPool::Instance();
auto* dev_ctx = dynamic_cast<OneDNNContext*>(pool.Get(place));
auto& cpu_engine = dev_ctx->GetEngine();
auto in_dims = vectorize<int64_t>(in.dims());

auto md_dims = !in_dims.empty() ? in_dims : std::vector<int64_t>{1};
const auto src_mem_desc =
!in_dims.empty() ? in.mem_desc()
: dnnl::memory::desc(md_dims,
ToOneDNNDataType(in.dtype()),
dnnl::memory::format_tag::x);

dnnl::memory::desc out_mem_desc = make_memory_desc(in, out_layout);

Expand All @@ -100,8 +108,7 @@ void TransDataLayoutFromOneDNN(DataLayout in_layout,

ReorderOneDNNHandler handler(in_tz, in.dtype(), in_type, cpu_engine);

auto reorder_src_memory_p =
handler.AcquireSrcMemory(in.mem_desc(), in_data);
auto reorder_src_memory_p = handler.AcquireSrcMemory(src_mem_desc, in_data);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(out, out->mem_desc(), place);
auto reorder_p =
Expand Down

0 comments on commit dc07a55

Please sign in to comment.