diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 643ef08d74320..8f121158d8c7f 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -471,14 +471,16 @@ def median(x, axis=None, keepdim=False, name=None): return out_tensor -def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): +def _compute_quantile( + x, q, axis=None, keepdim=False, interpolation="linear", ignore_nan=False +): """ Compute the quantile of the input along the specified axis. Args: x (Tensor): The input Tensor, it's data type can be float32, float64, int32, int64. - q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, - each q will be calculated and the first dimension of output is same to the number of ``q`` . + q (int|float|list|Tensor): The q for calculate quantile, which should be in range [0, 1]. If q is a list, + a 1-D Tensor or a 0-D Tensor, each q will be calculated and the first dimension of output is same to the number of ``q`` . axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. @@ -489,6 +491,8 @@ def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): the output Tensor is the same as ``x`` except in the reduced dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. + interpolation (str, optional): The interpolation method to use + when the desired quantile falls between two data points. Default is linear. ignore_nan: (bool, optional): Whether to ignore NaN of input Tensor. If ``ignore_nan`` is True, it will calculate nanquantile. Otherwise it will calculate quantile. Default is False. @@ -507,9 +511,33 @@ def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): elif isinstance(q, (list, tuple)): if len(q) <= 0: raise ValueError("q should not be empty") + elif isinstance(q, Variable): + if len(q.shape) > 1: + raise ValueError("q should be a 0-D tensor or a 1-D tensor") + if len(q.shape) == 0: + q = [q] else: - raise TypeError("Type of q should be int, float, list or tuple.") + raise TypeError( + "Type of q should be int, float, list or tuple, or tensor" + ) + for q_num in q: + if not in_dynamic_mode() and isinstance(q_num, Variable): + break + if q_num < 0 or q_num > 1: + raise ValueError("q should be in range [0, 1]") + if interpolation not in [ + 'linear', + 'lower', + 'higher', + 'nearest', + 'midpoint', + ]: + raise ValueError( + "interpolation must be one of 'linear', 'lower', 'higher', 'nearest' or 'midpoint', but got {}".format( + interpolation + ) + ) # Validate axis dims = len(x.shape) out_shape = list(x.shape) @@ -557,8 +585,6 @@ def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): indices = [] for q_num in q: - if q_num < 0 or q_num > 1: - raise ValueError("q should be in range [0, 1]") if in_dynamic_or_pir_mode(): q_num = paddle.to_tensor(q_num, dtype='float64') if ignore_nan: @@ -573,31 +599,47 @@ def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): sorted_tensor = paddle.sort(x, axis) - outputs = [] + def _compute_index(index): + if interpolation == "nearest": + idx = paddle.round(index).astype(paddle.int32) + return paddle.take_along_axis(sorted_tensor, idx, axis=axis) - # TODO(chenjianye): replace the for-loop to directly take elements. - for index in indices: - indices_below = paddle.floor(index).astype('int32') - indices_upper = paddle.ceil(index).astype('int32') - tensor_upper = paddle.take_along_axis( - sorted_tensor, indices_upper, axis=axis - ) + indices_below = paddle.floor(index).astype(paddle.int32) tensor_below = paddle.take_along_axis( sorted_tensor, indices_below, axis=axis ) - weights = index - indices_below.astype('float64') - out = paddle.lerp( - tensor_below.astype('float64'), - tensor_upper.astype('float64'), + if interpolation == "lower": + return tensor_below + + indices_upper = paddle.ceil(index).astype(paddle.int32) + tensor_upper = paddle.take_along_axis( + sorted_tensor, indices_upper, axis=axis + ) + if interpolation == "higher": + return tensor_upper + + if interpolation == "midpoint": + return (tensor_upper + tensor_below) / 2 + + weights = (index - indices_below).astype(paddle.float64) + return paddle.lerp( + tensor_below.astype(paddle.float64), + tensor_upper.astype(paddle.float64), weights, ) + + outputs = [] + + # TODO(chenjianye): replace the for-loop to directly take elements. + for index in indices: + out = _compute_index(index) if not keepdim: out = paddle.squeeze(out, axis=axis) else: out = out.reshape(out_shape) outputs.append(out) - if len(q) > 1: + if len(outputs) > 1: outputs = paddle.stack(outputs, 0) else: outputs = outputs[0] @@ -605,15 +647,15 @@ def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): return outputs -def quantile(x, q, axis=None, keepdim=False): +def quantile(x, q, axis=None, keepdim=False, interpolation="linear"): """ Compute the quantile of the input along the specified axis. If any values in a reduced row are NaN, then the quantiles for that reduction will be NaN. Args: x (Tensor): The input Tensor, it's data type can be float32, float64, int32, int64. - q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, - each q will be calculated and the first dimension of output is same to the number of ``q`` . + q (int|float|list|Tensor): The q for calculate quantile, which should be in range [0, 1]. If q is a list, + a 1-D Tensor or a 0-D Tensor, each q will be calculated and the first dimension of output is same to the number of ``q`` . axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. @@ -624,6 +666,8 @@ def quantile(x, q, axis=None, keepdim=False): the output Tensor is the same as ``x`` except in the reduced dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. + interpolation (str, optional): The interpolation method to use + when the desired quantile falls between two data points. Default is linear. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -670,18 +714,25 @@ def quantile(x, q, axis=None, keepdim=False): [6.80000000]]) """ - return _compute_quantile(x, q, axis=axis, keepdim=keepdim, ignore_nan=False) + return _compute_quantile( + x, + q, + axis=axis, + keepdim=keepdim, + interpolation=interpolation, + ignore_nan=False, + ) -def nanquantile(x, q, axis=None, keepdim=False): +def nanquantile(x, q, axis=None, keepdim=False, interpolation="linear"): """ Compute the quantile of the input as if NaN values in input did not exist. If all values in a reduced row are NaN, then the quantiles for that reduction will be NaN. Args: x (Tensor): The input Tensor, it's data type can be float32, float64, int32, int64. - q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, - each q will be calculated and the first dimension of output is same to the number of ``q`` . + q (int|float|list|Tensor): The q for calculate quantile, which should be in range [0, 1]. If q is a list or + a 1-D Tensor, each q will be calculated and the first dimension of output is same to the number of ``q`` . axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. @@ -692,6 +743,8 @@ def nanquantile(x, q, axis=None, keepdim=False): the output Tensor is the same as ``x`` except in the reduced dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. + interpolation (str, optional): The interpolation method to use + when the desired quantile falls between two data points. Default is linear. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -740,4 +793,11 @@ def nanquantile(x, q, axis=None, keepdim=False): [nan]]) """ - return _compute_quantile(x, q, axis=axis, keepdim=keepdim, ignore_nan=True) + return _compute_quantile( + x, + q, + axis=axis, + keepdim=keepdim, + interpolation=interpolation, + ignore_nan=True, + ) diff --git a/test/legacy_test/test_quantile_and_nanquantile.py b/test/legacy_test/test_quantile_and_nanquantile.py index 815520ccfff6a..b2726e2d48f9d 100644 --- a/test/legacy_test/test_quantile_and_nanquantile.py +++ b/test/legacy_test/test_quantile_and_nanquantile.py @@ -119,6 +119,89 @@ def test_nanquantile_all_NaN(self): paddle_res.numpy(), np_res, rtol=1e-05, equal_nan=True ) + def test_nanquantile_interpolation(self): + input_data = np.random.randn(2, 3, 4) + input_data[0, 1, 1] = np.nan + x = paddle.to_tensor(input_data) + for mode in ['lower', 'higher', 'midpoint', 'nearest']: + paddle_res = paddle.nanquantile( + x, q=0.35, axis=0, interpolation=mode + ) + np_res = np.nanquantile(input_data, q=0.35, axis=0, method=mode) + np.testing.assert_allclose( + paddle_res.numpy(), np_res, rtol=1e-05, equal_nan=True + ) + + def test_backward(self): + def check_grad(x, q, axis, target_gard, apis=None): + x = np.array(x, dtype='float32') + paddle.disable_static() + for op, _ in apis or API_list: + x_p = paddle.to_tensor(x, dtype='float32', stop_gradient=False) + op(x_p, q, axis).sum().backward() + np.testing.assert_allclose( + x_p.grad.numpy(), + np.array(target_gard, dtype='float32'), + rtol=1e-05, + equal_nan=True, + ) + paddle.enable_static() + opt = paddle.optimizer.SGD(learning_rate=0.01) + for op, _ in apis or API_list: + s_p = paddle.static.Program() + m_p = paddle.static.Program() + with paddle.static.program_guard(m_p, s_p): + x_p = paddle.static.data( + name="x", + shape=x.shape, + dtype=paddle.float32, + ) + x_p.stop_gradient = False + q_p = paddle.static.data( + name="q", + shape=[len(q)] if isinstance(q, list) else [], + dtype=paddle.float32, + ) + loss = op(x_p, q_p, axis).sum() + opt.minimize(loss) + exe = paddle.static.Executor() + exe.run(paddle.static.default_startup_program()) + o = exe.run( + paddle.static.default_main_program(), + feed={"x": x, "q": np.array(q, dtype='float32')}, + fetch_list=['x@GRAD'], + )[0] + np.testing.assert_allclose( + o, + np.array(target_gard, dtype='float32'), + rtol=1e-05, + equal_nan=True, + ) + paddle.disable_static() + + check_grad([1, 2, 3], 0.5, 0, [0, 1, 0]) + check_grad( + [1, 2, 3, 4] * 2, [0.55, 0.7], 0, [0, 0, 0.95, 0, 0, 0.15, 0.9, 0] + ) + check_grad( + [[1, 2, 3], [4, 5, 6]], + [0.3, 0.7], + 1, + [[0.4, 1.2, 0.4], [0.4, 1.2, 0.4]], + ) + # quantile + check_grad( + [1, float('nan'), 3], 0.5, 0, [0, 1, 0], [(paddle.quantile, None)] + ) + # nanquantile + check_grad( + [1, float('nan'), 3], + 0.5, + 0, + [0.5, 0, 0.5], + [(paddle.nanquantile, None)], + ) + class TestMuitlpleQ(unittest.TestCase): """ @@ -150,6 +233,24 @@ def test_quantile_multiple_axis_keepdim(self): ) np.testing.assert_allclose(paddle_res.numpy(), np_res, rtol=1e-05) + def test_quantile_with_tensor_input(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile( + x, q=paddle.to_tensor([0.1, 0.2]), axis=[1, 2], keepdim=True + ) + np_res = np.quantile( + self.input_data, q=[0.1, 0.2], axis=[1, 2], keepdims=True + ) + np.testing.assert_allclose(paddle_res.numpy(), np_res, rtol=1e-05) + + def test_quantile_with_zero_dim_tensor_input(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile( + x, q=paddle.to_tensor(0.1), axis=[1, 2], keepdim=True + ) + np_res = np.quantile(self.input_data, q=0.1, axis=[1, 2], keepdims=True) + np.testing.assert_allclose(paddle_res.numpy(), np_res, rtol=1e-05) + class TestError(unittest.TestCase): """ @@ -210,6 +311,26 @@ def test_axis_value_error_2(): self.assertRaises(ValueError, test_axis_value_error_2) + # Test error when q is not a 1-D tensor + def test_tensor_input_1(): + paddle_res = paddle.quantile( + self.x, q=paddle.randn((2, 3)), axis=[1, -10] + ) + + self.assertRaises(ValueError, test_tensor_input_1) + + def test_type_q(): + paddle_res = paddle.quantile(self.x, q={1}, axis=[1, -10]) + + self.assertRaises(TypeError, test_type_q) + + def test_interpolation(): + paddle_res = paddle.quantile( + self.x, q={1}, axis=[1, -10], interpolation=' ' + ) + + self.assertRaises(TypeError, test_interpolation) + class TestQuantileRuntime(unittest.TestCase): """ @@ -267,11 +388,101 @@ def test_static(self): ) np_res = res_func(np_input_data, q=0.5, axis=1) np_res_fp64 = res_func(np_input_data_fp64, q=0.5, axis=1) - self.assertTrue( - np.allclose(paddle_res, np_res) - and np.allclose(paddle_res_fp64, np_res_fp64) + np.testing.assert_allclose(paddle_res, np_res, rtol=1e-05) + np.testing.assert_allclose( + paddle_res_fp64, np_res_fp64, rtol=1e-05 ) + def test_static_tensor(self): + paddle.enable_static() + for func, res_func in API_list: + s_p = paddle.static.Program() + m_p = paddle.static.Program() + with paddle.static.program_guard(m_p, s_p): + for device in self.devices: + x = paddle.static.data( + name="x", + shape=self.input_data.shape, + dtype=paddle.float32, + ) + q = paddle.static.data( + name="q", shape=(3,), dtype=paddle.float32 + ) + x_fp64 = paddle.static.data( + name="x_fp64", + shape=self.input_data.shape, + dtype=paddle.float64, + ) + + results = func(x, q=q, axis=1) + np_input_data = self.input_data.astype("float32") + results_fp64 = func(x_fp64, q=q, axis=1) + np_input_data_fp64 = self.input_data.astype("float64") + q_data = np.array([0.5, 0.5, 0.5]).astype("float32") + + exe = paddle.static.Executor(device) + paddle_res, paddle_res_fp64 = exe.run( + paddle.static.default_main_program(), + feed={ + "x": np_input_data, + "x_fp64": np_input_data_fp64, + "q": q_data, + }, + fetch_list=[results, results_fp64], + ) + np_res = res_func(np_input_data, q=[0.5, 0.5, 0.5], axis=1) + np_res_fp64 = res_func( + np_input_data_fp64, q=[0.5, 0.5, 0.5], axis=1 + ) + np.testing.assert_allclose(paddle_res, np_res, rtol=1e-05) + np.testing.assert_allclose( + paddle_res_fp64, np_res_fp64, rtol=1e-05 + ) + + def test_static_0d_tensor(self): + paddle.enable_static() + for func, res_func in API_list: + for device in self.devices: + s_p = paddle.static.Program() + m_p = paddle.static.Program() + with paddle.static.program_guard(m_p, s_p): + x = paddle.static.data( + name="x", + shape=self.input_data.shape, + dtype=paddle.float32, + ) + q = paddle.static.data( + name="q", shape=[], dtype=paddle.float32 + ) + x_fp64 = paddle.static.data( + name="x_fp64", + shape=self.input_data.shape, + dtype=paddle.float64, + ) + + results = func(x, q=q, axis=1) + np_input_data = self.input_data.astype("float32") + results_fp64 = func(x_fp64, q=q, axis=1) + np_input_data_fp64 = self.input_data.astype("float64") + q_data = np.array(0.3).astype("float32") + + exe = paddle.static.Executor(device) + paddle_res, paddle_res_fp64 = exe.run( + paddle.static.default_main_program(), + feed={ + "x": np_input_data, + "x_fp64": np_input_data_fp64, + "q": q_data, + }, + fetch_list=[results, results_fp64], + ) + np_res = res_func(np_input_data, q=0.3, axis=1) + np_res_fp64 = res_func(np_input_data_fp64, q=0.3, axis=1) + np.testing.assert_allclose(paddle_res, np_res, rtol=1e-05) + np.testing.assert_allclose( + paddle_res_fp64, np_res_fp64, rtol=1e-05 + ) + if __name__ == '__main__': unittest.main()