Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Random.normal() with backward #16330

Merged
merged 5 commits into from
Nov 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/operator/numpy/random/dist_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const DistParam &param = nnvm::get<DistParam>(attrs.parsed);
CHECK_EQ(out_attrs->size(), 1U);
if (param.size.has_value()) {
// Size declared.
std::vector<dim_t> oshape_vec;
Expand Down Expand Up @@ -173,7 +172,10 @@ inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs,
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1))
}
}
return shape_is_known(out_attrs->at(0));
if (out_attrs->size() == 2U) {
SHAPE_ASSIGN_CHECK(*out_attrs, 1, out_attrs->at(0));
}
return true;
}

template <typename DistParam>
Expand Down
37 changes: 35 additions & 2 deletions src/operator/numpy/random/np_normal_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ NNVM_REGISTER_OP(_npi_normal)
return num_inputs;
}
)
.set_num_outputs(1)
.set_num_outputs(2)
.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
[](const NodeAttrs& attrs) {
return 1;
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const NumpyNormalParam& param = nnvm::get<NumpyNormalParam>(attrs.parsed);
Expand All @@ -60,10 +64,39 @@ NNVM_REGISTER_OP(_npi_normal)
ResourceRequest::kRandom, ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyNormalForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_broadcast_normal"})
.add_argument("input1", "NDArray-or-Symbol", "Source input")
.add_argument("input2", "NDArray-or-Symbol", "Source input")
.add_arguments(NumpyNormalParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_broadcast_normal)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<NumpyNormalParam>)
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyNormalParam& param = nnvm::get<NumpyNormalParam>(attrs.parsed);
int num_inputs = 6;
if (param.loc.has_value()) num_inputs -= 1;
if (param.scale.has_value()) num_inputs -= 1;
return num_inputs;
}
)
.set_num_outputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyNormalParam& param = nnvm::get<NumpyNormalParam>(attrs.parsed);
int num_outputs = 2;
if (param.loc.has_value()) num_outputs -= 1;
if (param.scale.has_value()) num_outputs -= 1;
return num_outputs;
}
)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NormalReparamBackward<cpu>)
.add_arguments(NumpyNormalParam::__FIELDS__());


} // namespace op
} // namespace mxnet
3 changes: 3 additions & 0 deletions src/operator/numpy/random/np_normal_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,8 @@ namespace op {
NNVM_REGISTER_OP(_npi_normal)
.set_attr<FCompute>("FCompute<gpu>", NumpyNormalForward<gpu>);

NNVM_REGISTER_OP(_backward_broadcast_normal)
.set_attr<FCompute>("FCompute<gpu>", NormalReparamBackward<gpu>);

} // namespace op
} // namespace mxnet
126 changes: 118 additions & 8 deletions src/operator/numpy/random/np_normal_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define MXNET_OPERATOR_NUMPY_RANDOM_NP_NORMAL_OP_H_

#include <mxnet/operator_util.h>
#include <cstdio>
#include <algorithm>
#include <string>
#include <vector>
Expand Down Expand Up @@ -78,6 +79,7 @@ inline bool NumpyNormalOpType(const nnvm::NodeAttrs &attrs,
} else {
(*out_attrs)[0] = mshadow::kFloat32;
}
(*out_attrs)[1] = mshadow::kFloat32;
return true;
}

Expand Down Expand Up @@ -130,7 +132,7 @@ template <typename IType>
struct check_legal_scale_kernel {
MSHADOW_XINLINE static void Map(index_t i, IType *scalar, float* flag) {
if (scalar[i] < 0) {
flag[0] = -1.0;
*flag = -1.0;
}
}
};
Expand All @@ -146,22 +148,18 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs,
using namespace mshadow;
using namespace mxnet_op;
const NumpyNormalParam &param = nnvm::get<NumpyNormalParam>(attrs.parsed);
CHECK_EQ(outputs.size(), 1);
Stream<xpu> *s = ctx.get_stream<xpu>();

// Generate base random number.
Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
index_t output_len = outputs[0].Size();
Tensor<xpu, 1, float> workspace =
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(output_len + 1), s);
Tensor<xpu, 1, float> normal_tensor = workspace.Slice(0, output_len);
Tensor<xpu, 1, float> indicator_device = workspace.Slice(output_len, output_len + 1);
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(1), s);
Tensor<xpu, 1, float> normal_tensor = outputs[1].FlatTo1D<xpu, float>(s);
Tensor<xpu, 1, float> indicator_device = workspace;
float indicator_host = 1.0;
float *indicator_device_ptr = indicator_device.dptr_;
Kernel<set_zero, xpu>::Launch(s, 1, indicator_device_ptr);
prnd->SampleGaussian(&normal_tensor, 0.0, 1.0);
mxnet::TShape new_lshape, new_hshape, new_oshape;

// [scalar scalar] case
if (inputs.size() == 0U) {
CHECK_GE(param.scale.value(), 0.0) << "ValueError: scale < 0";
Expand Down Expand Up @@ -228,6 +226,118 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs,
}
}

template<typename xpu, int ndim, typename DType>
inline void NormalReparamBackwardImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mxnet::TShape& new_lshape,
const mxnet::TShape& new_rshape,
const mxnet::TShape& new_oshape) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob lgrad = outputs[0].reshape(new_lshape);
const TBlob rgrad = outputs[1].reshape(new_rshape);
const TBlob ograd = inputs[0].reshape(new_oshape);
// Mean
const TBlob lhs = inputs[2].reshape(new_lshape);
// Variance
const TBlob rhs = inputs[3].reshape(new_rshape);
const TBlob samples = inputs[4].reshape(new_oshape);
const TBlob noise = inputs[5].reshape(new_oshape);
size_t workspace_size_l = ReduceWorkspaceSize<ndim, DType>(
s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_);
size_t workspace_size_r = ReduceWorkspaceSize<ndim, DType>(
s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_);
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(s,
lgrad, req[0], workspace, ograd);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
s, rgrad, req[1], workspace, ograd, noise, rhs);
}

template<typename xpu, int ndim, typename DType>
inline void ScalarNormalReparamBackwardImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mxnet::TShape& new_ishape,
const mxnet::TShape& new_oshape,
const bool loc_is_tensor) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob igrad = outputs[0].reshape(new_ishape);
// inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
// samples, noise]
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob itensor = inputs[2].reshape(new_ishape);
const TBlob samples = inputs[3].reshape(new_oshape);
const TBlob noise = inputs[4].reshape(new_oshape);
size_t workspace_size =
ReduceWorkspaceSize<ndim, DType>(s, igrad.shape_, req[0], ograd.shape_);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
if (loc_is_tensor) {
Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(s, igrad, req[0],
workspace, ograd);
} else {
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
s, igrad, req[0], workspace, ograd, noise, noise);
}
}

// Allow normal sampling to be differentiable,
// using reparameterization trick described in:
// Auto-encoding variational bayes.
// Kingma, D. P., & Welling, M. (2013).
template<typename xpu>
void NormalReparamBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
// skip kernel launch for zero-size tensors
if (inputs[0].shape_.Size() == 0U) {
return;
}
// [scalar scalar] case
if (outputs.size() == 0U) {
return;
}
const NumpyNormalParam &param = nnvm::get<NumpyNormalParam>(attrs.parsed);
// [tensor tensor] case
if (inputs.size() == 6U) {
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = FillShape(outputs[0].shape_, outputs[1].shape_, inputs[0].shape_,
&new_lshape, &new_rshape, &new_oshape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
NormalReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, req, outputs, new_lshape, new_rshape, new_oshape);
});
});
}
// [tensor scalar], [scalar tensor] case
if (inputs.size() == 5U) {
mxnet::TShape new_ishape, new_oshape;
int ndim = FillShape(outputs[0].shape_, outputs[0].shape_, inputs[0].shape_,
&new_ishape, &new_ishape, &new_oshape);
bool loc_is_tensor = !param.loc.has_value();
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
ScalarNormalReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, req, outputs, new_ishape, new_oshape, loc_is_tensor);
});
});
}
}

} // namespace op
} // namespace mxnet

Expand Down
40 changes: 40 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2605,6 +2605,46 @@ def hybrid_forward(self, F, a, *args, **kwargs):
check_unary_func(func, ref_grad, shape, low, high)


@with_seed()
@use_np
def test_np_normal_grad():
class TestNormalGrad(HybridBlock):
def __init__(self, shape):
super(TestNormalGrad, self).__init__()
self._shape = shape

def hybrid_forward(self, F, loc, scale):
return F.np.random.normal(loc, scale, self._shape)

dtypes = ['float16', 'float32', 'float64']
param_shape = [
[(3, 2), (3, 2)],
[(3, 2, 2), (3, 2, 2)],
[(3, 4, 5), (4, 1)],
]
output_shapes = [
(3, 2),
(4, 3, 2, 2),
(3, 4, 5)
]
for hybridize in [False, True]:
for dtype in dtypes:
for ((shape1, shape2), out_shape) in zip(param_shape, output_shapes):
test_normal_grad = TestNormalGrad(out_shape)
if hybridize:
test_normal_grad.hybridize()
loc = np.zeros(shape1)
loc.attach_grad()
scale = np.ones(shape2)
scale.attach_grad()
with mx.autograd.record():
samples = test_normal_grad(loc, scale)
samples.backward()
assert loc.grad.shape == shape1
assert scale.grad.shape == shape2
assert_almost_equal(loc.grad.asnumpy().sum(), _np.ones(out_shape).sum(), rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_random():
Expand Down