Skip to content

Commit

Permalink
reduce_mean error if keepdim=True and reduce_all=True (#26614)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang committed Aug 25, 2020
1 parent a31dbc8 commit c80fcf9
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 104 deletions.
4 changes: 1 addition & 3 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,4 @@ using CUDAReduceMeanGradKernel =
ops::MeanGradFunctor, true>;

REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<float>,
CUDAReduceMeanGradKernel<double>,
CUDAReduceMeanGradKernel<int>,
CUDAReduceMeanGradKernel<int64_t>);
CUDAReduceMeanGradKernel<double>);
4 changes: 2 additions & 2 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ class ReduceGradKernel : public framework::OpKernel<T> {

if (reduce_all) {
auto x = EigenVector<T>::Flatten(*input0);
auto x_reduce = EigenVector<T>::From(*input1);
auto x_reduce_grad = EigenVector<T>::From(*input2);
auto x_reduce = EigenVector<T>::Flatten(*input1);
auto x_reduce_grad = EigenVector<T>::Flatten(*input2);
auto x_grad = EigenVector<T>::Flatten(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Expand Down
119 changes: 116 additions & 3 deletions python/paddle/fluid/tests/unittests/test_mean_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard

np.random.seed(10)


class TestMeanOp(OpTest):
def setUp(self):
Expand Down Expand Up @@ -74,10 +76,105 @@ def test_checkout_grad(self):
place, ['X'], 'Out', max_relative_error=0.8)


def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False):
if isinstance(axis, list):
axis = tuple(axis)
if reduce_all:
axis = None
return np.mean(x, axis=axis, keepdims=keepdim)


class TestReduceMeanOp(OpTest):
def setUp(self):
self.op_type = 'reduce_mean'
self.dtype = 'float64'
self.shape = [2, 3, 4, 5]
self.axis = [0]
self.keepdim = False
self.reduce_all = False
self.set_attrs()

np.random.seed(10)
x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out_np = ref_reduce_mean(x_np, self.axis, self.keepdim, self.reduce_all)
self.inputs = {'X': x_np}
self.outputs = {'Out': out_np}
self.attrs = {
'dim': self.axis,
'keep_dim': self.keepdim,
'reduce_all': self.reduce_all
}

def set_attrs(self):
pass

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], ['Out'])


class TestReduceMeanOpDefaultAttrs(TestReduceMeanOp):
def setUp(self):
self.op_type = 'reduce_mean'
self.dtype = 'float64'
self.shape = [2, 3, 4, 5]

x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out_np = np.mean(x_np, axis=0)
self.inputs = {'X': x_np}
self.outputs = {'Out': out_np}


class TestReduceMeanOpFloat32(TestReduceMeanOp):
def set_attrs(self):
self.dtype = 'float32'


class TestReduceMeanOpShape1D(TestReduceMeanOp):
def set_attrs(self):
self.shape = [100]


class TestReduceMeanOpShape6D(TestReduceMeanOp):
def set_attrs(self):
self.shape = [2, 3, 4, 5, 6, 7]


class TestReduceMeanOpAxisAll(TestReduceMeanOp):
def set_attrs(self):
self.axis = [0, 1, 2, 3]


class TestReduceMeanOpAxisTuple(TestReduceMeanOp):
def set_attrs(self):
self.axis = (0, 1, 2)


class TestReduceMeanOpAxisNegative(TestReduceMeanOp):
def set_attrs(self):
self.axis = [-2, -1]


class TestReduceMeanOpKeepdimTrue1(TestReduceMeanOp):
def set_attrs(self):
self.keepdim = True


class TestReduceMeanOpKeepdimTrue2(TestReduceMeanOp):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
self.keepdim = True


class TestReduceMeanOpReduceAllTrue(TestReduceMeanOp):
def set_attrs(self):
self.reduce_all = True


class TestMeanAPI(unittest.TestCase):
"""
test paddle.tensor.stat.mean
"""
# test paddle.tensor.stat.mean

def setUp(self):
self.x_shape = [2, 3, 4, 5]
Expand Down Expand Up @@ -128,6 +225,22 @@ def test_case(x, axis=None, keepdim=False):
test_case(self.x, [0, 1, 2, 3])
paddle.enable_static()

def test_fluid_api(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = fluid.data("x", shape=[10, 10], dtype="float32")
out = fluid.layers.reduce_mean(input=x, dim=1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
x_np = np.random.rand(10, 10).astype(np.float32)
res = exe.run(feed={"x": x_np}, fetch_list=[out])
self.assertEqual(np.allclose(res[0], np.mean(x_np, axis=1)), True)

with fluid.dygraph.guard():
x_np = np.random.rand(10, 10).astype(np.float32)
x = fluid.dygraph.to_variable(x_np)
out = fluid.layers.reduce_mean(input=x, dim=1)
self.assertEqual(np.allclose(out.numpy(), np.mean(x_np, axis=1)), True)

def test_errors(self):
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 12]).astype('float32')
Expand Down
96 changes: 0 additions & 96 deletions python/paddle/fluid/tests/unittests/test_reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,6 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestMeanOp(OpTest):
def setUp(self):
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")}
self.attrs = {'dim': [1]}
self.outputs = {
'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim']))
}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')


@skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
Expand Down Expand Up @@ -318,21 +302,6 @@ def setUp(self):
self.outputs = {'Out': self.inputs['X'].sum()}


## reduction in multi dims
class TestReduceMeanOpMultiAxises(OpTest):
def setUp(self):
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")}
self.attrs = {'dim': [1, 2]}
self.outputs = {'Out': self.inputs['X'].mean(axis=(1, 2))}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')


@skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
Expand Down Expand Up @@ -420,40 +389,6 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestReduceMeanWithDimOne(OpTest):
def setUp(self):
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((100, 1, 1)).astype("float64")}
self.attrs = {'dim': [1], 'keep_dim': False}
self.outputs = {
'Out': self.inputs['X'].mean(
axis=tuple(self.attrs['dim']), keepdims=False)
}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestReduceMeanWithNumelOne(OpTest):
def setUp(self):
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((100, 1)).astype("float64")}
self.attrs = {'dim': [1], 'keep_dim': True}
self.outputs = {
'Out': self.inputs['X'].mean(
axis=tuple(self.attrs['dim']), keepdims=True)
}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestReduceAll(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
Expand Down Expand Up @@ -536,18 +471,6 @@ def test_errors(self):
self.assertRaises(TypeError, fluid.layers.reduce_sum, x2)


class TestReduceMeanOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of reduce_mean_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.reduce_mean, x1)
# The input dtype of reduce_mean_op must be float32 or float64 or int32 or int64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.reduce_mean, x2)


class API_TestSumOpError(unittest.TestCase):
def test_errors(self):
def test_dtype1():
Expand Down Expand Up @@ -649,24 +572,5 @@ def test_dygraph(self):
self.assertTrue((out3 == np.sum(np_x, axis=(0, 1, 2))).all())


class API_TestReduceMeanOp(unittest.TestCase):
def test_static(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = fluid.data("x", shape=[10, 10], dtype="float32")
out = fluid.layers.reduce_mean(input=x, dim=1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
x_np = np.random.rand(10, 10).astype(np.float32)
res = exe.run(feed={"x": x_np}, fetch_list=[out])
self.assertEqual(np.allclose(res[0], np.mean(x_np, axis=1)), True)

def test_dygraph(self):
with fluid.dygraph.guard():
x_np = np.random.rand(10, 10).astype(np.float32)
x = fluid.dygraph.to_variable(x_np)
out = fluid.layers.reduce_mean(input=x, dim=1)
self.assertEqual(np.allclose(out.numpy(), np.mean(x_np, axis=1)), True)


if __name__ == '__main__':
unittest.main()

0 comments on commit c80fcf9

Please sign in to comment.