Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] DeviceQuantileDmatrix failing with memory errors on A100-80GB #6822

Closed
ayushdg opened this issue Apr 2, 2021 · 5 comments · Fixed by #6826
Closed

[Bug] DeviceQuantileDmatrix failing with memory errors on A100-80GB #6822

ayushdg opened this issue Apr 2, 2021 · 5 comments · Fixed by #6826

Comments

@ayushdg
Copy link

ayushdg commented Apr 2, 2021

On large datasets, DeviceQuantileDMatrix fails with memory errors that look like some kind of overflow/ptr bug.

Reproducer:

import cudf
import numpy as np
import xgboost as xgb

nrows=51130552
ncols=43
df = cudf.DataFrame()
for i in range(ncols-1):
    df[f"fea{i}"] = np.random.random(nrows)
    df[f"fea{i}"] = df[f"fea{i}"].astype("float32")
df['val'] = [0,1] * (nrows//2)
df['val'] = df['val'].astype("int8")

fea_cols=df.columns.difference(["val"])
dmat = xgb.DeviceQuantileDMatrix(data=df[fea_cols], label=df["val"])
del(df)

dxgb_gpu_params = {
    "max_depth": 8,
    "max_leaves": 2 ** 8,
    "alpha": 0.9,
    "eta": 0.1,
    "gamma": 0.1,
    "learning_rate": 0.1,
    "subsample": 1,
    "reg_lambda": 1,
    "tree_method": "gpu_hist",
    "objective": "binary:logistic",
}
NUM_BOOST_ROUND = 5


res = xgb.train(
    dxgb_gpu_params, dmat, num_boost_round=NUM_BOOST_ROUND,
)

The following scripts succeeds on a V100-32GB and A100-40GB but fails on A100-80GB

Stacktrace:

Traceback (most recent call last):
  File "xgb_error.py", line 19, in <module>
    dmat = xgb.DeviceQuantileDMatrix(data=df[fea_cols], label=df["val"])
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/xgboost/core.py", line 946, in __init__
    handle, feature_names, feature_types = init_device_quantile_dmatrix(
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/xgboost/data.py", line 787, in init_device_quantile_dmatrix
    _check_call(ret)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/xgboost/core.py", line 189, in _check_call
    raise XGBoostError(py_str(_LIB.XGBGetLastError()))
xgboost.core.XGBoostError: [15:44:36] /opt/conda/envs/rapids/conda-bld/xgboost_1616533850716/work/src/c_api/../data/../common/device_helpers.cuh:400: Memory allocation error on worker 0: Caching allocator
- Free memory: 58308362240
- Requested memory: 18446744073675998463

Stack trace:
  [bt] (0) /opt/conda/envs/rapids/lib/libxgboost.so(+0x14eb6f) [0x7fcf095a0b6f]
  [bt] (1) /opt/conda/envs/rapids/lib/libxgboost.so(dh::detail::ThrowOOMError(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned long)+0x3ad) [0x7fcf097da4bd]
  [bt] (2) /opt/conda/envs/rapids/lib/libxgboost.so(+0x413c76) [0x7fcf09865c76]
  [bt] (3) /opt/conda/envs/rapids/lib/libxgboost.so(thrust::detail::normal_iterator<thrust::device_ptr<xgboost::Entry> > thrust::cuda_cub::copy_if<thrust::detail::execute_with_allocator<dh::detail::XGBCachingDeviceAllocatorImpl<char>&, thrust::cuda_cub::execute_on_stream_base>, thrust::transform_iterator<__nv_dl_wrapper_t<__nv_dl_tag<void (*)(xgboost::data::CudfAdapterBatch const&, thrust::transform_iterator<__nv_dl_wrapper_t<__nv_dl_tag<void (*)(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int), &(void xgboost::common::ProcessSlidingWindow<xgboost::data::CudfAdapterBatch>(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int)), 1u>, xgboost::data::CudfAdapterBatch const>, thrust::counting_iterator<unsigned long long, thrust::use_default, thrust::use_default, thrust::use_default>, xgboost::data::COOTuple, thrust::use_default>, xgboost::common::Range1d, float, unsigned long, unsigned long, int, xgboost::HostDeviceVector<unsigned long>*, thrust::device_vector<unsigned long, dh::detail::XGBCachingDeviceAllocatorImpl<unsigned long> >*, thrust::device_vector<xgboost::Entry, dh::detail::XGBDefaultDeviceAllocatorImpl<xgboost::Entry> >*), &(void xgboost::common::detail::MakeEntriesFromAdapter<xgboost::data::CudfAdapterBatch, thrust::transform_iterator<__nv_dl_wrapper_t<__nv_dl_tag<void (*)(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int), &(void xgboost::common::ProcessSlidingWindow<xgboost::data::CudfAdapterBatch>(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int)), 1u>, xgboost::data::CudfAdapterBatch const>, thrust::counting_iterator<unsigned long long, thrust::use_default, thrust::use_default, thrust::use_default>, xgboost::data::COOTuple, thrust::use_default> >(xgboost::data::CudfAdapterBatch const&, thrust::transform_iterator<__nv_dl_wrapper_t<__nv_dl_tag<void (*)(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int), &(void xgboost::common::ProcessSlidingWindow<xgboost::data::CudfAdapterBatch>(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int)), 1u>, xgboost::data::CudfAdapterBatch const>, thrust::counting_iterator<unsigned long long, thrust::use_default, thrust::use_default, thrust::use_default>, xgboost::data::COOTuple, thrust::use_default>, xgboost::common::Range1d, float, unsigned long, unsigned long, int, xgboost::HostDeviceVector<unsigned long>*, thrust::device_vector<unsigned long, dh::detail::XGBCachingDeviceAllocatorImpl<unsigned long> >*, thrust::device_vector<xgboost::Entry, dh::detail::XGBDefaultDeviceAllocatorImpl<xgboost::Entry> >*)), 1u>, xgboost::data::CudfAdapterBatch const>, thrust::counting_iterator<unsigned long long, thrust::use_default, thrust::use_default, thrust::use_default>, xgboost::Entry, thrust::use_default>, thrust::detail::normal_iterator<thrust::device_ptr<xgboost::Entry> >, xgboost::data::IsValidFunctor>(thrust::cuda_cub::execution_policy<thrust::detail::execute_with_allocator<dh::detail::XGBCachingDeviceAllocatorImpl<char>&, thrust::cuda_cub::execute_on_stream_base> >&, thrust::transform_iterator<__nv_dl_wrapper_t<__nv_dl_tag<void (*)(xgboost::data::CudfAdapterBatch const&, thrust::transform_iterator<__nv_dl_wrapper_t<__nv_dl_tag<void (*)(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int), &(void xgboost::common::ProcessSlidingWindow<xgboost::data::CudfAdapterBatch>(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int)), 1u>, xgboost::data::CudfAdapterBatch const>, thrust::counting_iterator<unsigned long long, thrust::use_default, thrust::use_default, thrust::use_default>, xgboost::data::COOTuple, thrust::use_default>, xgboost::common::Range1d, float, unsigned long, unsigned long, int, xgboost::HostDeviceVector<unsigned long>*, thrust::device_vector<unsigned long, dh::detail::XGBCachingDeviceAllocatorImpl<unsigned long> >*, thrust::device_vector<xgboost::Entry, dh::detail::XGBDefaultDeviceAllocatorImpl<xgboost::Entry> >*), &(void xgboost::common::detail::MakeEntriesFromAdapter<xgboost::data::CudfAdapterBatch, thrust::transform_iterator<__nv_dl_wrapper_t<__nv_dl_tag<void (*)(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int), &(void xgboost::common::ProcessSlidingWindow<xgboost::data::CudfAdapterBatch>(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int)), 1u>, xgboost::data::CudfAdapterBatch const>, thrust::counting_iterator<unsigned long long, thrust::use_default, thrust::use_default, thrust::use_default>, xgboost::data::COOTuple, thrust::use_default> >(xgboost::data::CudfAdapterBatch const&, thrust::transform_iterator<__nv_dl_wrapper_t<__nv_dl_tag<void (*)(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int), &(void xgboost::common::ProcessSlidingWindow<xgboost::data::CudfAdapterBatch>(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int)), 1u>, xgboost::data::CudfAdapterBatch const>, thrust::counting_iterator<unsigned long long, thrust::use_default, thrust::use_default, thrust::use_default>, xgboost::data::COOTuple, thrust::use_default>, xgboost::common::Range1d, float, unsigned long, unsigned long, int, xgboost::HostDeviceVector<unsigned long>*, thrust::device_vector<unsigned long, dh::detail::XGBCachingDeviceAllocatorImpl<unsigned long> >*, thrust::device_vector<xgboost::Entry, dh::detail::XGBDefaultDeviceAllocatorImpl<xgboost::Entry> >*)), 1u>, xgboost::data::CudfAdapterBatch const>, thrust::counting_iterator<unsigned long long, thrust::use_default, thrust::use_default, thrust::use_default>, xgboost::Entry, thrust::use_default>, thrust::transform_iterator<__nv_dl_wrapper_t<__nv_dl_tag<void (*)(xgboost::data::CudfAdapterBatch const&, thrust::transform_iterator<__nv_dl_wrapper_t<__nv_dl_tag<void (*)(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int), &(void xgboost::common::ProcessSlidingWindow<xgboost::data::CudfAdapterBatch>(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int)), 1u>, xgboost::data::CudfAdapterBatch const>, thrust::counting_iterator<unsigned long long, thrust::use_default, thrust::use_default, thrust::use_default>, xgboost::data::COOTuple, thrust::use_default>, xgboost::common::Range1d, float, unsigned long, unsigned long, int, xgboost::HostDeviceVector<unsigned long>*, thrust::device_vector<unsigned long, dh::detail::XGBCachingDeviceAllocatorImpl<unsigned long> >*, thrust::device_vector<xgboost::Entry, dh::detail::XGBDefaultDeviceAllocatorImpl<xgboost::Entry> >*), &(void xgboost::common::detail::MakeEntriesFromAdapter<xgboost::data::CudfAdapterBatch, thrust::transform_iterator<__nv_dl_wrapper_t<__nv_dl_tag<void (*)(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int), &(void xgboost::common::ProcessSlidingWindow<xgboost::data::CudfAdapterBatch>(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int)), 1u>, xgboost::data::CudfAdapterBatch const>, thrust::counting_iterator<unsigned long long, thrust::use_default, thrust::use_default, thrust::use_default>, xgboost::data::COOTuple, thrust::use_default> >(xgboost::data::CudfAdapterBatch const&, thrust::transform_iterator<__nv_dl_wrapper_t<__nv_dl_tag<void (*)(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int), &(void xgboost::common::ProcessSlidingWindow<xgboost::data::CudfAdapterBatch>(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int)), 1u>, xgboost::data::CudfAdapterBatch const>, thrust::counting_iterator<unsigned long long, thrust::use_default, thrust::use_default, thrust::use_default>, xgboost::data::COOTuple, thrust::use_default>, xgboost::common::Range1d, float, unsigned long, unsigned long, int, xgboost::HostDeviceVector<unsigned long>*, thrust::device_vector<unsigned long, dh::detail::XGBCachingDeviceAllocatorImpl<unsigned long> >*, thrust::device_vector<xgboost::Entry, dh::detail::XGBDefaultDeviceAllocatorImpl<xgboost::Entry> >*)), 1u>, xgboost::data::CudfAdapterBatch const>, thrust::counting_iterator<unsigned long long, thrust::use_default, thrust::use_default, thrust::use_default>, xgboost::Entry, thrust::use_default>, thrust::detail::normal_iterator<thrust::device_ptr<xgboost::Entry> >, xgboost::data::IsValidFunctor)+0x227) [0x7fcf09874f87]
  [bt] (4) /opt/conda/envs/rapids/lib/libxgboost.so(void xgboost::common::ProcessSlidingWindow<xgboost::data::CudfAdapterBatch>(xgboost::data::CudfAdapterBatch const&, int, unsigned long, unsigned long, unsigned long, float, xgboost::common::SketchContainer*, int)+0x2c5) [0x7fcf09881835]
  [bt] (5) /opt/conda/envs/rapids/lib/libxgboost.so(+0x414541) [0x7fcf09866541]
  [bt] (6) /opt/conda/envs/rapids/lib/libxgboost.so(xgboost::data::IterativeDeviceDMatrix::Initialize(void*, float, int)+0xcdc) [0x7fcf0986725c]
  [bt] (7) /opt/conda/envs/rapids/lib/libxgboost.so(xgboost::DMatrix* xgboost::DMatrix::Create<void*, void*, void (void*), int (void*)>(void*, void*, void (*)(void*), int (*)(void*), float, int, int)+0xaf) [0x7fcf0962eb9f]
  [bt] (8) /opt/conda/envs/rapids/lib/libxgboost.so(XGDeviceQuantileDMatrixCreateFromCallback+0x25) [0x7fcf095b08e5]

Additional Info

The requested memory value - Requested memory: 18446744073675998463 from the stack trace is extremely close to the unsigned int64 limit

@hcho3
Copy link
Collaborator

hcho3 commented Apr 2, 2021

@ayushdg How did you install XGBoost? Did you build it from the source?

@trivialfis
Copy link
Member

Hmm .. somewhere inside thrust. Suspecting the thrust::sort.

@trivialfis
Copy link
Member

I just ran it with master branch on rtx8000, it works fine? What's your XGBoost version?

@ayushdg
Copy link
Author

ayushdg commented Apr 4, 2021

@ayushdg How did you install XGBoost? Did you build it from the source?

This is the one shipped with the rapids nightly containers: xgboost 1.3.3dev.rapidsai0.19

I just ran it with master branch on rtx8000, it works fine? What's your XGBoost version?

I tested on v100, and a100 40gb and it worked fine on those gpu's as well. Was failing specifically on the a100 80gb (though it has more memory). Not sure why.

@trivialfis
Copy link
Member

@ayushdg Thanks for sharing. I will test it on A100.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants