From 22c7ef7bee1a475fb6d5ce927176387560b79e13 Mon Sep 17 00:00:00 2001 From: dw_sjtu <46704444+sjtuWangDing@users.noreply.github.com> Date: Sat, 18 Jan 2020 06:37:29 +0800 Subject: [PATCH] * update source files (#17279) * fix - Invalid type for args: NDArray-or-Symbol-or-Scalar * fix - unix gpu: use cudaMalloc * fix - Not use CUDA_CALL * fix - not Free(handle) * fix - bug in c_lapack_api.h --- python/mxnet/ndarray/numpy/linalg.py | 74 +- python/mxnet/numpy/linalg.py | 70 +- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/linalg.py | 73 +- src/operator/c_lapack_api.cc | 13 + src/operator/c_lapack_api.h | 64 ++ src/operator/numpy/linalg/np_pinv-inl.h | 738 ++++++++++++++++++ src/operator/numpy/linalg/np_pinv.cc | 195 +++++ src/operator/numpy/linalg/np_pinv.cu | 41 + .../unittest/test_numpy_interoperability.py | 22 + tests/python/unittest/test_numpy_op.py | 79 ++ 11 files changed, 1367 insertions(+), 3 deletions(-) create mode 100644 src/operator/numpy/linalg/np_pinv-inl.h create mode 100644 src/operator/numpy/linalg/np_pinv.cc create mode 100644 src/operator/numpy/linalg/np_pinv.cu diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index e4fee158bea4..51be85851a9b 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -21,7 +21,79 @@ from . import _op as _mx_nd_np from . import _internal as _npi -__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv'] + + +def pinv(a, rcond=1e-15, hermitian=False): + r""" + Compute the (Moore-Penrose) pseudo-inverse of a matrix. + + Calculate the generalized inverse of a matrix using its + singular-value decomposition (SVD) and including all + *large* singular values. + + Parameters + ---------- + a : (..., M, N) ndarray + Matrix or stack of matrices to be pseudo-inverted. + rcond : (...) {float or ndarray of float}, optional + Cutoff for small singular values. + Singular values less than or equal to + ``rcond * largest_singular_value`` are set to zero. + Broadcasts against the stack of matrices. + hermitian : bool, optional + If True, `a` is assumed to be Hermitian (symmetric if real-valued), + enabling a more efficient method for finding singular values. + Defaults to False. + + Returns + ------- + B : (..., N, M) ndarray + The pseudo-inverse of `a`. If `a` is a `matrix` instance, then so + is `B`. + + Raises + ------ + MXNetError + If the SVD computation does not converge. + + Notes + ----- + The pseudo-inverse of a matrix A, denoted :math:`A^+`, is + defined as: "the matrix that 'solves' [the least-squares problem] + :math:`Ax = b`," i.e., if :math:`\\bar{x}` is said solution, then + :math:`A^+` is that matrix such that :math:`\\bar{x} = A^+b`. + + It can be shown that if :math:`Q_1 \\Sigma Q_2^T = A` is the singular + value decomposition of A, then + :math:`A^+ = Q_2 \\Sigma^+ Q_1^T`, where :math:`Q_{1,2}` are + orthogonal matrices, :math:`\\Sigma` is a diagonal matrix consisting + of A's so-called singular values, (followed, typically, by + zeros), and then :math:`\\Sigma^+` is simply the diagonal matrix + consisting of the reciprocals of A's singular values + (again, followed by zeros). [1]_ + + References + ---------- + .. [1] G. Strang, *Linear Algebra and Its Applications*, 2nd Ed., Orlando, + FL, Academic Press, Inc., 1980, pp. 139-142. + + Examples + -------- + The following example checks that ``a * a+ * a == a`` and + ``a+ * a * a+ == a+``: + >>> a = np.random.randn(2, 3) + >>> pinv_a = np.linalg.pinv(a) + >>> (a - np.dot(a, np.dot(pinv_a, a))).sum() + array(0.) + >>> (pinv_a - np.dot(pinv_a, np.dot(a, pinv_a))).sum() + array(0.) + """ + if hermitian is True: + raise NotImplementedError("hermitian is not supported yet...") + if _mx_nd_np._np.isscalar(rcond): + return _npi.pinv_scalar_rcond(a, rcond, hermitian) + return _npi.pinv(a, rcond, hermitian) def norm(x, ord=None, axis=None, keepdims=False): diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index 96fe1d311028..2e0da392ca5f 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -20,7 +20,75 @@ from __future__ import absolute_import from ..ndarray import numpy as _mx_nd_np -__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv'] + + +def pinv(a, rcond=1e-15, hermitian=False): + r""" + Compute the (Moore-Penrose) pseudo-inverse of a matrix. + + Calculate the generalized inverse of a matrix using its + singular-value decomposition (SVD) and including all + *large* singular values. + + Parameters + ---------- + a : (..., M, N) ndarray + Matrix or stack of matrices to be pseudo-inverted. + rcond : (...) {float or ndarray of float}, optional + Cutoff for small singular values. + Singular values less than or equal to + ``rcond * largest_singular_value`` are set to zero. + Broadcasts against the stack of matrices. + hermitian : bool, optional + If True, `a` is assumed to be Hermitian (symmetric if real-valued), + enabling a more efficient method for finding singular values. + Defaults to False. + + Returns + ------- + B : (..., N, M) ndarray + The pseudo-inverse of `a`. If `a` is a `matrix` instance, then so + is `B`. + + Raises + ------ + MXNetError + If the SVD computation does not converge. + + Notes + ----- + The pseudo-inverse of a matrix A, denoted :math:`A^+`, is + defined as: "the matrix that 'solves' [the least-squares problem] + :math:`Ax = b`," i.e., if :math:`\\bar{x}` is said solution, then + :math:`A^+` is that matrix such that :math:`\\bar{x} = A^+b`. + + It can be shown that if :math:`Q_1 \\Sigma Q_2^T = A` is the singular + value decomposition of A, then + :math:`A^+ = Q_2 \\Sigma^+ Q_1^T`, where :math:`Q_{1,2}` are + orthogonal matrices, :math:`\\Sigma` is a diagonal matrix consisting + of A's so-called singular values, (followed, typically, by + zeros), and then :math:`\\Sigma^+` is simply the diagonal matrix + consisting of the reciprocals of A's singular values + (again, followed by zeros). [1]_ + + References + ---------- + .. [1] G. Strang, *Linear Algebra and Its Applications*, 2nd Ed., Orlando, + FL, Academic Press, Inc., 1980, pp. 139-142. + + Examples + -------- + The following example checks that ``a * a+ * a == a`` and + ``a+ * a * a+ == a+``: + >>> a = np.random.randn(2, 3) + >>> pinv_a = np.linalg.pinv(a) + >>> (a - np.dot(a, np.dot(pinv_a, a))).sum() + array(0.) + >>> (pinv_a - np.dot(pinv_a, np.dot(a, pinv_a))).sum() + array(0.) + """ + return _mx_nd_np.linalg.pinv(a, rcond, hermitian) def norm(x, ord=None, axis=None, keepdims=False): diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 1156003045da..56944facac81 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -148,6 +148,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'linalg.solve', 'linalg.tensorinv', 'linalg.tensorsolve', + 'linalg.pinv', 'shape', 'trace', 'tril', diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index 0bfbb6ee540f..979742001aa8 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -22,7 +22,78 @@ from . import _op as _mx_sym_np from . import _internal as _npi -__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv'] + +def pinv(a, rcond=1e-15, hermitian=False): + r""" + Compute the (Moore-Penrose) pseudo-inverse of a matrix. + + Calculate the generalized inverse of a matrix using its + singular-value decomposition (SVD) and including all + *large* singular values. + + Parameters + ---------- + a : (..., M, N) ndarray + Matrix or stack of matrices to be pseudo-inverted. + rcond : (...) {float or ndarray of float}, optional + Cutoff for small singular values. + Singular values less than or equal to + ``rcond * largest_singular_value`` are set to zero. + Broadcasts against the stack of matrices. + hermitian : bool, optional + If True, `a` is assumed to be Hermitian (symmetric if real-valued), + enabling a more efficient method for finding singular values. + Defaults to False. + + Returns + ------- + B : (..., N, M) ndarray + The pseudo-inverse of `a`. If `a` is a `matrix` instance, then so + is `B`. + + Raises + ------ + MXNetError + If the SVD computation does not converge. + + Notes + ----- + The pseudo-inverse of a matrix A, denoted :math:`A^+`, is + defined as: "the matrix that 'solves' [the least-squares problem] + :math:`Ax = b`," i.e., if :math:`\\bar{x}` is said solution, then + :math:`A^+` is that matrix such that :math:`\\bar{x} = A^+b`. + + It can be shown that if :math:`Q_1 \\Sigma Q_2^T = A` is the singular + value decomposition of A, then + :math:`A^+ = Q_2 \\Sigma^+ Q_1^T`, where :math:`Q_{1,2}` are + orthogonal matrices, :math:`\\Sigma` is a diagonal matrix consisting + of A's so-called singular values, (followed, typically, by + zeros), and then :math:`\\Sigma^+` is simply the diagonal matrix + consisting of the reciprocals of A's singular values + (again, followed by zeros). [1]_ + + References + ---------- + .. [1] G. Strang, *Linear Algebra and Its Applications*, 2nd Ed., Orlando, + FL, Academic Press, Inc., 1980, pp. 139-142. + + Examples + -------- + The following example checks that ``a * a+ * a == a`` and + ``a+ * a * a+ == a+``: + >>> a = np.random.randn(2, 3) + >>> pinv_a = np.linalg.pinv(a) + >>> (a - np.dot(a, np.dot(pinv_a, a))).sum() + array(0.) + >>> (pinv_a - np.dot(pinv_a, np.dot(a, pinv_a))).sum() + array(0.) + """ + if hermitian is True: + raise NotImplementedError("hermitian is not supported yet...") + if _symbol._np.isscalar(rcond): + return _npi.pinv_scalar_rcond(a, rcond, hermitian) + return _npi.pinv(a, rcond, hermitian) def norm(x, ord=None, axis=None, keepdims=False): diff --git a/src/operator/c_lapack_api.cc b/src/operator/c_lapack_api.cc index 73b6138df5ea..442789e95d13 100644 --- a/src/operator/c_lapack_api.cc +++ b/src/operator/c_lapack_api.cc @@ -78,6 +78,16 @@ return 1; \ } + #define MXNET_LAPACK_CWRAPPER9(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \ + dtype *a, int lda, dtype *s, \ + dtype *u, int ldu, \ + dtype *vt, int ldvt, \ + dtype *work, int lwork, int *iwork) { \ + LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ + return 1; \ + } + #define MXNET_LAPACK_UNAVAILABLE(func) \ int mxnet_lapack_##func(...) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ @@ -111,4 +121,7 @@ MXNET_LAPACK_CWRAPPER7(sgesv, float) MXNET_LAPACK_CWRAPPER7(dgesv, double) + MXNET_LAPACK_CWRAPPER9(sgesdd, float) + MXNET_LAPACK_CWRAPPER9(dgesdd, double) + #endif // MSHADOW_USE_MKL == 0 diff --git a/src/operator/c_lapack_api.h b/src/operator/c_lapack_api.h index 8a7cbc067feb..8029a71f61d3 100644 --- a/src/operator/c_lapack_api.h +++ b/src/operator/c_lapack_api.h @@ -163,6 +163,23 @@ extern "C" { MXNET_LAPACK_FSIG_GESV(sgesv, float) MXNET_LAPACK_FSIG_GESV(dgesv, double) + + #ifdef __ANDROID__ + #define MXNET_LAPACK_FSIG_GESDD(func, dtype) \ + int func##_(char *jobz, int *m, int *n, dtype *a, int *lda, dtype *s, \ + dtype *u, int *ldu, \ + dtype *vt, int *ldvt, \ + dtype *work, int *lwork, int *iwork, int *info); + #else + #define MXNET_LAPACK_FSIG_GESDD(func, dtype) \ + void func##_(char *jobz, int *m, int *n, dtype *a, int *lda, dtype *s, \ + dtype *u, int *ldu, \ + dtype *vt, int *ldvt, \ + dtype *work, int *lwork, int *iwork, int *info); + #endif + + MXNET_LAPACK_FSIG_GESDD(sgesdd, float) + MXNET_LAPACK_FSIG_GESDD(dgesdd, double) } #endif // MSHADOW_USE_MKL == 0 @@ -282,6 +299,24 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { MXNET_LAPACK_CWRAP_GESVD(s, float) MXNET_LAPACK_CWRAP_GESVD(d, double) + // Computes the singular value decomposition of a general rectangular matrix + // using a divide and conquer method. + #define MXNET_LAPACK_CWRAP_GESDD(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##gesdd(int matrix_layout, int m, int n, \ + dtype *a, int lda, dtype *s, \ + dtype *u, int ldu, \ + dtype *vt, int ldvt, \ + dtype *work, int lwork, int *iwork) { \ + if (lwork != -1) { \ + return LAPACKE_##prefix##gesdd(matrix_layout, 'O', m, n, a, lda, \ + s, u, ldu, vt, ldvt); \ + } \ + *work = 0; \ + return 0; \ + } + MXNET_LAPACK_CWRAP_GESDD(s, float) + MXNET_LAPACK_CWRAP_GESDD(d, double) + #define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \ inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \ int *ipiv, dtype *work, int lwork) { \ @@ -421,6 +456,25 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { MXNET_LAPACK_CWRAP_GESVD(sgesvd, float) MXNET_LAPACK_CWRAP_GESVD(dgesvd, double) + #define MXNET_LAPACK_CWRAP_GESDD(func, dtype) \ + inline int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \ + dtype *a, int lda, dtype *s, \ + dtype *u, int ldu, \ + dtype *vt, int ldvt, \ + dtype *work, int lwork, int *iwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ + CHECK(false) << "MXNET_LAPACK_" << #func << " implemented for row-major layout only"; \ + return 1; \ + } else { \ + int info(0); \ + char jobz('O'); \ + func##_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info); \ + return info; \ + } \ + } + MXNET_LAPACK_CWRAP_GESDD(sgesdd, float) + MXNET_LAPACK_CWRAP_GESDD(dgesdd, double) + #define MXNET_LAPACK // Note: Both MXNET_LAPACK_*getrf, MXNET_LAPACK_*getri can only be called with col-major format @@ -506,6 +560,13 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \ int lda, int *ipiv, dtype *b, int ldb); \ + #define MXNET_LAPACK_CWRAPPER9(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \ + dtype *a, int lda, dtype *s, \ + dtype *u, int ldu, \ + dtype *vt, int ldvt, \ + dtype *work, int lwork, int *iwork); + #define MXNET_LAPACK_UNAVAILABLE(func) \ int mxnet_lapack_##func(...); MXNET_LAPACK_CWRAPPER1(spotrf, float) @@ -536,6 +597,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { MXNET_LAPACK_CWRAPPER7(sgesv, float) MXNET_LAPACK_CWRAPPER7(dgesv, double) + MXNET_LAPACK_CWRAPPER9(sgesdd, float) + MXNET_LAPACK_CWRAPPER9(dgesdd, double) + #undef MXNET_LAPACK_CWRAPPER1 #undef MXNET_LAPACK_CWRAPPER2 #undef MXNET_LAPACK_CWRAPPER3 diff --git a/src/operator/numpy/linalg/np_pinv-inl.h b/src/operator/numpy/linalg/np_pinv-inl.h new file mode 100644 index 000000000000..76bcc9a1ab64 --- /dev/null +++ b/src/operator/numpy/linalg/np_pinv-inl.h @@ -0,0 +1,738 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_pinv-inl.h + * \brief Placeholder for pinv + */ +#ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_PINV_INL_H_ +#define MXNET_OPERATOR_NUMPY_LINALG_NP_PINV_INL_H_ + +#include +#include +#include +#include "../../operator_common.h" +#include "../../mshadow_op.h" +#include "../../tensor/elemwise_binary_op.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "../../tensor/la_op.h" +#include "../../tensor/la_op-inl.h" +#include "../../tensor/matrix_op-inl.h" + +namespace mxnet { +namespace op { + +using namespace mshadow; + +struct PinvParam : public dmlc::Parameter { + bool hermitian; + DMLC_DECLARE_PARAMETER(PinvParam) { + DMLC_DECLARE_FIELD(hermitian) + .set_default(false) + .describe("If True, A is assumed to be Hermitian (symmetric if real-valued)."); + } +}; + +struct PinvScalarRcondParam : public dmlc::Parameter { + double rcond; + bool hermitian; + DMLC_DECLARE_PARAMETER(PinvScalarRcondParam) { + DMLC_DECLARE_FIELD(rcond) + .set_default(1e-15) + .describe("Cutoff for small singular values."); + DMLC_DECLARE_FIELD(hermitian) + .set_default(false) + .describe("If True, A is assumed to be Hermitian (symmetric if real-valued)."); + } +}; + +template +int linalg_gesdd_workspace_query(const int m, const int n, + const Tensor& UT, + const Tensor& S, + const Tensor& V, + Stream* s = 0); + +template +void linalg_gesdd(const int m, const int n, + const Tensor& UT, + const Tensor& S, + const Tensor& V, + const Tensor& work, + Stream *s = 0); + +template +void BatchSVDImpl(const int m, const int n, + const Tensor& UT, + const Tensor& S, + const Tensor& V, + const Tensor& work, + Stream *s = 0); + +#define LINALG_CPU_GESDD_WORKSPACE_QUERY(func, DType) \ +template<> inline \ +int linalg_gesdd_workspace_query(const int m, const int n, \ + const Tensor& UT, \ + const Tensor& S, \ + const Tensor& V, \ + Stream *s) { \ + DType work(0.0); \ + std::vector iwork(8 * std::min(m, n), 0); \ + if (m > n) { \ + MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, n, m, \ + UT.dptr_, UT.stride_, S.dptr_, \ + V.dptr_, V.stride_, \ + UT.dptr_, UT.stride_, \ + &work, -1, iwork.data()); \ + } else { \ + MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, n, m, \ + V.dptr_, V.stride_, S.dptr_, \ + V.dptr_, V.stride_, \ + UT.dptr_, UT.stride_, \ + &work, -1, iwork.data()); \ + } \ + return static_cast(work); \ +} + +#define LINALG_CPU_GESDD(func, DType) \ +template<> inline \ +void linalg_gesdd(const int m, \ + const int n, \ + const Tensor& UT, \ + const Tensor& S, \ + const Tensor& V, \ + const Tensor& work, \ + Stream *s) { \ + std::vector iwork(8 * std::min(m, n), 0); \ + int res(0); \ + if (m > n) { \ + res = MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, n, m, \ + UT.dptr_, UT.stride_, S.dptr_, \ + V.dptr_, V.stride_, \ + UT.dptr_, UT.stride_, \ + work.dptr_, work.shape_.Size(), iwork.data()); \ + } else { \ + res = MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, n, m, \ + V.dptr_, V.stride_, S.dptr_, \ + V.dptr_, V.stride_, \ + UT.dptr_, UT.stride_, \ + work.dptr_, work.shape_.Size(), iwork.data()); \ + } \ + CHECK_GE(res, 0) << #func << ": the " << -res \ + << "-th argument had an illegal value"; \ + CHECK_LE(res, 0) << #func << " did not converge, updating process failed."; \ +} + +LINALG_CPU_GESDD_WORKSPACE_QUERY(sgesdd, float) +LINALG_CPU_GESDD_WORKSPACE_QUERY(dgesdd, double) + +LINALG_CPU_GESDD(sgesdd, float) +LINALG_CPU_GESDD(dgesdd, double) + +#ifdef __CUDACC__ + +#define LINALG_GPU_GESDD_WORKSPACE_QUERY(DType) \ +template<> inline \ +int linalg_gesdd_workspace_query(const int m, const int n, \ + const Tensor& U, \ + const Tensor& S, \ + const Tensor& VT, \ + Stream *s) { \ + LOG(FATAL) << "Lapack gesdd workspace query routines is unsupported in gpu!"; \ + return 0; \ +} + +#define LINALG_GPU_GESDD(DType) \ +template<> inline \ +void linalg_gesdd(const int m, const int n, \ + const Tensor& U, \ + const Tensor& S, \ + const Tensor& VT, \ + const Tensor& work, \ + Stream *s) { \ + LOG(FATAL) << "Lapack gesdd routines is unsupported in gpu!"; \ +} + +LINALG_GPU_GESDD_WORKSPACE_QUERY(float) +LINALG_GPU_GESDD_WORKSPACE_QUERY(double) + +LINALG_GPU_GESDD(float) +LINALG_GPU_GESDD(double) + +#endif // __CUDACC__ + +#define BATCH_SVD_IMPL_CPU(DType) \ +template<> inline \ +void BatchSVDImpl(const int m, const int n, \ + const Tensor& UT, \ + const Tensor& S, \ + const Tensor& V, \ + const Tensor& work, \ + Stream *s) { \ + for (index_t i = 0; i < S.size(0); ++i) { \ + linalg_gesdd(m, n, UT[i], S[i], V[i], work, s); \ + } \ +} + +BATCH_SVD_IMPL_CPU(float) +BATCH_SVD_IMPL_CPU(double) + +#ifdef __CUDACC__ + +#define BATCH_SVD_IMPL_GPU(DType) \ +template<> inline \ +void BatchSVDImpl(const int m, const int n, \ + const Tensor& UT, \ + const Tensor& S, \ + const Tensor& V, \ + const Tensor& work, \ + Stream *s) { \ + for (index_t i = 0; i < S.size(0); ++i) { \ + linalg_gesvd(UT[i], S[i], V[i], work, s); \ + } \ +} + +BATCH_SVD_IMPL_GPU(float) +BATCH_SVD_IMPL_GPU(double) + +#endif // __CUDACC__ + +struct SingularValSmax { + template + MSHADOW_XINLINE static void Map(int i, DType *smax_ptr, const DType *s_ptr, + const int length, const int lds) { + const DType *s_iptr = s_ptr + i * lds; + DType *smax_iptr = smax_ptr + i; + *smax_iptr = s_iptr[0]; + for (int j = 1; j < length; ++j) { + *smax_iptr = s_iptr[j] > *smax_iptr ? s_iptr[j] : *smax_iptr; + } + } +}; + +struct DiscardSmallSingularVal { + template + MSHADOW_XINLINE static void Map(int i, DType *s_ptr, const DType *large_ptr) { + if (large_ptr[i]) { + s_ptr[i] = DType(1) / s_ptr[i]; + } else { + s_ptr[i] = DType(0); + } + } +}; + +struct DiscardSmallSingularValWithScalarRcond { + template + MSHADOW_XINLINE static void Map(int i, DType *s_ptr, const int length, + const int lds, const double rcond) { + DType *s_iptr = s_ptr + i * lds; + DType smax_i = s_iptr[0]; + for (int j = 1; j < length; ++j) { + smax_i = s_iptr[j] > smax_i ? s_iptr[j] : smax_i; + } + for (int j = 0; j < length; ++j) { + s_iptr[j] = (s_iptr[j] > rcond * smax_i) ? (DType(1) / s_iptr[j]) : (DType(0)); + } + } +}; + +inline void GetPinvShape(const mxnet::TShape& a_shape, + mxnet::TShape *ut_shape, + mxnet::TShape *s_shape, + mxnet::TShape *v_shape, + mxnet::TShape *u_shape = nullptr, + mxnet::TShape *vt_shape = nullptr) { + const int a_ndim = a_shape.ndim(); + const int m = a_shape[a_ndim - 2]; + const int n = a_shape[a_ndim - 1]; + + // Calculate S shape. + std::vector s_shape_vec(a_ndim - 1, -1); + for (int i = 0; i < a_ndim - 2; ++i) { + s_shape_vec[i] = a_shape[i]; + } + s_shape_vec[a_ndim - 2] = std::min(m, n); + *s_shape = mxnet::TShape(s_shape_vec.begin(), s_shape_vec.end()); + + std::vector temp_shape_vec(a_ndim, -1); + for (int i = 0; i < a_ndim - 2; ++i) { + temp_shape_vec[i] = a_shape[i]; + } + temp_shape_vec[a_ndim - 2] = std::min(m, n); + temp_shape_vec[a_ndim - 1] = std::min(m, n); + if (m >= n) { + // UT must have same shape as A. + *ut_shape = a_shape; + *v_shape = mxnet::TShape(temp_shape_vec.begin(), temp_shape_vec.end()); + if (u_shape && vt_shape) { + *vt_shape = mxnet::TShape(temp_shape_vec.begin(), temp_shape_vec.end()); + *u_shape = a_shape; + (*u_shape)[a_ndim - 2] = a_shape[a_ndim - 1]; + (*u_shape)[a_ndim - 1] = a_shape[a_ndim - 2]; + } + } else { + // V must have same shape as A. + *v_shape = a_shape; + *ut_shape = mxnet::TShape(temp_shape_vec.begin(), temp_shape_vec.end()); + if (u_shape && vt_shape) { + *u_shape = mxnet::TShape(temp_shape_vec.begin(), temp_shape_vec.end()); + *vt_shape = a_shape; + (*vt_shape)[a_ndim - 2] = a_shape[a_ndim - 1]; + (*vt_shape)[a_ndim - 1] = a_shape[a_ndim - 2]; + } + } +} + +inline void GetOrCheckCutoffAndLargeShape(const nnvm::NodeAttrs& attrs, + const mxnet::TShape& a_shape, + const mxnet::TShape& rcond_shape, + mxnet::TShape *cutoff_shape = nullptr, + mxnet::TShape *large_shape = nullptr) { + if (!shape_is_known(a_shape)) { return ; } + const int a_ndim = a_shape.ndim(); + const int rcond_ndim = rcond_shape.ndim(); + mxnet::TShape s_shape(a_ndim - 1, 1); + mxnet::TShape smax_shape(a_ndim - 1, 1); + mxnet::TShape new_rcond_shape(rcond_ndim + 1, 1); + // Get new rcond shape. + for (int i = 0; i < rcond_ndim; ++i) { + new_rcond_shape[i] = rcond_shape[i]; + } + // Get Smax shape. + for (int i = 0; i < a_ndim - 2; ++i) { + s_shape[i] = a_shape[i]; + smax_shape[i] = a_shape[i]; + } + s_shape[s_shape.ndim() - 1] = std::min(a_shape[a_ndim - 2], a_shape[a_ndim - 1]); + smax_shape[smax_shape.ndim() - 1] = 1; + // Check cutoff = rcond[..., newaxis] * smax. + mxnet::ShapeVector in_shape_vec1({ new_rcond_shape, smax_shape }); + mxnet::ShapeVector out_shape_vec1(1); + mxnet::op::BinaryBroadcastShape(attrs, &in_shape_vec1, &out_shape_vec1); + // Check large = s > cutoff. + mxnet::ShapeVector in_shape_vec2({ s_shape, out_shape_vec1[0] }); + mxnet::ShapeVector out_shape_vec2(1); + mxnet::op::BinaryBroadcastShape(attrs, &in_shape_vec2, &out_shape_vec2); + // Check s = divide(1, s, where=large, out=s). + if (s_shape != out_shape_vec2[0]) { + LOG(FATAL) << "Error: non-broadcastable output operand with shape " + << s_shape << " doesn't match the broadcast shape " << out_shape_vec2[0]; + } + if (cutoff_shape) { + *cutoff_shape = out_shape_vec1[0]; + } + if (large_shape) { + *large_shape = out_shape_vec2[0]; + } +} + +template +size_t SVDWorkspaceSize(const TBlob& a, + const TBlob& pinv_a, + const mxnet::TShape& u_shape, + const mxnet::TShape& s_shape, + const mxnet::TShape& v_shape, + const std::vector& req, + const OpContext& ctx) { + if (kNullOp == req[0]) { return 0U; } + + // Zero-size input, no need to launch kernel + if (0U == a.Size()) { return 0U; } + + size_t work_space_size = 0; + Stream *s = ctx.get_stream(); + MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, OType, { + const int a_ndim = a.shape_.ndim(); + const int u_ndim = u_shape.ndim(); + const int s_ndim = s_shape.ndim(); + const int v_ndim = v_shape.ndim(); + mxnet::TShape u_shape2 = Shape2(u_shape[u_ndim - 2], u_shape[u_ndim - 1]); + mxnet::TShape s_shape1 = Shape1(s_shape[s_ndim - 1]); + mxnet::TShape v_shape2 = Shape2(v_shape[v_ndim - 2], v_shape[v_ndim - 1]); + if (xpu::kDevCPU) { + std::vector u_vec(u_shape2.Size(), 0); + std::vector s_vec(s_shape1.Size(), 0); + std::vector v_vec(v_shape2.Size(), 0); + // For workspace size in linalg_gesdd. + work_space_size += linalg_gesdd_workspace_query( + a.shape_[a_ndim - 2], a.shape_[a_ndim - 1], + TBlob(u_vec.data(), u_shape2, a.dev_mask(), a.dev_id()).get(s), + TBlob(s_vec.data(), s_shape1, a.dev_mask(), a.dev_id()).get(s), + TBlob(v_vec.data(), v_shape2, a.dev_mask(), a.dev_id()).get(s), s); + } else { + Storage::Handle u_handle = + Storage::Get()->Alloc(sizeof(OType) * u_shape2.Size(), Context::GPU()); + Storage::Handle s_handle = + Storage::Get()->Alloc(sizeof(OType) * s_shape1.Size(), Context::GPU()); + Storage::Handle v_handle = + Storage::Get()->Alloc(sizeof(OType) * v_shape2.Size(), Context::GPU()); + TBlob u_data(static_cast(u_handle.dptr), u_shape2, a.dev_mask(), a.dev_id()); + TBlob s_data(static_cast(s_handle.dptr), s_shape1, a.dev_mask(), a.dev_id()); + TBlob v_data(static_cast(v_handle.dptr), v_shape2, a.dev_mask(), a.dev_id()); + // For workspace size in linalg_gesvd. + if (a.shape_[a_ndim - 2] >= a.shape_[a_ndim - 1]) { + work_space_size += linalg_gesvd_workspace_query(v_data.get(s), + s_data.get(s), + u_data.get(s), s); + } else { + work_space_size += linalg_gesvd_workspace_query(u_data.get(s), + s_data.get(s), + v_data.get(s), s); + } + Storage::Get()->Free(u_handle); + Storage::Get()->Free(s_handle); + Storage::Get()->Free(v_handle); + } + }); + return work_space_size; +} + +// Calculates workspace size of pinv op forward. +template +size_t PinvForwardWorkspaceSize(const TBlob& a, + const TBlob& rcond, + const TBlob& pinv_a, + const nnvm::NodeAttrs& attrs, + const std::vector& req, + const OpContext& ctx) { + if (kNullOp == req[0]) { return 0U; } + // Zero-size input, no need to launch kernel + if (0U == a.Size()) { return 0U; } + + size_t work_space_size = 0; + mxnet::TShape u_shape, s_shape, v_shape; + GetPinvShape(a.shape_, &u_shape, &s_shape, &v_shape); + + MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, OType, { + mxnet::TShape smax_shape = s_shape; + smax_shape[s_shape.ndim() - 1] = 1; + mxnet::TShape cutoff_shape; + mxnet::TShape large_shape; + GetOrCheckCutoffAndLargeShape(attrs, a.shape_, rcond.shape_, &cutoff_shape, &large_shape); + work_space_size += // For #gesdd_ or #gesvd work space size. + SVDWorkspaceSize(a, pinv_a, u_shape, s_shape, v_shape, req, ctx); + work_space_size += rcond.shape_.Size(); // For rcond. + work_space_size += 2 * u_shape.Size(); // For UT. + work_space_size += s_shape.Size(); // For S. + work_space_size += 2 * v_shape.Size(); // For V. + work_space_size += smax_shape.Size(); // For Smax. + work_space_size += cutoff_shape.Size(); // For Cutoff. + work_space_size += large_shape.Size(); // For Large. + return work_space_size * sizeof(OType); + }); + LOG(FATAL) << "InternalError: cannot reach here"; + return 0U; +} + +inline mxnet::TShape GetTransAxis(const mxnet::TShape& in_shape) { + const int in_ndim = in_shape.ndim(); + std::vector trans_axis(in_ndim, -1); + for (int i = 0; i < in_ndim - 2; ++i) { trans_axis[i] = i; } + trans_axis[in_ndim - 2] = in_ndim - 1; + trans_axis[in_ndim - 1] = in_ndim - 2; + return mxnet::TShape(trans_axis.begin(), trans_axis.end()); +} + +template +void PinvOpForwardImpl(const TBlob& a, + const TBlob& rcond, + const TBlob& pinv_a, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& req, + const Tensor& workspace) { + Stream *s = ctx.get_stream(); + const mxnet::TShape a_shape = a.shape_; + const mxnet::TShape rcond_shape = rcond.shape_; + const int a_ndim = a_shape.ndim(); + const int rcond_ndim = rcond_shape.ndim(); + mxnet::TShape rcond_shape_newaxis(rcond_ndim + 1, 1); + for (int i = 0; i < rcond_ndim; ++i) { + rcond_shape_newaxis[i] = rcond_shape[i]; + } + mxnet::TShape s_shape; + mxnet::TShape u_shape; + mxnet::TShape ut_shape; + mxnet::TShape v_shape; + mxnet::TShape vt_shape; + GetPinvShape(a_shape, &u_shape, &s_shape, &v_shape, &ut_shape, &vt_shape); + mxnet::TShape smax_shape = s_shape; + smax_shape[s_shape.ndim() - 1] = 1; + mxnet::TShape s_shape_newaxis(s_shape.ndim() + 1, 1); + for (int i = 0; i < s_shape.ndim(); ++i) { + s_shape_newaxis[i] = s_shape[i]; + } + mxnet::TShape cutoff_shape; + mxnet::TShape large_shape; + GetOrCheckCutoffAndLargeShape(attrs, a_shape, rcond_shape, &cutoff_shape, &large_shape); + + MSHADOW_SGL_DBL_TYPE_SWITCH(a.type_flag_, AType, { + MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, DType, { + const size_t workspace_size = (workspace.shape_.Size() + sizeof(DType) - 1) / sizeof(DType); + const size_t lwork = workspace_size - rcond_shape_newaxis.Size() + - 2 * u_shape.Size() - s_shape.Size() - 2 * v_shape.Size() - smax_shape.Size() + - cutoff_shape.Size() - large_shape.Size(); + DType *work_ptr = reinterpret_cast(workspace.dptr_); + DType *rcond_ptr = work_ptr + lwork; + DType *ut_ptr = rcond_ptr + rcond_shape_newaxis.Size(); + DType *u_ptr = ut_ptr + ut_shape.Size(); + DType *s_ptr = u_ptr + u_shape.Size(); + DType *v_ptr = s_ptr + s_shape.Size(); + DType *vt_ptr = v_ptr + v_shape.Size(); + DType *smax_ptr = vt_ptr + vt_shape.Size(); + DType *cutoff_ptr = smax_ptr + smax_shape.Size(); + DType *large_ptr = cutoff_ptr + cutoff_shape.Size(); + // Step1: Calculate SVD. + TBlob work_data(work_ptr, Shape1(lwork), a.dev_mask(), a.dev_id()); + TBlob u_data(u_ptr, u_shape, a.dev_mask(), a.dev_id()); + TBlob ut_data(ut_ptr, ut_shape, a.dev_mask(), a.dev_id()); + TBlob v_data(v_ptr, v_shape, a.dev_mask(), a.dev_id()); + TBlob vt_data(vt_ptr, vt_shape, a.dev_mask(), a.dev_id()); + TBlob s_data(s_ptr, s_shape, a.dev_mask(), a.dev_id()); + // Noet: Only a_shape[a_ndim - 2] > a_shape[a_ndim - 1], need transpose operation. + if (a_shape[a_ndim - 2] > a_shape[a_ndim - 1]) { + mxnet_op::Kernel::Launch( + s, a.Size(), u_ptr, a.dptr()); + mxnet::op::TransposeImpl(ctx.run_ctx, u_data, ut_data, // u_data: src, ut_data: res + GetTransAxis(u_data.shape_)); + BatchSVDImpl(a_shape[a_ndim - 1], a_shape[a_ndim - 2], + vt_data.FlatToKD(s), + s_data.FlatToKD(s), + ut_data.FlatToKD(s), + work_data.FlatToKD(s), s); + } else { + mxnet_op::Kernel::Launch( + s, a.Size(), v_ptr, a.dptr()); + BatchSVDImpl(a_shape[a_ndim - 2], a_shape[a_ndim - 1], + u_data.FlatToKD(s), + s_data.FlatToKD(s), + v_data.FlatToKD(s), + work_data.FlatToKD(s), s); + } + TBlob smax_data(smax_ptr, smax_shape, a.dev_mask(), a.dev_id()); + TBlob cutoff_data(cutoff_ptr, cutoff_shape, a.dev_mask(), a.dev_id()); + TBlob large_data(large_ptr, large_shape, a.dev_mask(), a.dev_id()); + TBlob rcond_data(rcond_ptr, rcond_shape_newaxis, a.dev_mask(), a.dev_id()); + Tensor S = s_data.FlatToKD(s); + Tensor Smax = smax_data.FlatToKD(s); + mxnet_op::Kernel::Launch( + s, rcond_shape_newaxis.Size(), rcond_ptr, rcond.dptr()); + // Step2: Calculate Smax. + mxnet_op::Kernel::Launch( + s, S.size(0), Smax.dptr_, S.dptr_, S.size(1), S.stride_); + // Step3: Calculate Cutoff. + std::vector temp_req({kWriteTo}); + mxnet::op::BinaryBroadcastCompute(attrs, ctx, + {rcond_data, smax_data}, + temp_req, {cutoff_data}); + // Step4: Calculte Large. + mxnet::op::BinaryBroadcastCompute(attrs, ctx, + {s_data, cutoff_data}, + temp_req, {large_data}); + // Step5: Discard small singular values. + mxnet_op::Kernel::Launch( + s, s_data.Size(), s_data.dptr(), large_data.dptr()); + // Step6: Calculte matmul(transpose(v), multiply(s[..., newaxis], transpose(u))). + // Note: No need transpose when a_shape[a_ndim - 2] >= a_shape[a_ndim - 1] + if (a_shape[a_ndim - 2] <= a_shape[a_ndim - 1]) { + mxnet::op::TransposeImpl(ctx.run_ctx, u_data, ut_data, // u_data: src, ut_data: res + GetTransAxis(u_data.shape_)); + mxnet::op::TransposeImpl(ctx.run_ctx, v_data, vt_data, // v_data: src, vt_data: res + GetTransAxis(v_data.shape_)); + } + s_data = s_data.reshape(s_shape_newaxis); + u_data = ut_data.reshape(ut_shape); + mxnet::op::BinaryBroadcastCompute(attrs, ctx, {s_data, ut_data}, + temp_req, {u_data}); + gemm2::op(vt_data.FlatToKD(s), + u_data.FlatToKD(s), + pinv_a.FlatToKD(s), + DType(1), false, false, s); + }); + }); +} + +template +void PinvOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + Stream *s = ctx.get_stream(); + const TBlob& a = inputs[0]; + const TBlob& rcond = inputs[1]; + const TBlob& pinv_a = outputs[0]; + const mxnet::TShape a_shape = a.shape_; + + if (kNullOp == req[0]) { return; } + + // Zero-size output, no need to launch kernel + if (0U == a.Size()) { return; } + + size_t workspace_size = PinvForwardWorkspaceSize(a, rcond, pinv_a, attrs, req, ctx); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + PinvOpForwardImpl(a, rcond, pinv_a, attrs, ctx, req, workspace); +} + +// Calculates workspace size of pinv scalar rcond op forward. +template +size_t PinvScalarRcondForwardWorkspaceSize(const TBlob& a, + const TBlob& pinv_a, + const nnvm::NodeAttrs& attrs, + const std::vector& req, + const OpContext& ctx) { + if (kNullOp == req[0]) { return 0U; } + // Zero-size input, no need to launch kernel + if (0U == a.Size()) { return 0U; } + + size_t work_space_size = 0; + mxnet::TShape u_shape, s_shape, v_shape; + GetPinvShape(a.shape_, &u_shape, &s_shape, &v_shape); + + MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, OType, { + mxnet::TShape smax_shape = s_shape; + smax_shape[s_shape.ndim() - 1] = 1; + work_space_size += // For #gesdd_ or #gesvd work space size. + SVDWorkspaceSize(a, pinv_a, u_shape, s_shape, v_shape, req, ctx); + work_space_size += 2 * u_shape.Size(); // For UT. + work_space_size += s_shape.Size(); // For S. + work_space_size += 2 * v_shape.Size(); // For V. + return work_space_size * sizeof(OType); + }); + LOG(FATAL) << "InternalError: cannot reach here"; + return 0U; +} + +template +void PinvScalarRcondOpForwardImpl(const TBlob& a, + const TBlob& pinv_a, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& req, + const Tensor& workspace) { + Stream *s = ctx.get_stream(); + const mxnet::TShape a_shape = a.shape_; + const int a_ndim = a_shape.ndim(); + + mxnet::TShape s_shape; + mxnet::TShape u_shape; + mxnet::TShape ut_shape; + mxnet::TShape v_shape; + mxnet::TShape vt_shape; + GetPinvShape(a_shape, &u_shape, &s_shape, &v_shape, &ut_shape, &vt_shape); + mxnet::TShape s_shape_newaxis(s_shape.ndim() + 1, 1); + for (int i = 0; i < s_shape.ndim(); ++i) { + s_shape_newaxis[i] = s_shape[i]; + } + MSHADOW_SGL_DBL_TYPE_SWITCH(a.type_flag_, AType, { + MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, DType, { + const double rcond = nnvm::get(attrs.parsed).rcond; + const size_t workspace_size = (workspace.shape_.Size() + sizeof(DType) - 1) / sizeof(DType); + const size_t lwork = workspace_size - 2 * u_shape.Size() - s_shape.Size() + - 2 * v_shape.Size(); + DType *work_ptr = reinterpret_cast(workspace.dptr_); + DType *u_ptr = work_ptr + lwork; + DType *ut_ptr = u_ptr + u_shape.Size(); + DType *s_ptr = ut_ptr + ut_shape.Size(); + DType *v_ptr = s_ptr + s_shape.Size(); + DType *vt_ptr = v_ptr + v_shape.Size(); + // Step1: Calculate SVD. + TBlob work_data(work_ptr, Shape1(lwork), a.dev_mask(), a.dev_id()); + TBlob u_data(u_ptr, u_shape, a.dev_mask(), a.dev_id()); + TBlob ut_data(ut_ptr, ut_shape, a.dev_mask(), a.dev_id()); + TBlob v_data(v_ptr, v_shape, a.dev_mask(), a.dev_id()); + TBlob vt_data(vt_ptr, vt_shape, a.dev_mask(), a.dev_id()); + TBlob s_data(s_ptr, s_shape, a.dev_mask(), a.dev_id()); + Tensor S = s_data.FlatToKD(s); + // Noet: Only a_shape[a_ndim - 2] > a_shape[a_ndim - 1], need transpose operation. + if (a_shape[a_ndim - 2] > a_shape[a_ndim - 1]) { + mxnet_op::Kernel::Launch( + s, a.Size(), u_ptr, a.dptr()); + mxnet::op::TransposeImpl(ctx.run_ctx, u_data, ut_data, // u_data: src, ut_data: res + GetTransAxis(u_data.shape_)); + BatchSVDImpl(a_shape[a_ndim - 1], a_shape[a_ndim - 2], + vt_data.FlatToKD(s), + s_data.FlatToKD(s), + ut_data.FlatToKD(s), + work_data.FlatToKD(s), s); + } else { + mxnet_op::Kernel::Launch( + s, a.Size(), v_ptr, a.dptr()); + BatchSVDImpl(a_shape[a_ndim - 2], a_shape[a_ndim - 1], + u_data.FlatToKD(s), + s_data.FlatToKD(s), + v_data.FlatToKD(s), + work_data.FlatToKD(s), s); + } + // Step2: Discard small singular values. + mxnet_op::Kernel::Launch( + s, S.size(0), S.dptr_, S.size(1), S.stride_, rcond); + // Step3: Calculte matmul(transpose(v), multiply(s[..., newaxis], transpose(u))). + // Note: No need transpose when a_shape[a_ndim - 2] >= a_shape[a_ndim - 1] + if (a_shape[a_ndim - 2] <= a_shape[a_ndim - 1]) { + mxnet::op::TransposeImpl(ctx.run_ctx, u_data, ut_data, // u_data: src, ut_data: res + GetTransAxis(u_data.shape_)); + mxnet::op::TransposeImpl(ctx.run_ctx, v_data, vt_data, // v_data: src, vt_data: res + GetTransAxis(v_data.shape_)); + } + s_data = s_data.reshape(s_shape_newaxis); + u_data = ut_data.reshape(ut_shape); + mxnet::op::BinaryBroadcastCompute(attrs, ctx, {s_data, ut_data}, + {kWriteTo}, {u_data}); + gemm2::op(vt_data.FlatToKD(s), + u_data.FlatToKD(s), + pinv_a.FlatToKD(s), + DType(1), false, false, s); + }); + }); +} + +template +void PinvScalarRcondOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + Stream *s = ctx.get_stream(); + const TBlob& a = inputs[0]; + const TBlob& pinv_a = outputs[0]; + + if (kNullOp == req[0]) { return; } + // Zero-size output, no need to launch kernel + if (0U == a.Size()) { return; } + + // Calculate workspace size. + size_t workspace_size = PinvScalarRcondForwardWorkspaceSize(a, pinv_a, attrs, req, ctx); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + PinvScalarRcondOpForwardImpl(a, pinv_a, attrs, ctx, req, workspace); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_LINALG_NP_PINV_INL_H_ diff --git a/src/operator/numpy/linalg/np_pinv.cc b/src/operator/numpy/linalg/np_pinv.cc new file mode 100644 index 000000000000..8f1968bb046a --- /dev/null +++ b/src/operator/numpy/linalg/np_pinv.cc @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_pinv.cc + * \brief CPU implementation of the PINV Operator + */ + +#include "./np_pinv-inl.h" + +namespace mxnet { +namespace op { + +bool PinvOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const mxnet::TShape& a_shape = (*in_attrs)[0]; + const mxnet::TShape& rcond_shape = (*in_attrs)[1]; + const mxnet::TShape& pinv_shape = (*out_attrs)[0]; + const int a_ndim = a_shape.ndim(); + + if (shape_is_known(a_shape)) { + // Forward shape inference. + CHECK_GE(a_ndim, 2) + << "Array must be at least two-dimensional"; + // Calculte pinv shape. + std::vector pinv_shape_vec(a_ndim, -1); + for (int i = 0; i < a_ndim - 2; ++i) { + pinv_shape_vec[i] = a_shape[i]; + } + pinv_shape_vec[a_ndim - 2] = a_shape[a_ndim - 1]; + pinv_shape_vec[a_ndim - 1] = a_shape[a_ndim - 2]; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(pinv_shape_vec.begin(), pinv_shape_vec.end())); + // Check rcond shape. + GetOrCheckCutoffAndLargeShape(attrs, a_shape, rcond_shape, nullptr, nullptr); + } else { + // Backward shape inference. + if (shape_is_known(pinv_shape)) { + const int pinv_ndim = pinv_shape.ndim(); + CHECK_GE(pinv_ndim, 2) + << "Array must be at least two-dimensional"; + // Calculte 'a' shape. + std::vector a_shape_vec(pinv_ndim, -1); + for (int i = 0; i < pinv_ndim - 2; ++i) { + a_shape_vec[i] = pinv_shape[i]; + } + a_shape_vec[pinv_ndim - 2] = pinv_shape[pinv_ndim - 1]; + a_shape_vec[pinv_ndim - 1] = pinv_shape[pinv_ndim - 2]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, mxnet::TShape(a_shape_vec.begin(), a_shape_vec.end())); + // Check rcond shape. + GetOrCheckCutoffAndLargeShape(attrs, (*in_attrs)[0], rcond_shape, nullptr, nullptr); + } + } + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); +} + +inline bool PinvOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + int a_type = in_attrs->at(0); + int rcond_type = in_attrs->at(1); + // unsupport float16 + CHECK_NE(a_type, mshadow::kFloat16) + << "array type float16 is unsupported in linalg."; + CHECK(rcond_type == mshadow::kFloat32 || rcond_type == mshadow::kFloat64) + << "rcond type should be float32 or float64."; + if (mshadow::kFloat32 == a_type) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64); + } + return out_attrs->at(0) != -1; +} + +DMLC_REGISTER_PARAMETER(PinvParam); + +NNVM_REGISTER_OP(_npi_pinv) +.describe(R"code()code" ADD_FILELINE) +.set_attr_parser(mxnet::op::ParamParser) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", [](const NodeAttrs& attrs){ + return std::vector{"A", "rcond"}; +}) +.set_attr("FInferShape", PinvOpShape) +.set_attr("FInferType", PinvOpType) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs){ + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FCompute", PinvOpForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("A", "NDArray-or-Symbol", "Tensor of matrix") +.add_argument("rcond", "NDArray-or-Symbol", "Cutoff for small singular values.") +.add_arguments(PinvParam::__FIELDS__()); + +bool PinvScalarRcondOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const mxnet::TShape& a_shape = (*in_attrs)[0]; + const mxnet::TShape& pinv_shape = (*out_attrs)[0]; + const int a_ndim = a_shape.ndim(); + + if (shape_is_known(a_shape)) { + // Forward shape inference. + CHECK_GE(a_ndim, 2) + << "Array must be at least two-dimensional"; + // Calculte pinv shape. + std::vector pinv_shape_vec(a_ndim, -1); + for (int i = 0; i < a_ndim - 2; ++i) { + pinv_shape_vec[i] = a_shape[i]; + } + pinv_shape_vec[a_ndim - 2] = a_shape[a_ndim - 1]; + pinv_shape_vec[a_ndim - 1] = a_shape[a_ndim - 2]; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(pinv_shape_vec.begin(), pinv_shape_vec.end())); + } else { + // Backward shape inference. + if (shape_is_known(pinv_shape)) { + const int pinv_ndim = pinv_shape.ndim(); + CHECK_GE(pinv_ndim, 2) + << "Array must be at least two-dimensional"; + // Calculte 'a' shape. + std::vector a_shape_vec(pinv_ndim, -1); + for (int i = 0; i < pinv_ndim - 2; ++i) { + a_shape_vec[i] = pinv_shape[i]; + } + a_shape_vec[pinv_ndim - 2] = pinv_shape[pinv_ndim - 1]; + a_shape_vec[pinv_ndim - 1] = pinv_shape[pinv_ndim - 2]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, mxnet::TShape(a_shape_vec.begin(), a_shape_vec.end())); + } + } + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); +} + +inline bool PinvScalarRcondOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + int a_type = in_attrs->at(0); + // unsupport float16 + CHECK_NE(a_type, mshadow::kFloat16) + << "array type float16 is unsupported in linalg."; + if (mshadow::kFloat32 == a_type) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64); + } + return out_attrs->at(0) != -1; +} + +DMLC_REGISTER_PARAMETER(PinvScalarRcondParam); + +NNVM_REGISTER_OP(_npi_pinv_scalar_rcond) +.describe(R"code()code" ADD_FILELINE) +.set_attr_parser(mxnet::op::ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", [](const NodeAttrs& attrs){ + return std::vector{"A"}; +}) +.set_attr("FInferShape", PinvScalarRcondOpShape) +.set_attr("FInferType", PinvScalarRcondOpType) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs){ + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FCompute", PinvScalarRcondOpForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("A", "NDArray-or-Symbol", "Tensor of matrix") +.add_arguments(PinvScalarRcondParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/linalg/np_pinv.cu b/src/operator/numpy/linalg/np_pinv.cu new file mode 100644 index 000000000000..83ae5df6de2e --- /dev/null +++ b/src/operator/numpy/linalg/np_pinv.cu @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_pinv.cu + * \brief GPU implementation placeholder of Pinv Operator + */ + +#include "./np_pinv-inl.h" + +namespace mxnet { +namespace op { + +#if MXNET_USE_CUSOLVER == 1 + +NNVM_REGISTER_OP(_npi_pinv) +.set_attr("FCompute", PinvOpForward); + +NNVM_REGISTER_OP(_npi_pinv_scalar_rcond) +.set_attr("FCompute", PinvScalarRcondOpForward); + +#endif + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index c8c086ac2cbf..136913b6bfd9 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -502,6 +502,27 @@ def _add_workload_linalg_tensorsolve(): OpArgMngr.add_workload('linalg.tensorsolve', np.array(a_np, dtype=dtype), np.array(b_np, dtype=dtype), axes) +def _add_workload_linalg_pinv(): + shapes = [ + ((1, 1), ()), + ((5, 5), ()), + ((5, 6), ()), + ((6, 5), ()), + ((2, 3, 3), (1,)), + ((4, 6, 5), (4,)), + ((2, 2, 3, 4), (2, 2)), + ] + dtypes = (np.float32, np.float64) + for dtype in dtypes: + for a_shape, rcond_shape in shapes: + hermitian = False + a_np = _np.random.uniform(-10.0, 10.0, a_shape) + a_np = _np.array(a_np, dtype=dtype) + rcond_np = _np.random.uniform(0., 0.1, rcond_shape) + rcond_np = _np.array(rcond_np, dtype=dtype) + OpArgMngr.add_workload('linalg.pinv', np.array(a_np, dtype=dtype), np.array(rcond_np, dtype=dtype), hermitian) + + def _add_workload_linalg_slogdet(): OpArgMngr.add_workload('linalg.slogdet', np.array(_np.ones((2, 2)), dtype=np.float32)) OpArgMngr.add_workload('linalg.slogdet', np.array(_np.ones((0, 1, 1)), dtype=np.float64)) @@ -1690,6 +1711,7 @@ def _prepare_workloads(): _add_workload_linalg_det() _add_workload_linalg_tensorinv() _add_workload_linalg_tensorsolve() + _add_workload_linalg_pinv() _add_workload_linalg_slogdet() _add_workload_trace() _add_workload_tril() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index d952bf7f33dc..8d54a53b8bc4 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4480,6 +4480,85 @@ def newInvertibleMatrix_2D(shape, max_cond=4): check_tensorsolve(mx_out, a.asnumpy(), b.asnumpy(), axes) +@with_seed() +@use_np +def test_np_linalg_pinv(): + class TestPinv(HybridBlock): + def __init__(self, hermitian): + super(TestPinv, self).__init__() + self._hermitian = hermitian + + def hybrid_forward(self, F, a, rcond=1e-15): + return F.np.linalg.pinv(a, rcond, hermitian=self._hermitian) + + def check_pinv(x, a_np, rcond_np, hermitian, use_rcond): + try: + if use_rcond: + x_expected = _np.linalg.pinv(a_np, rcond_np, hermitian=hermitian) + else: + x_expected = _np.linalg.pinv(a_np, hermitian=hermitian) + except Exception as e: + print("a:", a_np) + print("a shape:", a_np.shape) + if use_rcond: + print("rcond_np", rcond_np) + print("b rcond_np:", rcond_np.shape) + print(e) + else: + assert x.shape == x_expected.shape + assert_almost_equal(x.asnumpy(), x_expected, rtol=rtol, atol=atol) + + shapes = [ + ((1, 1), ()), + ((5, 5), ()), + ((5, 6), ()), + ((6, 5), ()), + ((2, 3, 3), (1,)), + ((2, 3, 3), (2,)), + ((2, 3, 4), (2,)), + ((2, 4, 3), (1,)), + ((4, 5, 6), ()), + ((4, 5, 6), (1,)), + ((4, 6, 5), (4,)), + ((2, 2, 4, 3), (1,)), + ((2, 2, 4, 3), (2,)), + ((2, 2, 4, 3), (1, 1)), + ((2, 2, 4, 3), (1, 2)), + ((2, 2, 4, 3), (2, 1)), + ((2, 2, 4, 3), (2, 2)), + ((2, 2, 3, 4), (1,)), + ((2, 2, 3, 4), (2,)), + ((2, 2, 3, 4), (1, 1)), + ((2, 2, 3, 4), (1, 2)), + ((2, 2, 3, 4), (2, 1)), + ((2, 2, 3, 4), (2, 2)), + ] + dtypes = ['float32', 'float64'] + for dtype in dtypes: + for a_shape, rcond_shape in shapes: + for use_rcond, hybridize in itertools.product([True, False], [True, False]): + rtol = 1e-2 if dtype == 'float32' else 1e-3 + atol = 1e-4 if dtype == 'float32' else 1e-5 + hermitian = False + test_pinv = TestPinv(hermitian) + if hybridize: + test_pinv.hybridize() + + a_np = _np.random.uniform(-10.0, 10.0, a_shape) + a_np = _np.array(a_np, dtype=dtype) + rcond_np = _np.random.uniform(0., 0.1, rcond_shape) + rcond_np = _np.array(rcond_np, dtype=dtype) + a = np.array(a_np, dtype=dtype) + rcond = np.array(rcond_np, dtype=dtype) + if use_rcond: + mx_out = test_pinv(a, rcond) + else: + mx_out = test_pinv(a) + + # check tensorsolve validity + check_pinv(mx_out, a.asnumpy(), rcond.asnumpy(), hermitian, use_rcond) + + @with_seed() @use_np def test_np_linalg_det():