Skip to content

Commit

Permalink
[cherrypick] Fix elementwise_sub sign reverse for mkldnn (#46107)
Browse files Browse the repository at this point in the history
* fix sub sign reverse for mkldnn

* refactor code as comment

* remove useless
  • Loading branch information
zh794390558 authored Sep 26, 2022
1 parent 0ab2672 commit 6990edf
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 8 deletions.
27 changes: 19 additions & 8 deletions paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,20 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
VLOG(4) << "element sub: dx " << dx << " dy " << dy << " dout " << dout;

// oneDNN's binary is optimized for broadcasting y into x, so in other case
// we have to swap tensors to achieve optimal performance
bool swap_x_y = false;
if (x->numel() < y->numel()) {
std::swap(x, y);
std::swap(dx, dy);
swap_x_y = true;
}

std::vector<float> scales{1.0};
if (swap_x_y) {
scales[0] = (BINARY_OP == dnnl::algorithm::binary_add) ? 1 : -1;
}

int axis = ctx.Attr<int>("axis");
Expand All @@ -172,7 +180,6 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
dout->mem_desc(), platform::to_void_cast(dout->data<T>()));

auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();

if (dx) {
std::shared_ptr<dnnl::memory> dst_memory;

Expand All @@ -181,15 +188,19 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
BINARY_OP == dnnl::algorithm::binary_sub) {
dst_memory = reorder_handler.AcquireDstMemory(
dx, dout->mem_desc(), ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(dst_memory, reorder_src_memory_p);

dnnl::primitive_attr reorder_attr;
reorder_attr.set_output_scales(0, scales);
auto reorder_p = reorder_handler.AcquireReorder(
dst_memory, reorder_src_memory_p, reorder_attr);
platform::RecordEvent record_reorder(
"int_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);

reorder_p->execute(astream, *reorder_src_memory_p, *dst_memory);

} else { // elementwise_mul & elementwise_div
platform::BinaryMKLDNNHandler<T> binary_handler(BINARY_OP,
axis,
Expand Down Expand Up @@ -233,11 +244,10 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
dy, dout->mem_desc(), ctx.GetPlace());

dnnl::primitive_attr reorder_attr;
std::vector<float> scales(1);
scales[0] = (BINARY_OP == dnnl::algorithm::binary_add) ? 1 : -1;
reorder_attr.set_output_scales(0, scales);
auto reorder_p = std::make_shared<dnnl::reorder>(
*(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr);

auto reorder_p = reorder_handler.AcquireReorder(
reorder_dst_memory_p, reorder_src_memory_p, reorder_attr);
platform::RecordEvent record_reorder(
"int_reorder",
platform::TracerEventType::UserDefined,
Expand Down Expand Up @@ -331,7 +341,8 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
// Broadcasting
if (BINARY_OP == dnnl::algorithm::binary_sub) {
dnnl::post_ops po;
po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0);
po.append_eltwise(
1.0f, dnnl::algorithm::eltwise_linear, scales[0], 0);
broadcast_reduction_attr.set_post_ops(po);
}

Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/backends/onednn/onednn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,10 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
CreateAttributes(algo, scale_x, scale_y, scale_out, post_ops);

if (x->numel() < y->numel()) {
if (algo == dnnl::algorithm::binary_sub) {
attributes = CreateAttributes(
algo, -1.0 * scale_x, -1.0 * scale_y, scale_out, post_ops);
}
this->AcquireForwardPrimitiveDescriptor(
attributes, algo, src1_md, src0_md, dst_md);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from paddle.fluid.framework import _current_expected_place
import paddle.fluid.core as core

import sys


@OpTestTool.skip_if(not (isinstance(_current_expected_place(), core.CPUPlace)),
"GPU is not supported")
Expand Down Expand Up @@ -108,6 +110,23 @@ def init_axis(self):
self.axis = 1


class TestMKLDNNElementwiseSubOp40(TestMKLDNNElementwiseSubOp):

def init_input_output(self):
self.x = np.random.uniform(0.1, 2, [180, 1]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [1, 256]).astype(self.dtype)
self.out = np.subtract(self.x, self.y)

def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')

def test_check_grad_ignore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set("X"))

def test_check_grad_ignore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))


class TestElementwiseSubOp_xsize_lessthan_ysize_sub(TestMKLDNNElementwiseSubOp):

def init_input_output(self):
Expand Down

0 comments on commit 6990edf

Please sign in to comment.