Skip to content

Commit

Permalink
reduce_mean error if keepdim=True and reduce_all=True
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang committed Aug 24, 2020
1 parent 79539cf commit 5a63442
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 104 deletions.
4 changes: 1 addition & 3 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,4 @@ using CUDAReduceMeanGradKernel =
ops::MeanGradFunctor, true>;

REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<float>,
CUDAReduceMeanGradKernel<double>,
CUDAReduceMeanGradKernel<int>,
CUDAReduceMeanGradKernel<int64_t>);
CUDAReduceMeanGradKernel<double>);
4 changes: 2 additions & 2 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ class ReduceGradKernel : public framework::OpKernel<T> {

if (reduce_all) {
auto x = EigenVector<T>::Flatten(*input0);
auto x_reduce = EigenVector<T>::From(*input1);
auto x_reduce_grad = EigenVector<T>::From(*input2);
auto x_reduce = EigenVector<T>::Flatten(*input1);
auto x_reduce_grad = EigenVector<T>::Flatten(*input2);
auto x_grad = EigenVector<T>::Flatten(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Expand Down
119 changes: 116 additions & 3 deletions python/paddle/fluid/tests/unittests/test_mean_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard

np.random.seed(10)


class TestMeanOp(OpTest):
def setUp(self):
Expand Down Expand Up @@ -74,10 +76,105 @@ def test_checkout_grad(self):
place, ['X'], 'Out', max_relative_error=0.8)


def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False):
if isinstance(axis, list):
axis = tuple(axis)
if reduce_all:
axis = None
return np.mean(x, axis=axis, keepdims=keepdim)


class TestReduceMeanOp(OpTest):
def setUp(self):
self.op_type = 'reduce_mean'
self.dtype = 'float64'
self.shape = [2, 3, 4, 5]
self.axis = [0]
self.keepdim = False
self.reduce_all = False
self.set_attrs()

np.random.seed(10)
x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out_np = ref_reduce_mean(x_np, self.axis, self.keepdim, self.reduce_all)
self.inputs = {'X': x_np}
self.outputs = {'Out': out_np}
self.attrs = {
'dim': self.axis,
'keep_dim': self.keepdim,
'reduce_all': self.reduce_all
}

def set_attrs(self):
pass

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], ['Out'])


class TestReduceMeanOpDefaultAttrs(TestReduceMeanOp):
def setUp(self):
self.op_type = 'reduce_mean'
self.dtype = 'float64'
self.shape = [2, 3, 4, 5]

x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out_np = np.mean(x_np, axis=0)
self.inputs = {'X': x_np}
self.outputs = {'Out': out_np}


class TestReduceMeanOpFloat32(TestReduceMeanOp):
def set_attrs(self):
self.dtype = 'float32'


class TestReduceMeanOpShape1D(TestReduceMeanOp):
def set_attrs(self):
self.shape = [100]


class TestReduceMeanOpShape6D(TestReduceMeanOp):
def set_attrs(self):
self.shape = [2, 3, 4, 5, 6, 7]


class TestReduceMeanOpAxisAll(TestReduceMeanOp):
def set_attrs(self):
self.axis = [0, 1, 2, 3]


class TestReduceMeanOpAxisTuple(TestReduceMeanOp):
def set_attrs(self):
self.axis = (0, 1, 2)


class TestReduceMeanOpAxisNegative(TestReduceMeanOp):
def set_attrs(self):
self.axis = [-2, -1]


class TestReduceMeanOpKeepdimTrue1(TestReduceMeanOp):
def set_attrs(self):
self.keepdim = True


class TestReduceMeanOpKeepdimTrue2(TestReduceMeanOp):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
self.keepdim = True


class TestReduceMeanOpReduceAllTrue(TestReduceMeanOp):
def set_attrs(self):
self.reduce_all = True


class TestMeanAPI(unittest.TestCase):
"""
test paddle.tensor.stat.mean
"""
# test paddle.tensor.stat.mean

def setUp(self):
self.x_shape = [2, 3, 4, 5]
Expand Down Expand Up @@ -128,6 +225,22 @@ def test_case(x, axis=None, keepdim=False):
test_case(self.x, [0, 1, 2, 3])
paddle.enable_static()

def test_fluid_api(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = fluid.data("x", shape=[10, 10], dtype="float32")
out = fluid.layers.reduce_mean(input=x, dim=1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
x_np = np.random.rand(10, 10).astype(np.float32)
res = exe.run(feed={"x": x_np}, fetch_list=[out])
self.assertEqual(np.allclose(res[0], np.mean(x_np, axis=1)), True)

with fluid.dygraph.guard():
x_np = np.random.rand(10, 10).astype(np.float32)
x = fluid.dygraph.to_variable(x_np)
out = fluid.layers.reduce_mean(input=x, dim=1)
self.assertEqual(np.allclose(out.numpy(), np.mean(x_np, axis=1)), True)

def test_errors(self):
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 12]).astype('float32')
Expand Down
96 changes: 0 additions & 96 deletions python/paddle/fluid/tests/unittests/test_reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,6 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestMeanOp(OpTest):
def setUp(self):
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")}
self.attrs = {'dim': [1]}
self.outputs = {
'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim']))
}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')


@skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
Expand Down Expand Up @@ -318,21 +302,6 @@ def setUp(self):
self.outputs = {'Out': self.inputs['X'].sum()}


## reduction in multi dims
class TestReduceMeanOpMultiAxises(OpTest):
def setUp(self):
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")}
self.attrs = {'dim': [1, 2]}
self.outputs = {'Out': self.inputs['X'].mean(axis=(1, 2))}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')


@skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
Expand Down Expand Up @@ -420,40 +389,6 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestReduceMeanWithDimOne(OpTest):
def setUp(self):
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((100, 1, 1)).astype("float64")}
self.attrs = {'dim': [1], 'keep_dim': False}
self.outputs = {
'Out': self.inputs['X'].mean(
axis=tuple(self.attrs['dim']), keepdims=False)
}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestReduceMeanWithNumelOne(OpTest):
def setUp(self):
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((100, 1)).astype("float64")}
self.attrs = {'dim': [1], 'keep_dim': True}
self.outputs = {
'Out': self.inputs['X'].mean(
axis=tuple(self.attrs['dim']), keepdims=True)
}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestReduceAll(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
Expand Down Expand Up @@ -536,18 +471,6 @@ def test_errors(self):
self.assertRaises(TypeError, fluid.layers.reduce_sum, x2)


class TestReduceMeanOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of reduce_mean_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.reduce_mean, x1)
# The input dtype of reduce_mean_op must be float32 or float64 or int32 or int64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.reduce_mean, x2)


class API_TestSumOpError(unittest.TestCase):
def test_errors(self):
def test_dtype1():
Expand Down Expand Up @@ -649,24 +572,5 @@ def test_dygraph(self):
self.assertTrue((out3 == np.sum(np_x, axis=(0, 1, 2))).all())


class API_TestReduceMeanOp(unittest.TestCase):
def test_static(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = fluid.data("x", shape=[10, 10], dtype="float32")
out = fluid.layers.reduce_mean(input=x, dim=1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
x_np = np.random.rand(10, 10).astype(np.float32)
res = exe.run(feed={"x": x_np}, fetch_list=[out])
self.assertEqual(np.allclose(res[0], np.mean(x_np, axis=1)), True)

def test_dygraph(self):
with fluid.dygraph.guard():
x_np = np.random.rand(10, 10).astype(np.float32)
x = fluid.dygraph.to_variable(x_np)
out = fluid.layers.reduce_mean(input=x, dim=1)
self.assertEqual(np.allclose(out.numpy(), np.mean(x_np, axis=1)), True)


if __name__ == '__main__':
unittest.main()

1 comment on commit 5a63442

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.