Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuml PCA shows cuda out of memory while GPU shows lots of vacancy. #4141

Closed
ShihengDuan opened this issue Aug 2, 2021 · 3 comments
Closed
Assignees
Labels
0 - Blocked Cannot progress due to external reasons

Comments

@ShihengDuan
Copy link

Hi all,

I'm new to cuML and want to use it to accelerate my PCA process for a large dataset. The input size is (10492, 55296) and I set 150 as n_components. The GPU is v100 with 32 GB memory. The problem is shown below.

MemoryError                               Traceback (most recent call last)
/glade/scratch/shiheng/ipykernel_138709/22359057.py in <module>
      1 import cuml
      2 
----> 3 ml_PCA = cuml.PCA(n_components=151).fit(df)
      4 np.cumsum(ml_PCA.explained_variance_ratio_*100)

~/miniconda3/envs/pytorch/lib/python3.7/site-packages/cuml/internals/api_decorators.py in inner_with_setters(*args, **kwargs)
    407                                 target_val=target_val)
    408 
--> 409                 return func(*args, **kwargs)
    410 
    411         @wraps(func)

cuml/decomposition/pca.pyx in cuml.decomposition.pca.PCA.fit()

MemoryError: std::bad_alloc: CUDA error at: /glade/u/home/shiheng/miniconda3/envs/pytorch/include/rmm/mr/device/cuda_memory_resource.hpp:69: cudaErrorMemoryAllocation out of memory

My code is pretty simple here.

import cudf
import cuml

df = cudf.DataFrame(psl.data.reshape(10492, -1))
ml_PCA = cuml.PCA(n_components=151).fit(df)
np.cumsum(ml_PCA.explained_variance_ratio_*100)

When it occurs, the GPU memory usage is only less than 7 GB. Any ideas or suggestions to deal with it?

Thanks!

@lowener
Copy link
Contributor

lowener commented Aug 4, 2021

Hi @ShihengDuan.
I looked into this issue and I was able to reproduce it. This is due to the number of columns in your data that is very high.
This leads to an integer overflow in our PCA implementation because we use an eigen decomposition, which requires the covariance matrix of size 55296 * 55296 which overflow for int32.

I tried locally to fix this line:

size_t len = prms.n_cols * prms.n_cols;

but it ends up in an invalid value in cuSolver:

RuntimeError: cuSOLVER error encountered at: file=~/miniconda3/envs/cuml_dev/include/raft/linalg/eig.cuh line=54: call='cusolverDnsyevd_bufferSize(cusolverH, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, n_rows, in, n_cols, eig_vals, &lwork)',
Reason=3:CUSOLVER_STATUS_INVALID_VALUE
Obtained 22 stack frames
#0 in ~/miniconda3/envs/cuml_dev/lib/libcuml++.so(_ZN4raft9exception18collect_call_stackEv+0x47) [0x7fd0fab27867]
#1 in ~/miniconda3//envs/cuml_dev/lib/libcuml++.so(_ZN4raft14cusolver_errorC2ERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE+0x6d) [0x7fd0fab7970d]
#2 in ~/miniconda3/envs/cuml_dev/lib/libcuml++.so(_ZN4raft6linalg5eigDCIfEEvRKNS_8handle_tEPKT_iiPS5_S8_P11CUstream_st+0xb59) [0x7fd0fac7c929]
#3 in ~/miniconda3//envs/cuml_dev/lib/libcuml++.so(_ZN2ML6calEigIfNS_6solverEEEvRKN4raft8handle_tEPT_S7_S7_RKNS_18paramsTSVDTemplateIT0_EEP11CUstream_st+0x94) [0x7fd0faeff984]
#4 in ~/miniconda3/envs/cuml_dev/lib/libcuml++.so(_ZN2ML16truncCompExpVarsIfNS_6solverEEEvRKN4raft8handle_tEPT_S7_S7_S7_NS_18paramsTSVDTemplateIT0_EEP11CUstream_st+0x1eb) [0x7fd0faeffd1b]

For the moment the only algorithm that can run a decomposition with this input size is cuml.IncrementalPCA. Supporting input of this size could be a future improvement.

@lowener
Copy link
Contributor

lowener commented Aug 16, 2021

I'm linking #1269 since it seems to be the same kind of error.
And also #2459.

@lowener lowener added the 0 - Blocked Cannot progress due to external reasons label Aug 16, 2021
@lowener
Copy link
Contributor

lowener commented Dec 21, 2021

This is solved by #4255 on cuml side, and CUDA 11.6 will solve this on cuSolver side.
I'm closing the issue for now.

@lowener lowener closed this as completed Dec 21, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
0 - Blocked Cannot progress due to external reasons
Projects
None yet
Development

No branches or pull requests

2 participants