Skip to content

Commit

Permalink
Fix LeakyRelu behaviour on empty input (apache#18934) (apache#18996)
Browse files Browse the repository at this point in the history
* Fix LeakyRelu behaviour on empty input

* Remove duplicated declarations
  • Loading branch information
bgawrych authored and chinakook committed Nov 17, 2020
1 parent f97e39d commit 6c5730a
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ void LeakyReLUCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (inputs[0].Size() == 0U) return;
const LeakyReLUParam &param = nnvm::get<LeakyReLUParam>(attrs.parsed);
const std::vector<TBlob> no_use_but_adapt_origin_api;
size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1;
Expand All @@ -352,6 +353,7 @@ void LeakyReLUGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (inputs[0].Size() == 0U) return;
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
const std::vector<TBlob> no_use_but_adapt_origin_api;
// inputs: out_grad, input_data, input_gamma, output, output_mask
Expand Down
2 changes: 2 additions & 0 deletions src/operator/leaky_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ static void LeakyReLUComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1;
CHECK_EQ(inputs.size(), expected);
Expand All @@ -107,6 +108,7 @@ void LeakyReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
if (SupportMKLDNNLeakyRelu(param, inputs[0])) {
std::vector<NDArray> in_data{inputs[0], inputs[1]};
Expand Down
7 changes: 0 additions & 7 deletions src/operator/nn/mkldnn/mkldnn_act-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,6 @@ MKLDNNActForward &GetActForward(const MKLDNNActParam& param,
const OpContext &ctx, const NDArray &in_data,
const mkldnn::memory &in_mem);

void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
const MKLDNNActParam &param, const mkldnn::memory &input_mem,
const mkldnn::memory &diff_dst_memory);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/mkldnn/mkldnn_quantized_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
*/
#if MXNET_USE_MKLDNN == 1

#include "../../nn/mkldnn/mkldnn_act-inl.h"
#include "../../nn/mkldnn/mkldnn_ops-inl.h"
#include "../quantization_utils.h"

namespace mxnet {
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,14 @@ def test_18933_channel_0():
with autograd.record():
a = npx.instance_norm(arr, gamma, beta)
a.backward()

@use_np
@with_environment('MXNET_ENGINE_TYPE', 'NaiveEngine')
def test_18934_empty_leaky_relu():
arr = np.random.rand(0,2)
arr_grad = np.empty_like(arr)

autograd.mark_variables([arr], [arr_grad])
with autograd.record():
res = npx.leaky_relu(arr)
res.backward()

0 comments on commit 6c5730a

Please sign in to comment.