Skip to content

Commit

Permalink
【Hackathon 4 No.9】Add pca_lowrank API to Paddle (#53743)
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick-Star125 authored Jun 13, 2023
1 parent 7309f8a commit 4ebb476
Show file tree
Hide file tree
Showing 7 changed files with 648 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .tensor.linalg import matrix_rank # noqa: F401
from .tensor.linalg import multi_dot # noqa: F401
from .tensor.linalg import norm # noqa: F401
from .tensor.linalg import pca_lowrank # noqa: F401
from .tensor.linalg import pinv # noqa: F401
from .tensor.linalg import qr # noqa: F401
from .tensor.linalg import slogdet # noqa: F401
Expand All @@ -50,6 +51,7 @@
'matrix_rank',
'svd',
'qr',
'pca_lowrank',
'lu',
'lu_unpack',
'matrix_power',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .unary import log1p
from .unary import abs
from .unary import pow
from .unary import pca_lowrank
from .unary import cast
from .unary import neg
from .unary import coalesce
Expand Down Expand Up @@ -69,6 +70,7 @@
'log1p',
'abs',
'pow',
'pca_lowrank',
'cast',
'neg',
'deg2rad',
Expand Down
203 changes: 203 additions & 0 deletions python/paddle/sparse/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np

import paddle
from paddle import _C_ops, in_dynamic_mode
from paddle.common_ops_import import Variable
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
Expand Down Expand Up @@ -920,3 +921,205 @@ def slice(x, axes, starts, ends, name=None):
type=op_type, inputs={'x': x}, outputs={'out': out}, attrs=attrs
)
return out


def pca_lowrank(x, q=None, center=True, niter=2, name=None):
r"""
Performs linear Principal Component Analysis (PCA) on a sparse matrix.
Let :math:`X` be the input matrix or a batch of input matrices, the output should satisfies:
.. math::
X = U * diag(S) * V^{T}
Args:
x (Tensor): The input tensor. Its shape should be `[N, M]`,
N and M can be arbitraty positive number.
The data type of x should be float32 or float64.
q (int, optional): a slightly overestimated rank of :math:`X`.
Default value is :math:`q=min(6,N,M)`.
center (bool, optional): if True, center the input tensor.
Default value is True.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
- Tensor U, is N x q matrix.
- Tensor S, is a vector with length q.
- Tensor V, is M x q matrix.
tuple (U, S, V): which is the nearly optimal approximation of a singular value decomposition of a centered matrix :math:`X`.
Examples:
.. code-block:: python
import paddle
format = "coo"
dense_x = paddle.randn((5, 5), dtype='float64')
if format == "coo":
sparse_x = dense_x.to_sparse_coo(len(dense_x.shape))
else:
sparse_x = dense_x.to_sparse_csr()
print("sparse.pca_lowrank API only support CUDA 11.x")
U, S, V = None, None, None
# use code blow when your device CUDA version >= 11.0
# U, S, V = paddle.sparse.pca_lowrank(sparse_x)
print(U)
# Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [[ 0.02206024, 0.53170082, -0.22392168, -0.48450657, 0.65720625],
# [ 0.02206024, 0.53170082, -0.22392168, -0.32690402, -0.74819812],
# [ 0.02206024, 0.53170082, -0.22392168, 0.81141059, 0.09099187],
# [ 0.15045792, 0.37840027, 0.91333217, -0.00000000, 0.00000000],
# [ 0.98787775, -0.09325209, -0.12410317, -0.00000000, -0.00000000]])
print(S)
# Tensor(shape=[5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [2.28621761, 0.93618564, 0.53234942, 0.00000000, 0.00000000])
print(V)
# Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [[ 0.26828910, -0.57116436, -0.26548201, 0.67342660, -0.27894114],
# [-0.19592125, -0.31629129, 0.02001645, -0.50484498, -0.77865626],
# [-0.82913017, -0.09391036, 0.37975388, 0.39938099, -0.00241046],
# [-0.41163516, 0.27490410, -0.86666276, 0.03382656, -0.05230341],
# [ 0.18092947, 0.69952818, 0.18385126, 0.36190987, -0.55959343]])
"""

def get_floating_dtype(x):
dtype = x.dtype
if dtype in (paddle.float16, paddle.float32, paddle.float64):
return dtype
return paddle.float32

def conjugate(x):
if x.is_complex():
return x.conj()
return x

def transpose(x):
shape = x.shape
perm = list(range(0, len(shape)))
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
if x.is_sparse():
return paddle.sparse.transpose(x, perm)
return paddle.transpose(x, perm)

def transjugate(x):
return conjugate(transpose(x))

def get_approximate_basis(x, q, niter=2, M=None):
niter = 2 if niter is None else niter
m, n = x.shape[-2:]
qr = paddle.linalg.qr

R = paddle.randn((n, q), dtype=x.dtype)

A_t = transpose(x)
A_H = conjugate(A_t)
if M is None:
Q = qr(paddle.sparse.matmul(x, R))[0]
for i in range(niter):
Q = qr(paddle.sparse.matmul(A_H, Q))[0]
Q = qr(paddle.sparse.matmul(x, Q))[0]
else:
M_H = transjugate(M)
Q = qr(paddle.sparse.matmul(x, R) - paddle.matmul(M, R))[0]
for i in range(niter):
Q = qr(paddle.sparse.matmul(A_H, Q) - paddle.matmul(M_H, Q))[0]
Q = qr(paddle.sparse.matmul(x, Q) - paddle.matmul(M, Q))[0]

return Q

def svd_lowrank(x, q=6, niter=2, M=None):
q = 6 if q is None else q
m, n = x.shape[-2:]
if M is None:
M_t = None
else:
M_t = transpose(M)
A_t = transpose(x)

if m < n or n > q:
Q = get_approximate_basis(A_t, q, niter=niter, M=M_t)
Q_c = conjugate(Q)
if M is None:
B_t = paddle.sparse.matmul(x, Q_c)
else:
B_t = paddle.sparse.matmul(x, Q_c) - paddle.matmul(M, Q_c)
assert B_t.shape[-2] == m, (B_t.shape, m)
assert B_t.shape[-1] == q, (B_t.shape, q)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = transjugate(Vh)
V = Q.matmul(V)
else:
Q = get_approximate_basis(x, q, niter=niter, M=M)
Q_c = conjugate(Q)
if M is None:
B = paddle.sparse.matmul(A_t, Q_c)
else:
B = paddle.sparse.matmul(A_t, Q_c) - paddle.matmul(M_t, Q_c)
B_t = transpose(B)
assert B_t.shape[-2] == q, (B_t.shape, q)
assert B_t.shape[-1] == n, (B_t.shape, n)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = transjugate(Vh)
U = Q.matmul(U)

return U, S, V

if not paddle.is_tensor(x):
raise ValueError(f'Input must be tensor, but got {type(x)}')

if not x.is_sparse():
raise ValueError('Input must be sparse, but got dense')

cuda_version = paddle.version.cuda()
if (
cuda_version is None
or cuda_version == 'False'
or int(cuda_version.split('.')[0]) < 11
):
raise ValueError('sparse.pca_lowrank API only support CUDA 11.x')

(m, n) = x.shape[-2:]

if q is None:
q = min(6, m, n)
elif not (q >= 0 and q <= min(m, n)):
raise ValueError(
'q(={}) must be non-negative integer'
' and not greater than min(m, n)={}'.format(q, min(m, n))
)
if not (niter >= 0):
raise ValueError(f'niter(={niter}) must be non-negative integer')

dtype = get_floating_dtype(x)

if not center:
return svd_lowrank(x, q, niter=niter, M=None)

if len(x.shape) != 2:
raise ValueError('input is expected to be 2-dimensional tensor')
# TODO: complement sparse_csr_tensor test
# when sparse.sum with axis(-2) is implemented
s_sum = paddle.sparse.sum(x, axis=-2)
s_val = s_sum.values() / m
c = paddle.sparse.sparse_coo_tensor(
s_sum.indices(), s_val, dtype=s_sum.dtype, place=s_sum.place
)
column_indices = c.indices()[0]
indices = paddle.zeros((2, len(column_indices)), dtype=column_indices.dtype)
indices[0] = column_indices
C_t = paddle.sparse.sparse_coo_tensor(
indices, c.values(), (n, 1), dtype=dtype, place=x.place
)

ones_m1_t = paddle.ones(x.shape[:-2] + [1, m], dtype=dtype)
M = transpose(paddle.matmul(C_t.to_dense(), ones_m1_t))
return svd_lowrank(x, q, niter=niter, M=M)
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .linalg import cov # noqa: F401
from .linalg import corrcoef # noqa: F401
from .linalg import norm # noqa: F401
from .linalg import pca_lowrank # noqa: F401
from .linalg import cond # noqa: F401
from .linalg import transpose # noqa: F401
from .linalg import lstsq # noqa: F401
Expand Down Expand Up @@ -333,6 +334,7 @@
'mv',
'matrix_power',
'qr',
'pca_lowrank',
'eigvals',
'eigvalsh',
'abs',
Expand Down
Loading

0 comments on commit 4ebb476

Please sign in to comment.