Skip to content

Commit

Permalink
Add working implementation of elementwise_sub grad
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrekobi committed Sep 20, 2021
1 parent fc02000 commit b4d7c9e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,14 @@ class EltwiseSubMKLDNNGradKernel : public ElemwiseGradKernel<T> {
} else {
// Broadcasting

dnnl::post_ops po;
po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0);
dnnl::primitive_attr attr;
attr.set_post_ops(po);

platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy), attr);

auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
Expand Down
27 changes: 26 additions & 1 deletion paddle/fluid/platform/mkldnn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include <string>
#include <utility>
#include <vector>

#include "boost/optional.hpp"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/operator.h"
Expand Down Expand Up @@ -929,7 +930,7 @@ class BroadcastDataMKLDNNHandler
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
T_out* ptr = output->mutable_data<T_out>(
this->place_, this->fwd_pd_->dst_desc().get_size());
;

memset(ptr, 0, this->fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
}
Expand Down Expand Up @@ -961,6 +962,30 @@ class ReductionMKLDNNHandler

this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps);
}

ReductionMKLDNNHandler(const dnnl::algorithm algo, const float p,
const float eps, const mkldnn::engine engine,
platform::Place cpu_place, const Tensor* x,
const Tensor* y, std::vector<int64_t> y_tz,
const dnnl::primitive_attr& attr)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction>(engine,
cpu_place) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor."));

const auto x_tz = framework::vectorize(x->dims());

const auto x_md =
dnnl::memory::desc(x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto y_md =
memory::desc(y_tz, platform::MKLDNNGetDataType<T>(), x->format());

this->AcquireForwardPrimitiveDescriptor(attr, algo, x_md, y_md, p, eps);
}
};

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,18 @@ def init_input_output(self):
self.out = np.subtract(self.x, self.y)


# class TestMKLDNNElementwiseSubOp4(TestMKLDNNElementwiseSubOp):
# def init_input_output(self):
# self.x = np.random.uniform(1, 2, [2, 3, 4, 32]).astype(self.dtype)
# self.y = np.random.uniform(1, 2, [4, 32]).astype(self.dtype)
# self.out = np.subtract(self.x, self.y)

# class TestMKLDNNElementwiseSubOp5(TestMKLDNNElementwiseSubOp):
# def init_input_output(self):
# self.x = np.random.uniform(1, 2, [2, 3, 4, 100]).astype(self.dtype)
# self.y = np.random.uniform(1, 2, [100]).astype(self.dtype)
# self.out = np.subtract(self.x, self.y)
class TestMKLDNNElementwiseSubOp4(TestMKLDNNElementwiseSubOp):
def init_input_output(self):
self.x = np.random.uniform(1, 2, [2, 3, 4, 32]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [4, 32]).astype(self.dtype)
self.out = np.subtract(self.x, self.y)


class TestMKLDNNElementwiseSubOp5(TestMKLDNNElementwiseSubOp):
def init_input_output(self):
self.x = np.random.uniform(1, 2, [2, 3, 4, 100]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [100]).astype(self.dtype)
self.out = np.subtract(self.x, self.y)


class TestMKLDNNElementwiseSubOp_broadcast_3(TestMKLDNNElementwiseSubOp):
Expand Down

0 comments on commit b4d7c9e

Please sign in to comment.