From d704e9a5e2e9c7323aa6917b5668f263bcf53daf Mon Sep 17 00:00:00 2001 From: zhoujianqian <15205085056@163.com> Date: Sat, 19 Mar 2022 09:50:03 +0000 Subject: [PATCH 1/5] modify matrix_rank --- paddle/phi/infermeta/binary.cc | 52 ++++++++++++++++++++++++++++++++++ paddle/phi/infermeta/binary.h | 6 ++++ paddle/phi/infermeta/unary.cc | 34 ++++++++++++++++++++++ paddle/phi/infermeta/unary.h | 5 ++++ 4 files changed, 97 insertions(+) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 521f2a9bf0648..323266cb3d914 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1002,6 +1002,58 @@ void ValueCompareInferMeta(const MetaTensor& x, out->set_dtype(DataType::BOOL); } +namespace detail { +static DDim CheckAndGetOutputDim(const DDim& dim_x) { + auto x_vec = phi::vectorize(dim_x); + if (x_vec.size() == 2) { + return phi::make_ddim({1}); + } + x_vec.erase(x_vec.end() - 2, x_vec.end()); + return phi::make_ddim(x_vec); +} +} // namespace detail + +void MatrixRankTolInferMeta(const MetaTensor& x, + const MetaTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + MetaTensor* out) { + auto dim_x = x.dims(); + PADDLE_ENFORCE_GE( + dim_x.size(), + 2, + phi::errors::InvalidArgument("The dims of input must be greater than 2")); + + if (hermitian) { + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + PADDLE_ENFORCE_EQ(rows, + cols, + phi::errors::InvalidArgument( + "if hermitian == true, matrix should be n*n")); + } + DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x); + auto dim_tol = atol_tensor.dims(); + if (dim_x_batch == dim_tol) { + out->set_dims(dim_x_batch); + } else { + int max_dim = std::max(dim_x_batch.size(), dim_tol.size()); + int axis = std::abs(dim_x_batch.size() - dim_tol.size()); + std::vector x_batch_dims_array(max_dim); + std::vector tol_dims_array(max_dim); + std::vector out_dims_array(max_dim); + phi::funcs::GetBroadcastDimsArrays(dim_x_batch, + dim_tol, + x_batch_dims_array.data(), + tol_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + out->set_dims(phi::make_ddim(out_dims_array)); + } + out->share_lod(x); +} + } // namespace phi PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 9e1a35640ad29..ad86350bbd9a0 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -165,4 +165,10 @@ void ValueCompareInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void MatrixRankTolInferMeta(const MetaTensor& x, + const MetaTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index f81f4a1b7c739..50e50911edfaf 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1752,6 +1752,40 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { out->set_dtype(DataType::INT64); } +namespace detail { +static DDim CheckAndGetOutputDim(const DDim& dim_x) { + auto x_vec = phi::vectorize(dim_x); + if (x_vec.size() == 2) { + return phi::make_ddim({1}); + } + x_vec.erase(x_vec.end() - 2, x_vec.end()); + return phi::make_ddim(x_vec); +} +} // namespace detail + +void MatrixRankInferMeta(const MetaTensor& x, + bool use_default_tol, + bool hermitian, + MetaTensor* out) { + auto dim_x = x.dims(); + PADDLE_ENFORCE_GE( + dim_x.size(), + 2, + phi::errors::InvalidArgument("The dims of input must be greater than 2")); + + if (hermitian) { + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + PADDLE_ENFORCE_EQ(rows, + cols, + phi::errors::InvalidArgument( + "if hermitian == true, matrix should be n*n")); + } + DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x); + out->set_dims(dim_x_batch); + out->share_lod(x); +} + } // namespace phi PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index eb894003e5354..154dc561c5d75 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -255,4 +255,9 @@ void OneHotInferMeta(const MetaTensor& x, const Scalar& depth, MetaTensor* out); void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out); +void MatrixRankInferMeta(const MetaTensor& x, + bool use_default_tol, + bool hermitian, + MetaTensor* out); + } // namespace phi From b44ee6ef14ebcbb73bc285bf8f6cc58c7a8b66fc Mon Sep 17 00:00:00 2001 From: zhoujianqian <15205085056@163.com> Date: Wed, 30 Mar 2022 10:22:38 +0000 Subject: [PATCH 2/5] add matrix_rank shape --- paddle/phi/infermeta/binary.cc | 20 +++++++++----------- paddle/phi/infermeta/unary.cc | 2 ++ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 2989da2d91c0d..1b28c3fb44d10 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -63,6 +63,15 @@ static void BinarySameInputDimsCheck(const MetaTensor& x, } } +static DDim CheckAndGetOutputDim(const DDim& dim_x) { + auto x_vec = phi::vectorize(dim_x); + if (x_vec.size() == 2) { + return phi::make_ddim({1}); + } + x_vec.erase(x_vec.end() - 2, x_vec.end()); + return phi::make_ddim(x_vec); +} + } // namespace detail void AllValueCompareInferMeta(const MetaTensor& x, @@ -1789,17 +1798,6 @@ void ValueCompareInferMeta(const MetaTensor& x, out->set_dtype(DataType::BOOL); } -namespace detail { -static DDim CheckAndGetOutputDim(const DDim& dim_x) { - auto x_vec = phi::vectorize(dim_x); - if (x_vec.size() == 2) { - return phi::make_ddim({1}); - } - x_vec.erase(x_vec.end() - 2, x_vec.end()); - return phi::make_ddim(x_vec); -} -} // namespace detail - void MatrixRankTolInferMeta(const MetaTensor& x, const MetaTensor& atol_tensor, bool use_default_tol, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index dabba36d451c2..d9f1db113385b 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2551,6 +2551,7 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { } namespace detail { + static DDim CheckAndGetOutputDim(const DDim& dim_x) { auto x_vec = phi::vectorize(dim_x); if (x_vec.size() == 2) { @@ -2559,6 +2560,7 @@ static DDim CheckAndGetOutputDim(const DDim& dim_x) { x_vec.erase(x_vec.end() - 2, x_vec.end()); return phi::make_ddim(x_vec); } + } // namespace detail void MatrixRankInferMeta(const MetaTensor& x, From 8618de1fe6fd98a7a73d00f1781a39d747997c1d Mon Sep 17 00:00:00 2001 From: zhoujianqian <15205085056@163.com> Date: Wed, 30 Mar 2022 10:24:49 +0000 Subject: [PATCH 3/5] add matrix_rank shape --- paddle/phi/infermeta/unary.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index d9f1db113385b..dabba36d451c2 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2551,7 +2551,6 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { } namespace detail { - static DDim CheckAndGetOutputDim(const DDim& dim_x) { auto x_vec = phi::vectorize(dim_x); if (x_vec.size() == 2) { @@ -2560,7 +2559,6 @@ static DDim CheckAndGetOutputDim(const DDim& dim_x) { x_vec.erase(x_vec.end() - 2, x_vec.end()); return phi::make_ddim(x_vec); } - } // namespace detail void MatrixRankInferMeta(const MetaTensor& x, From 60c3f4b24bcd4355092ef99e6f1c1d6f5dbc408c Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 6 Apr 2022 19:35:15 +0800 Subject: [PATCH 4/5] Add yaml for matrix_rank OP --- paddle/phi/infermeta/binary.cc | 83 ++++++++++--------- paddle/phi/infermeta/binary.h | 12 +-- paddle/phi/infermeta/unary.cc | 69 +++++++-------- paddle/phi/infermeta/unary.h | 10 +-- .../tests/unittests/test_matrix_rank_op.py | 7 +- python/paddle/tensor/linalg.py | 20 ++++- python/paddle/utils/code_gen/api.yaml | 17 ++++ 7 files changed, 130 insertions(+), 88 deletions(-) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 748a0904490ae..2139605fb2048 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -64,6 +64,7 @@ static void BinarySameInputDimsCheck(const MetaTensor& x, } } +// Used in MatrixRankTolInferMeta static DDim CheckAndGetOutputDim(const DDim& dim_x) { auto x_vec = phi::vectorize(dim_x); if (x_vec.size() == 2) { @@ -1474,6 +1475,47 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x, out->share_lod(x); } +void MatrixRankTolInferMeta(const MetaTensor& x, + const MetaTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + MetaTensor* out) { + auto dim_x = x.dims(); + PADDLE_ENFORCE_GE( + dim_x.size(), + 2, + phi::errors::InvalidArgument("The dims of input must be greater than 2")); + + if (hermitian) { + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + PADDLE_ENFORCE_EQ(rows, + cols, + phi::errors::InvalidArgument( + "if hermitian == true, matrix should be n*n")); + } + DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x); + auto dim_tol = atol_tensor.dims(); + if (dim_x_batch == dim_tol) { + out->set_dims(dim_x_batch); + } else { + int max_dim = std::max(dim_x_batch.size(), dim_tol.size()); + int axis = std::abs(dim_x_batch.size() - dim_tol.size()); + std::vector x_batch_dims_array(max_dim); + std::vector tol_dims_array(max_dim); + std::vector out_dims_array(max_dim); + phi::funcs::GetBroadcastDimsArrays(dim_x_batch, + dim_tol, + x_batch_dims_array.data(), + tol_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + out->set_dims(phi::make_ddim(out_dims_array)); + } + out->share_lod(x); +} + void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { auto dim_x = x.dims(); auto dim_vec = vec.dims(); @@ -1921,47 +1963,6 @@ void ValueCompareInferMeta(const MetaTensor& x, out->set_dtype(DataType::BOOL); } -void MatrixRankTolInferMeta(const MetaTensor& x, - const MetaTensor& atol_tensor, - bool use_default_tol, - bool hermitian, - MetaTensor* out) { - auto dim_x = x.dims(); - PADDLE_ENFORCE_GE( - dim_x.size(), - 2, - phi::errors::InvalidArgument("The dims of input must be greater than 2")); - - if (hermitian) { - int rows = dim_x[dim_x.size() - 2]; - int cols = dim_x[dim_x.size() - 1]; - PADDLE_ENFORCE_EQ(rows, - cols, - phi::errors::InvalidArgument( - "if hermitian == true, matrix should be n*n")); - } - DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x); - auto dim_tol = atol_tensor.dims(); - if (dim_x_batch == dim_tol) { - out->set_dims(dim_x_batch); - } else { - int max_dim = std::max(dim_x_batch.size(), dim_tol.size()); - int axis = std::abs(dim_x_batch.size() - dim_tol.size()); - std::vector x_batch_dims_array(max_dim); - std::vector tol_dims_array(max_dim); - std::vector out_dims_array(max_dim); - phi::funcs::GetBroadcastDimsArrays(dim_x_batch, - dim_tol, - x_batch_dims_array.data(), - tol_dims_array.data(), - out_dims_array.data(), - max_dim, - axis); - out->set_dims(phi::make_ddim(out_dims_array)); - } - out->share_lod(x); -} - } // namespace phi PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index cbec4f79529c6..192fa214c905f 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -218,6 +218,12 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x, int y_num_col_dims, MetaTensor* out); +void MatrixRankTolInferMeta(const MetaTensor& x, + const MetaTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + MetaTensor* out); + void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out); void PReluInferMeta(const MetaTensor& x, @@ -278,10 +284,4 @@ void ValueCompareInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); -void MatrixRankTolInferMeta(const MetaTensor& x, - const MetaTensor& atol_tensor, - bool use_default_tol, - bool hermitian, - DenseTensor* out); - } // namespace phi diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 1a00b5e5a7f5d..23511de39eef4 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -31,6 +31,18 @@ limitations under the License. */ namespace phi { +namespace detail { +// Used in MatrixRankInferMeta +static DDim CheckAndGetOutputDim(const DDim& dim_x) { + auto x_vec = phi::vectorize(dim_x); + if (x_vec.size() == 2) { + return phi::make_ddim({1}); + } + x_vec.erase(x_vec.end() - 2, x_vec.end()); + return phi::make_ddim(x_vec); +} +} // namespace detail + void ArgMinMaxInferMeta(const MetaTensor& x, int64_t axis, bool keepdims, @@ -901,6 +913,29 @@ void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) { out->set_dtype(x.dtype()); } +void MatrixRankInferMeta(const MetaTensor& x, + bool use_default_tol, + bool hermitian, + MetaTensor* out) { + auto dim_x = x.dims(); + PADDLE_ENFORCE_GE( + dim_x.size(), + 2, + phi::errors::InvalidArgument("The dims of input must be greater than 2")); + + if (hermitian) { + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + PADDLE_ENFORCE_EQ(rows, + cols, + phi::errors::InvalidArgument( + "if hermitian == true, matrix should be n*n")); + } + DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x); + out->set_dims(dim_x_batch); + out->share_lod(x); +} + void MaxOutInferMeta(const MetaTensor& x, int groups, int axis, @@ -2863,40 +2898,6 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { out->set_dtype(DataType::INT64); } -namespace detail { -static DDim CheckAndGetOutputDim(const DDim& dim_x) { - auto x_vec = phi::vectorize(dim_x); - if (x_vec.size() == 2) { - return phi::make_ddim({1}); - } - x_vec.erase(x_vec.end() - 2, x_vec.end()); - return phi::make_ddim(x_vec); -} -} // namespace detail - -void MatrixRankInferMeta(const MetaTensor& x, - bool use_default_tol, - bool hermitian, - MetaTensor* out) { - auto dim_x = x.dims(); - PADDLE_ENFORCE_GE( - dim_x.size(), - 2, - phi::errors::InvalidArgument("The dims of input must be greater than 2")); - - if (hermitian) { - int rows = dim_x[dim_x.size() - 2]; - int cols = dim_x[dim_x.size() - 1]; - PADDLE_ENFORCE_EQ(rows, - cols, - phi::errors::InvalidArgument( - "if hermitian == true, matrix should be n*n")); - } - DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x); - out->set_dims(dim_x_batch); - out->share_lod(x); -} - } // namespace phi PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 4b06a73df75eb..ad2a9b5311c58 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -142,6 +142,11 @@ void LogsumexpInferMeta(const MetaTensor& input, void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out); +void MatrixRankInferMeta(const MetaTensor& x, + bool use_default_tol, + bool hermitian, + MetaTensor* out); + void MaxOutInferMeta(const MetaTensor& x, int groups, int axis, @@ -426,9 +431,4 @@ void OneHotInferMeta(const MetaTensor& x, const Scalar& depth, MetaTensor* out); void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out); -void MatrixRankInferMeta(const MetaTensor& x, - bool use_default_tol, - bool hermitian, - MetaTensor* out); - } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py b/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py index d0b84a0d7e108..2a622f15dedab 100644 --- a/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py +++ b/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py @@ -30,8 +30,13 @@ np.random.seed(SEED) +def matrix_rank_wraper(x, tol=None, use_default_tol=True, hermitian=False): + return paddle.linalg.matrix_rank(x, tol, hermitian) + + class TestMatrixRankOP(OpTest): def setUp(self): + self.python_api = matrix_rank_wraper self.op_type = "matrix_rank" self.init_data() self.inputs = {'X': self.x} @@ -44,7 +49,7 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def init_data(self): self.x = np.eye(3, dtype=np.float32) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index c4814bd2b2f9c..b8aeaa362b40c 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1284,8 +1284,26 @@ def matrix_rank(x, tol=None, hermitian=False, name=None): # [1, 1, 1, 1]] """ + if in_dygraph_mode(): + if isinstance(tol, Variable): + if tol.dtype != x.dtype: + tol_tensor = cast(tol, x.dtype) + else: + tol_tensor = tol + use_default_tol = False + return _C_ops.final_state_matrix_rank_tol( + x, tol_tensor, use_default_tol, hermitian) - if paddle.in_dynamic_mode(): + if tol is None: + tol_attr = 0.0 + use_default_tol = True + else: + tol_attr = float(tol) + use_default_tol = False + return _C_ops.final_state_matrix_rank(x, tol_attr, use_default_tol, + hermitian) + + if _in_legacy_dygraph(): if tol is None: tol_tensor = None tol_attr = 0.0 diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 93d14b1744e93..c55c8224a9f64 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1090,6 +1090,23 @@ func : matrix_power backward : matrix_power_grad +- api : matrix_rank + args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false) + output : Tensor(out) + infer_meta : + func : MatrixRankInferMeta + param : [x, use_default_tol, hermitian] + kernel : + func : matrix_rank + +- api : matrix_rank_tol + args : (Tensor x, Tensor atol_tensor, bool use_default_tol=true, bool hermitian=false) + output : Tensor(out) + infer_meta : + func : MatrixRankTolInferMeta + kernel : + func : matrix_rank_tol + - api : max args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) output : Tensor(out) From d631a63b87a9d830d92773af38d81c3d064d26e2 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Thu, 7 Apr 2022 10:54:17 +0800 Subject: [PATCH 5/5] Add UT --- .../tests/unittests/test_matrix_rank_op.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py b/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py index 2a622f15dedab..b13b346261762 100644 --- a/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py +++ b/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py @@ -115,6 +115,28 @@ def init_data(self): self.hermitian) +class TestMatrixRankOP6(TestMatrixRankOP): + def init_data(self): + self.x = np.random.rand(3, 4, 5, 6).astype(np.float32) + self.tol_tensor = None + self.tol = None + self.use_default_tol = False + self.hermitian = False + self.out = np.linalg.matrix_rank(self.x, self.tol_tensor, + self.hermitian) + + +class TestMatrixRankOP7(TestMatrixRankOP): + def init_data(self): + self.x = np.eye(200, dtype=np.float64) + self.tol_tensor = np.random.random([200, 200]).astype(self.x.dtype) + self.tol = None + self.use_default_tol = True + self.hermitian = True + self.out = np.linalg.matrix_rank(self.x, self.tol_tensor, + self.hermitian) + + class TestMatrixRankAPI(unittest.TestCase): def test_dygraph(self): paddle.disable_static()