diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 298ad14f9e04b..2139605fb2048 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -64,6 +64,16 @@ static void BinarySameInputDimsCheck(const MetaTensor& x, } } +// Used in MatrixRankTolInferMeta +static DDim CheckAndGetOutputDim(const DDim& dim_x) { + auto x_vec = phi::vectorize(dim_x); + if (x_vec.size() == 2) { + return phi::make_ddim({1}); + } + x_vec.erase(x_vec.end() - 2, x_vec.end()); + return phi::make_ddim(x_vec); +} + } // namespace detail void AllValueCompareInferMeta(const MetaTensor& x, @@ -1465,6 +1475,47 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x, out->share_lod(x); } +void MatrixRankTolInferMeta(const MetaTensor& x, + const MetaTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + MetaTensor* out) { + auto dim_x = x.dims(); + PADDLE_ENFORCE_GE( + dim_x.size(), + 2, + phi::errors::InvalidArgument("The dims of input must be greater than 2")); + + if (hermitian) { + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + PADDLE_ENFORCE_EQ(rows, + cols, + phi::errors::InvalidArgument( + "if hermitian == true, matrix should be n*n")); + } + DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x); + auto dim_tol = atol_tensor.dims(); + if (dim_x_batch == dim_tol) { + out->set_dims(dim_x_batch); + } else { + int max_dim = std::max(dim_x_batch.size(), dim_tol.size()); + int axis = std::abs(dim_x_batch.size() - dim_tol.size()); + std::vector x_batch_dims_array(max_dim); + std::vector tol_dims_array(max_dim); + std::vector out_dims_array(max_dim); + phi::funcs::GetBroadcastDimsArrays(dim_x_batch, + dim_tol, + x_batch_dims_array.data(), + tol_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + out->set_dims(phi::make_ddim(out_dims_array)); + } + out->share_lod(x); +} + void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { auto dim_x = x.dims(); auto dim_vec = vec.dims(); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 70c3c9dfe849d..192fa214c905f 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -218,6 +218,12 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x, int y_num_col_dims, MetaTensor* out); +void MatrixRankTolInferMeta(const MetaTensor& x, + const MetaTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + MetaTensor* out); + void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out); void PReluInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index e0ea637074c20..23511de39eef4 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -31,6 +31,18 @@ limitations under the License. */ namespace phi { +namespace detail { +// Used in MatrixRankInferMeta +static DDim CheckAndGetOutputDim(const DDim& dim_x) { + auto x_vec = phi::vectorize(dim_x); + if (x_vec.size() == 2) { + return phi::make_ddim({1}); + } + x_vec.erase(x_vec.end() - 2, x_vec.end()); + return phi::make_ddim(x_vec); +} +} // namespace detail + void ArgMinMaxInferMeta(const MetaTensor& x, int64_t axis, bool keepdims, @@ -901,6 +913,29 @@ void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) { out->set_dtype(x.dtype()); } +void MatrixRankInferMeta(const MetaTensor& x, + bool use_default_tol, + bool hermitian, + MetaTensor* out) { + auto dim_x = x.dims(); + PADDLE_ENFORCE_GE( + dim_x.size(), + 2, + phi::errors::InvalidArgument("The dims of input must be greater than 2")); + + if (hermitian) { + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + PADDLE_ENFORCE_EQ(rows, + cols, + phi::errors::InvalidArgument( + "if hermitian == true, matrix should be n*n")); + } + DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x); + out->set_dims(dim_x_batch); + out->share_lod(x); +} + void MaxOutInferMeta(const MetaTensor& x, int groups, int axis, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 5106c6f448733..ad2a9b5311c58 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -142,6 +142,11 @@ void LogsumexpInferMeta(const MetaTensor& input, void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out); +void MatrixRankInferMeta(const MetaTensor& x, + bool use_default_tol, + bool hermitian, + MetaTensor* out); + void MaxOutInferMeta(const MetaTensor& x, int groups, int axis, diff --git a/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py b/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py index d0b84a0d7e108..b13b346261762 100644 --- a/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py +++ b/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py @@ -30,8 +30,13 @@ np.random.seed(SEED) +def matrix_rank_wraper(x, tol=None, use_default_tol=True, hermitian=False): + return paddle.linalg.matrix_rank(x, tol, hermitian) + + class TestMatrixRankOP(OpTest): def setUp(self): + self.python_api = matrix_rank_wraper self.op_type = "matrix_rank" self.init_data() self.inputs = {'X': self.x} @@ -44,7 +49,7 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def init_data(self): self.x = np.eye(3, dtype=np.float32) @@ -110,6 +115,28 @@ def init_data(self): self.hermitian) +class TestMatrixRankOP6(TestMatrixRankOP): + def init_data(self): + self.x = np.random.rand(3, 4, 5, 6).astype(np.float32) + self.tol_tensor = None + self.tol = None + self.use_default_tol = False + self.hermitian = False + self.out = np.linalg.matrix_rank(self.x, self.tol_tensor, + self.hermitian) + + +class TestMatrixRankOP7(TestMatrixRankOP): + def init_data(self): + self.x = np.eye(200, dtype=np.float64) + self.tol_tensor = np.random.random([200, 200]).astype(self.x.dtype) + self.tol = None + self.use_default_tol = True + self.hermitian = True + self.out = np.linalg.matrix_rank(self.x, self.tol_tensor, + self.hermitian) + + class TestMatrixRankAPI(unittest.TestCase): def test_dygraph(self): paddle.disable_static() diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index c4814bd2b2f9c..b8aeaa362b40c 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1284,8 +1284,26 @@ def matrix_rank(x, tol=None, hermitian=False, name=None): # [1, 1, 1, 1]] """ + if in_dygraph_mode(): + if isinstance(tol, Variable): + if tol.dtype != x.dtype: + tol_tensor = cast(tol, x.dtype) + else: + tol_tensor = tol + use_default_tol = False + return _C_ops.final_state_matrix_rank_tol( + x, tol_tensor, use_default_tol, hermitian) - if paddle.in_dynamic_mode(): + if tol is None: + tol_attr = 0.0 + use_default_tol = True + else: + tol_attr = float(tol) + use_default_tol = False + return _C_ops.final_state_matrix_rank(x, tol_attr, use_default_tol, + hermitian) + + if _in_legacy_dygraph(): if tol is None: tol_tensor = None tol_attr = 0.0 diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 93d14b1744e93..c55c8224a9f64 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1090,6 +1090,23 @@ func : matrix_power backward : matrix_power_grad +- api : matrix_rank + args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false) + output : Tensor(out) + infer_meta : + func : MatrixRankInferMeta + param : [x, use_default_tol, hermitian] + kernel : + func : matrix_rank + +- api : matrix_rank_tol + args : (Tensor x, Tensor atol_tensor, bool use_default_tol=true, bool hermitian=false) + output : Tensor(out) + infer_meta : + func : MatrixRankTolInferMeta + kernel : + func : matrix_rank_tol + - api : max args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) output : Tensor(out)