Skip to content

Commit

Permalink
fix dropout bug in backward when input is 1d tensor (#26837)
Browse files Browse the repository at this point in the history
* fix dropout bug in backward when input is 1d tensor, test=develop

* add test case and refine error message, test=develop

* refine error message, test=develop
  • Loading branch information
huangjun12 committed Sep 3, 2020
1 parent 64a118f commit 8240c91
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
10 changes: 7 additions & 3 deletions paddle/fluid/operators/dropout_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename DeviceContext, typename T>
class CPUDropoutKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -116,9 +120,9 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace());

auto M = EigenMatrix<uint8_t>::Reshape(*mask, 1);
auto dX = EigenMatrix<T>::Reshape(*grad_x, 1);
auto dY = EigenMatrix<T>::Reshape(*grad_y, 1);
auto M = EigenVector<uint8_t>::Flatten(*mask);
auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenVector<T>::Flatten(*grad_y);

auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Expand Down
24 changes: 24 additions & 0 deletions python/paddle/fluid/tests/unittests/test_dropout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')


class TestDropoutOpInput1d(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((2000)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((2000)).astype('uint8')
}

def test_check_output(self):
self.check_output()

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')


class TestDropoutOp2(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
Expand Down Expand Up @@ -436,6 +453,13 @@ def test_axis_max():

self.assertRaises(ValueError, test_axis_max)

def test_axis_min():
# minimum of axis should greater equal than 0
x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32")
paddle.nn.functional.dropout(x2, axis=[0, -1])

self.assertRaises(ValueError, test_axis_min)

def test_axis_len():
# length of axis should not greater than dimensions of x
x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32")
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,12 +910,12 @@ def get_attrs(prog, dropout_prob, is_test, seed):
#get mask shape
input_shape = x.shape
drop_axes = [axis] if isinstance(axis, int) else axis
if max(drop_axes) > len(input_shape) - 1:
raise ValueError("axis value should less than dimensions of x:{}, but get drop_axes value:{} " \
if min(drop_axes) < 0 or max(drop_axes) > len(input_shape) - 1:
raise ValueError("axis value should be greater than or equal to 0 and less than dimensions of x:{}, but get axis value:{} " \
.format(len(input_shape), max(drop_axes)))
if len(drop_axes) > len(input_shape):
raise ValueError(
"length of axis should not greater than dimensions of x:{}, but get length of drop axes: {}".
"length of axis should not be greater than dimensions of x:{}, but get length of axis: {}".
format(len(input_shape), len(drop_axes)))
mask_shape = [1] * len(input_shape)
for i in drop_axes:
Expand Down

1 comment on commit 8240c91

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.