Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] RMM-only context destroyed error with Random Forest in loop #2632

Closed
beckernick opened this issue Jul 31, 2020 · 40 comments
Closed

[BUG] RMM-only context destroyed error with Random Forest in loop #2632

beckernick opened this issue Jul 31, 2020 · 40 comments
Labels
? - Needs Triage Need team to review and classify bug Something isn't working

Comments

@beckernick
Copy link
Member

beckernick commented Jul 31, 2020

It seems we may have an RMM-only memory leak with RandomForestRegressor. This could come up in a wide range of workloads, such as using RandomForestRegressor with RMM during hyper-parameter optimization.

In the following example:

  • Without an RMM pool, repeatedly fitting the model, predicting, and deleting the model/predictions causes peak memory of 1.2GB
  • With an RMM pool, repeatedly fitting the model, predicting, and deleting the model/predictions causes memory to grow uncontrollably. This can be triggered by uncommenting the rmm related lines. After 15-17 iterations, we exhaust the entire 5 GB pool.

Is it possible there is a place where RMM isn't getting visibility of a call to free memory?

import cudf
import cuml
import rmm
import cupy as cp
from dask.utils import parse_bytes
from sklearn.datasets import make_regression

# cudf.set_allocator(pool=True, initial_pool_size=parse_bytes("5GB"))
# cp.cuda.set_allocator(rmm.rmm_cupy_allocator)

NFEATURES = 20

X, y = make_regression(
    n_samples=10000,
    n_features=NFEATURES,
    random_state=12,
)

X = X.astype("float32")
X = cp.asarray(X)
y = cp.asarray(y)

for i in range(30):
    print(i)
    clf = cuml.ensemble.RandomForestRegressor(n_estimators=50)
    clf.fit(X, y)
    preds = clf.predict(X)
    del clf, preds

Environment: 2020-07-31 nightly at ~ 9AM EDT

@beckernick beckernick added bug Something isn't working ? - Needs Triage Need team to review and classify labels Jul 31, 2020
@Salonijain27
Copy link
Contributor

I am able to reproduce the results as mentioned in the issue. On switching the predict model to "CPU" the memory leak persist. Therefore it seems that the error is caused due to some information/variable not being cleared while fitting the model.

@beckernick
Copy link
Member Author

beckernick commented Aug 3, 2020

Thanks for digging into it @Salonijain27 ! Just to clarify, do you think it's memory actually not being freed around the fit, or memory being freed that RMM thinks isn't free (perhaps because it doesn't own it / have visibility)?

@Salonijain27
Copy link
Contributor

Salonijain27 commented Aug 4, 2020

Hi Nick, I think that while running the fit function the memory is being freed but RMM thinks memory isnt free.
If we run the above example and remove the predict statement from it. We can still see the memory leakage.
Furthermore, if we just run the predict function in a loop, we do not see any leakage:

import cudf
import cuml
import rmm
import cupy as cp
from dask.utils import parse_bytes
from sklearn.datasets import make_regression

cudf.set_allocator(pool=True, initial_pool_size=parse_bytes("5GB"))
# cp.cuda.set_allocator(rmm.rmm_cupy_allocator)

NFEATURES = 20

X, y = make_regression(
    n_samples=10000,
    n_features=NFEATURES,
    random_state=12,
)

X = X.astype("float32")
X = cudf.DataFrame(X)
y = cudf.Series(y)
clf = cuml.ensemble.RandomForestRegressor(n_estimators=50)
clf.fit(X, y)

for i in range(30):
    print(i)
    preds = clf.predict(X)
    del preds

Also i would like to add that cuml used to successfully run memory leakage tests for RF till the RMM library got updated and removed the support for the function :
cuda.current_context().get_memory_info()

@Salonijain27
Copy link
Contributor

Another interesting find is that on moving clf = cuml.ensemble.RandomForestRegressor(n_estimators=50) out of the for loop we dont see the memory leakage any more.

import cudf
import cuml
import rmm
import cupy as cp
from dask.utils import parse_bytes
from sklearn.datasets import make_regression

cudf.set_allocator(pool=True, initial_pool_size=parse_bytes("2GB"))
# cp.cuda.set_allocator(rmm.rmm_cupy_allocator)

NFEATURES = 20

X, y = make_regression(
    n_samples=10000,
    n_features=NFEATURES,
    random_state=12,
)

X = X.astype("float32")
y = y.astype("float32")
X = cudf.DataFrame(X)
y = cudf.Series(y)
clf = cuml.ensemble.RandomForestRegressor(n_estimators=50)

for i in range(30):
    print(i)
    clf.fit(X, y)
    preds = clf.predict(X)
    del preds

It seems like RF fit function is okay but when we delete the model and free the memory RMM thinks the memory isnt freed.

@beckernick
Copy link
Member Author

beckernick commented Aug 4, 2020

Nice find @Salonijain27 . I wonder if it could be the Treelite handle. From a quick test, it looks like RF Regression sets a Treelite handle and thus enters branching statement in _reset_forest_data to call TreeliteModel.free_treelite_model(self.treelite_handle), but RF Classification does not.

EDIT: Nevermind. This isn't actually true.

@Salonijain27
Copy link
Contributor

This is seen with RF Classification as well:

import cudf
import cuml
import rmm
import cupy as cp
from dask.utils import parse_bytes
from sklearn.datasets import make_regression, make_classification

cudf.set_allocator(pool=True, initial_pool_size=parse_bytes("2GB"))
# cp.cuda.set_allocator(rmm.rmm_cupy_allocator)

NFEATURES = 110

X, y = make_classification(
    n_samples=10000,
    n_features=NFEATURES,
    random_state=12,
    #n_classes=20,
)

X = X.astype("float32")
y = y.astype("int32")
X = cudf.DataFrame(X)
y = cudf.Series(y)
#clf = cuml.ensemble.RandomForestClassifier(n_estimators=50)
import pdb
pdb.set_trace()
for i in range(30):
    print(i)
    clf = cuml.ensemble.RandomForestClassifier(n_estimators=50)
    clf.fit(X, y)
    preds = clf.predict(X)
    del preds

Furthermore, just as seen in RF regressor, on moving the cuml.ensemble.RandomForestClassifier(n_estimators=50) outside of the for loop the memory leakage is not seen

@JohnZed
Copy link
Contributor

JohnZed commented Aug 13, 2020

@Salonijain27 to test some other similar models to see if it's a systemic error

@beckernick
Copy link
Member Author

beckernick commented Aug 14, 2020

@Salonijain27 @JohnZed , in the 2020-08-14 nightly (~9 AM EDT) the repeated RF Classifier and Regressor fit tests now cause a segfault rather than simply grow memory uncontrollably.

On iteration two of the example Saloni added above, we fail. It appears that, on the second call to fit, we're not cleanly creating a new buffer

python rfc-test.py
> /raid/nicholasb/rfc-test.py(66)<module>()
-> for i in range(30):
(Pdb) c
0
1
Traceback (most recent call last):
  File "rfc-test.py", line 66, in <module>
    for i in range(30):
  File "cuml/ensemble/randomforestclassifier.pyx", line 411, in cuml.ensemble.randomforestclassifier.RandomForestClassifier.fit
  File "/raid/nicholasb/miniconda3/envs/rapids-everywhere/lib/python3.7/site-packages/cuml/common/memory_utils.py", line 56, in cupy_rmm_wrapper
    return func(*args, **kwargs)
  File "cuml/ensemble/randomforest_common.pyx", line 239, in cuml.ensemble.randomforest_common.BaseRandomForestModel._dataset_setup_for_fit
  File "/raid/nicholasb/miniconda3/envs/rapids-everywhere/lib/python3.7/site-packages/cupy/manipulation/add_remove.py", line 66, in unique
    ar = cupy.asarray(ar).flatten()
  File "cupy/core/core.pyx", line 554, in cupy.core.core.ndarray.flatten
  File "cupy/core/core.pyx", line 566, in cupy.core.core.ndarray.flatten
  File "cupy/core/_routines_manipulation.pyx", line 180, in cupy.core._routines_manipulation._ndarray_flatten
  File "cupy/core/core.pyx", line 436, in cupy.core.core.ndarray.copy
  File "cupy/core/core.pyx", line 392, in cupy.core.core.ndarray.astype
  File "cupy/core/core.pyx", line 134, in cupy.core.core.ndarray.__init__
  File "cupy/cuda/memory.pyx", line 544, in cupy.cuda.memory.alloc
  File "/raid/nicholasb/miniconda3/envs/rapids-everywhere/lib/python3.7/site-packages/rmm/rmm.py", line 270, in rmm_cupy_allocator
    buf = librmm.device_buffer.DeviceBuffer(size=nbytes)
  File "rmm/_lib/device_buffer.pyx", line 70, in rmm._lib.device_buffer.DeviceBuffer.__cinit__
RuntimeError: CUDA error at: ../include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp365: cudaErrorContextIsDestroyed context is destroyed

The second time we segfault rather than get the stacktrace.

python rfc-test.py
> /raid/nicholasb/rfc-test.py(66)<module>()
-> for i in range(30):
(Pdb) c
0
1
Segmentation fault (core dumped)

@beckernick
Copy link
Member Author

Perhaps this is related to the RMM suballocator changes? cc @jakirkham as I believe he's been exploring this a bit

@beckernick beckernick changed the title [BUG] RMM-only memory leak with RandomForestRegressor [BUG] RMM-only memory leak with Random Forest Aug 14, 2020
@jakirkham
Copy link
Member

Would double check that the rmm package used includes Mark's recent fix in PR ( rapidsai/rmm#498 ).

@beckernick
Copy link
Member Author

beckernick commented Aug 14, 2020

Thanks @jakirkham . Looks like my RMM was behind.

This still happens in a fresh environment with RMM at d8ad1eb (last PR merged rapidsai/rmm#499 including rapidsai/rmm#498)

conda list | grep "rmm\|cudf\|cuml\|numba\|cupy\|rapids"
# packages in environment at /raid/nicholasb/miniconda3/envs/rapids-everywhere:
cudf                      0.16.0a200814   py37_g8e4f70e5a_636    rapidsai-nightly
cuml                      0.16.0a200812   cuda10.2_py37_g4c1136ee1_229    rapidsai-nightly
cupy                      7.7.0            py37h940342b_0    conda-forge
dask-cudf                 0.16.0a200814   py37_g8e4f70e5a_636    rapidsai-nightly
faiss-proc                1.0.0                      cuda    rapidsai-nightly
libcudf                   0.16.0a200814   cuda10.2_g8e4f70e5a_636    rapidsai-nightly
libcuml                   0.16.0a200812   cuda10.2_g4c1136ee1_229    rapidsai-nightly
libcumlprims              0.15.0a200720       cuda10.2_57    rapidsai-nightly
librmm                    0.16.0a200814   cuda10.2_gd8ad1eb_194    rapidsai-nightly
libxgboost                1.1.0dev.rapidsai0.15      cuda10.2_1    rapidsai-nightly
numba                     0.50.1           py37h0da4684_1    conda-forge
py-xgboost                1.1.0dev.rapidsai0.15  cuda10.2py37_1    rapidsai-nightly
rmm                       0.16.0a200814   py37_gd8ad1eb_194    rapidsai-nightly
ucx                       1.8.1+g6b29558       ha5db111_0    rapidsai-nightly
ucx-py                    0.16.0a200814+g6b29558         py37_57    rapidsai-nightly
xgboost                   1.1.0dev.rapidsai0.15  cuda10.2py37_1    rapidsai-nightly

@jakirkham
Copy link
Member

Thanks Nick! 😄

Yeah we are still seeing some issues with UCX-Py locally (though not on CI any more) ( rapidsai/ucx-py#578 ). So I suspect there are still some subtle things we need to identify.

@harrism
Copy link
Member

harrism commented Aug 14, 2020

Can you guys turn on logging and share the logs?

You'll want to look for allocs without frees in the logs.

@jakirkham
Copy link
Member

Just to add, this involves enabling logging and specifying the log filename.

@beckernick
Copy link
Member Author

beckernick commented Aug 17, 2020

Thanks @harrism . I do see allocs without frees, but only if I include the model.predict. Regardless of whether I just use fit or use both fit and predict, the context appears to be destroyed. It also appears to only occur when inside a Python loop, as Saloni noted above.

So far, I've tested KNN (Reg/Clf), Random Forest (Reg/Clf) and Logistic Regression with the following script. Only Random Forest appears to have this issue.

# to run: python rmm-model-logger.py rfr-logs.txt
import sys

import cudf
import cuml
import rmm
import numpy as np


logfilename = sys.argv[1]

# swap estimator class here
clf = cuml.ensemble.RandomForestClassifier

rmm.reinitialize(
    pool_allocator=True,
    managed_memory=False,
    initial_pool_size=2e9,
    logging=True,
    devices=0,
    log_file_name=logfilename,
)

X = cudf.DataFrame({"a": range(10), "b": range(10,20)}).astype("float32")
y = cudf.Series(np.random.choice([0, 1], 10))

for i in range(30):
    print(i)
    model = clf()
    model.fit(X, y)
    preds = model.predict(X)

Logs:

import pandas as pddf = pd.read_csv("rfc-logs.dev0.txt")
print(df.Action.value_counts())
​
allocate    211
free        201
Name: Action, dtype: int64

df = pd.read_csv("rfr-logs.dev0.txt")
print(df.Action.value_counts())
allocate    204
free        189
Name: Action, dtype: int64

rfr-logs.dev0.txt
rfc-logs.dev0.txt

Interestingly, if I run the script but comment out the preds = model.predict(X) line, I still get the destroyed context but the allocs match the frees.

import pandas as pddf = pd.read_csv("rfc-fit-only-logs.dev0.txt")
print(df.Action.value_counts())
​
import pandas as pddf = pd.read_csv("rfr-fit-only-logs.dev0.txt")
print(df.Action.value_counts())
free        185
allocate    185
Name: Action, dtype: int64
free        206
allocate    206
Name: Action, dtype: int64

rfc-fit-only-logs.dev0.txt
rfr-fit-only-logs.dev0.txt

Full traceback:

python rmm-model-logger.py rfr-fit-only-logs.txt
0
1
Traceback (most recent call last):
  File "rmm-model-logger.py", line 30, in <module>
    model.fit(X, y)
  File "cuml/ensemble/randomforestregressor.pyx", line 393, in cuml.ensemble.randomforestregressor.RandomForestRegressor.fit
  File "/raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817/lib/python3.7/site-packages/cuml/common/memory_utils.py", line 56, in cupy_rmm_wrapper
    return func(*args, **kwargs)
  File "cuml/ensemble/randomforest_common.pyx", line 251, in cuml.ensemble.randomforest_common.BaseRandomForestModel._dataset_setup_for_fit
  File "/raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817/lib/python3.7/site-packages/cuml/common/memory_utils.py", line 56, in cupy_rmm_wrapper
    return func(*args, **kwargs)
  File "/raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817/lib/python3.7/site-packages/cuml/common/input_utils.py", line 188, in input_to_cuml_array
    X = convert_dtype(X, to_dtype=convert_to_dtype)
  File "/raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817/lib/python3.7/site-packages/cuml/common/memory_utils.py", line 56, in cupy_rmm_wrapper
    return func(*args, **kwargs)
  File "/raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817/lib/python3.7/site-packages/cuml/common/input_utils.py", line 459, in convert_dtype
    would_lose_info = _typecast_will_lose_information(X, to_dtype)
  File "/raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817/lib/python3.7/site-packages/cuml/common/input_utils.py", line 504, in _typecast_will_lose_information
    (X < target_dtype_range.min) |
  File "/raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817/lib/python3.7/site-packages/cudf/core/series.py", line 1537, in __lt__
    return self._binaryop(other, "lt")
  File "/raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817/lib/python3.7/site-packages/cudf/core/series.py", line 1083, in _binaryop
    outcol = lhs._column.binary_operator(fn, rhs, reflect=reflect)
  File "/raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817/lib/python3.7/site-packages/cudf/core/column/numerical.py", line 100, in binary_operator
    lhs=self, rhs=rhs, op=binop, out_dtype=out_dtype, reflect=reflect
  File "/raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817/lib/python3.7/site-packages/cudf/core/column/numerical.py", line 472, in _numeric_column_binop
    out = libcudf.binaryop.binaryop(lhs, rhs, op, out_dtype)
  File "cudf/_lib/binaryop.pyx", line 200, in cudf._lib.binaryop.binaryop
  File "cudf/_lib/scalar.pyx", line 361, in cudf._lib.scalar.as_scalar
  File "cudf/_lib/scalar.pyx", line 81, in cudf._lib.scalar.Scalar.__init__
  File "cudf/_lib/scalar.pyx", line 174, in cudf._lib.scalar._set_numeric_from_np_scalar
RuntimeError: CUDA error at: ../include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp365: cudaErrorContextIsDestroyed context is destroyed

Environment:

conda list | grep "rmm\|cudf\|cuml\|numba\|cupy\|rapids"
# packages in environment at /raid/nicholasb/miniconda3/envs/rapids-tpcxbb-20200817:
cudf                      0.15.0a200817   py37_g1778921b0_4666    rapidsai-nightly
cuml                      0.15.0a200817   cuda10.2_py37_g1e5b7d348_1979    rapidsai-nightly
cupy                      7.7.0            py37h940342b_0    conda-forge
dask-cuda                 0.15.0a200817          py37_117    rapidsai-nightly
dask-cudf                 0.15.0a200817   py37_g1778921b0_4666    rapidsai-nightly
faiss-proc                1.0.0                      cuda    rapidsai-nightly
libcudf                   0.15.0a200817   cuda10.2_g1778921b0_4666    rapidsai-nightly
libcuml                   0.15.0a200817   cuda10.2_g1e5b7d348_1979    rapidsai-nightly
libcumlprims              0.15.0a200812       cuda10.2_61    rapidsai-nightly
librmm                    0.15.0a200817   cuda10.2_g17efc89_665    rapidsai-nightly
numba                     0.50.1           py37h0da4684_1    conda-forge
rmm                       0.15.0a200817   py37_g17efc89_665    rapidsai-nightly
ucx                       1.8.1+g6b29558       cuda10.2_0    rapidsai-nightly
ucx-proc                  1.0.0                       gpu    rapidsai-nightly
ucx-py                    0.15.0a200817+g6b29558        py37_203    rapidsai-nightly

cc @jakirkham @Salonijain27

@beckernick beckernick changed the title [BUG] RMM-only memory leak with Random Forest [BUG] RMM-only context destroyed error with Random Forest in loop Aug 17, 2020
@JohnZed
Copy link
Contributor

JohnZed commented Aug 17, 2020

Note that this code works (using the same handle for each rep):

rmm.reinitialize(pool_allocator=True,
                 initial_pool_size=2e9)

outer_h = cuml.Handle(8)
for i in range(12):
    # inner_h = cuml.Handle(8)
    h = outer_h
    rf = cuml.RandomForestClassifier(handle=h)
    rf.fit(X, y)
    print(i)

print("Done")

which seems like more evidence that the previous handle (or the allocator on it) is being accessed after it goes out of scope. I'll look through the code more and see if there is any possibility of the allocator being cached by accident somewhere...

@harrism
Copy link
Member

harrism commented Aug 18, 2020

clf = cuml.ensemble.RandomForestClassifier

rmm.reinitialize(
    pool_allocator=True,
    managed_memory=False,
    initial_pool_size=2e9,
    logging=True,
    devices=0,
    log_file_name=logfilename,
)

Does cuml.ensemble.RandomForestClassifier happen to use the default RMM resource to allocate device memory? If so, that memory may be orphaned or invalid after calling rmm.reinitialize.

What happens if you swap the order of these two statements?

@beckernick
Copy link
Member Author

clf = cuml.ensemble.RandomForestClassifier

rmm.reinitialize(
    pool_allocator=True,
    managed_memory=False,
    initial_pool_size=2e9,
    logging=True,
    devices=0,
    log_file_name=logfilename,
)

Does cuml.ensemble.RandomForestClassifier happen to use the default RMM resource to allocate device memory? If so, that memory may be orphaned or invalid after calling rmm.reinitialize.

What happens if you swap the order of these two statements?

I believe swapping the order shouldn't matter as the clf = line is just aliasing the constructor. It's not called until farther down with model = clf(), after the reinitialize

@harrism
Copy link
Member

harrism commented Aug 18, 2020

Ah, didn't catch that before. Thanks.

@JohnZed
Copy link
Contributor

JohnZed commented Aug 18, 2020

Another oddity (as seen in the backtrace above) is that the error typically occurs in cupy code when cupy is attempting to create a devicebuffer via RMM. Still would be consistent with RF somehow deleting or corrupting the allocator though, I suppose.

@harrism
Copy link
Member

harrism commented Aug 18, 2020

Is this a single-process, single-threaded example? Or are there multiple processes? The cudaErrorContextIsDestroyed is not an error I see very often -- when I have seen it was usually due to forking a new process after the CUDA context was already created, so the child process did not get a context.

@dantegd
Copy link
Member

dantegd commented Aug 18, 2020

@harrism it would be a single process single threaded example AFAIK. Another interesting finding, if we set the number of cuda streams in RF to 1 then the failure seems to completely go away (at least for me locally):

import sys

import cudf
import cuml
import rmm
import numpy as np


logfilename = sys.argv[1]

# swap estimator class here
clf = cuml.ensemble.RandomForestRegressor

rmm.reinitialize(
    pool_allocator=True,
    managed_memory=False,
    initial_pool_size=2e9,
    logging=True,
    devices=0,
    log_file_name=logfilename,
)

X = cudf.DataFrame({"a": range(10), "b": range(10,20)}).astype("float32")
y = cudf.Series(np.random.choice([0, 1], 10))

for i in range(100):
    model = clf(n_streams=1) # 2 or more and third iteration fails at fit
    print(i)
    model.fit(X, y)
    preds = model.predict(X, predict_model="GPU")

If I set it to anything greater than 1 then the failures come back. Also depending on whether you perform predict or change the order of things (for example creating a cuml handle in the loop), you get either a bad_alloc, invalid cuda resource or the context is destroyed errors, so its hard to read into the specific error yet.

@dantegd
Copy link
Member

dantegd commented Aug 18, 2020

Adding to my comment above, RF is the only model that uses multiple streams if I remember correctly, so that would explain why the rest of the models don't show this.

@beckernick
Copy link
Member Author

Interesting finding @dantegd . Could this possibly be related to stream order?

@dantegd
Copy link
Member

dantegd commented Aug 18, 2020

It definitely could, particularly because you can make it fail at different parts with different errors depending on whether you do predict or not, doing multiple predicts, and shuffling the commands above.

@beckernick
Copy link
Member Author

beckernick commented Aug 18, 2020

With two streams, I'm not seeing failures consistently at the same place. Sometimes I get to the 2nd iteration, sometimes the 3rd iteration before the failure.

EDIT: You're faster :)

@JohnZed
Copy link
Contributor

JohnZed commented Aug 18, 2020

Unfortunately, I still get failures with 1 stream. Running on CUDA 11:

import rmm
import cuml
import cudf
import numpy as np
import time
import cupy as cp

rmm.reinitialize(pool_allocator=True,
                 initial_pool_size=2e9)

X = cudf.DataFrame({"a": range(10), "b": range(10,20)}).astype("float32")
y = cudf.Series(np.random.choice([0, 1], 10)).astype(np.int32)


for i in range(12):
    inner_h = cuml.Handle(1)
    rf = cuml.RandomForestClassifier(handle=inner_h, n_streams=1)

    print("fit")
    rf.fit(X, y)
    print(i)

print("Done")

@JohnZed
Copy link
Contributor

JohnZed commented Aug 18, 2020

I rolled back to an RMM from Aug 11 (before cnmem deprecation). With that version, I can't repro this crash at all. Even when adding back the predict. RMM built at commit c3e3fe60672ec22b6219dd45bbad4b72d195ac10 (HEAD -> fix-debug-build) (checked out from PR 479). Code:

import rmm
import cuml
import cudf
import numpy as np
import time
import cupy as cp

rmm.reinitialize(pool_allocator=True,
                 initial_pool_size=2e9)

X = cudf.DataFrame({"a": range(10), "b": range(10,20)}).astype("float32")
y = cudf.Series(np.random.choice([0, 1], 10)).astype(np.int32)

for i in range(20):
    inner_h = cuml.Handle(4)
    rf = cuml.RandomForestClassifier(handle=inner_h)

    print("fit")
    rf.fit(X, y)

    print("predict")
    y_out = rf.predict(X)
    print(i)

print("Done")

@JohnZed
Copy link
Contributor

JohnZed commented Aug 18, 2020

Running @dantegd 's code but with the older RMM and 8 streams, that test passes.

@dantegd
Copy link
Member

dantegd commented Aug 18, 2020

Adding to the rabbit hole, with the latest RMM this loop passes:

for i in range(1000):
    model = cuml.RandomForestRegressor(n_streams=1)
    print(i)
    model.fit(X, y)
    preds = model.predict(X, predict_model="GPU")

but if we remove the predict it fails:

for i in range(12):
    model = cuml.RandomForestRegressor(n_streams=1)
    print(i)
    model.fit(X, y)

@beckernick
Copy link
Member Author

beckernick commented Aug 18, 2020

We should also be careful to define what "passing" means. Does passing mean "doesn't error", or does passing mean "doesn't error and doesn't leak memory"? Based on the earlier comments, before the cnmem deprecation, these were not erroring but leaking memory

@dantegd
Copy link
Member

dantegd commented Aug 18, 2020

passing == works perfectly fine no memory leak I can see

@harrism
Copy link
Member

harrism commented Aug 18, 2020

OK sounds like there could be an issue with stream free lists in the new pool resource. Will look later today. Anything you can do to make the repro as simple as possible will help.

@JohnZed
Copy link
Contributor

JohnZed commented Aug 18, 2020

We're likely to allocate memory from one stream, then kill that stream and then go back and allocate more from another stream. Does this trigger an issue in stream_ordered_memory_resource.hpp when it goes to do get_block_from_other_stream and synchronize this now-dead stream that allocated the memory at the beginning? Maybe there is already something that already handles this case...

The RF code will create a bunch of streams and allocate on each one then kill all of them when the cuml handle gets cleaned up.

@JohnZed
Copy link
Contributor

JohnZed commented Aug 18, 2020

Apologies that I don't know much about expected semantics in RMM, so maybe I'm misunderstanding the expected behavior, but a gtest like this in RMM fails:

  • Allocate device buffers in a loop, each on a different stream, all from a pool
  • Destroy the streams
  • Allocate some memory on the default stream

This is similar to the random forest example.

diff --git a/tests/mr/device/pool_mr_tests.cpp b/tests/mr/device/pool_mr_tests.cpp
index f28bcbd..db8abd5 100644
--- a/tests/mr/device/pool_mr_tests.cpp
+++ b/tests/mr/device/pool_mr_tests.cpp
@@ -19,6 +19,7 @@
 #include <rmm/mr/device/device_memory_resource.hpp>
 #include <rmm/mr/device/per_device_resource.hpp>
 #include <rmm/mr/device/pool_memory_resource.hpp>
+#include <rmm/device_buffer.hpp>
 
 #include <gtest/gtest.h>
 
@@ -69,6 +70,24 @@ TEST(PoolTest, ForceGrowth)
   EXPECT_NO_THROW(mr.allocate(1000));
 }
 
+TEST(PoolTest, MultiStream)
+{
+  Pool mr{rmm::mr::get_current_device_resource(), 0};
+  const int n_streams = 8;
+  cudaStream_t streams[n_streams];
+  const int size = 10000;
+  
+  for (int i=0; i < n_streams; i++) {
+    cudaStreamCreate(&streams[i]);
+    EXPECT_NO_THROW(rmm::device_buffer buff(size, streams[i], &mr));
+  }
+  for (int i=0; i < n_streams; i++) {
+    cudaStreamDestroy(streams[i]);
+  }
+  
+  EXPECT_NO_THROW(mr.allocate(size));
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace rmm

@dantegd
Copy link
Member

dantegd commented Aug 18, 2020

Note that even if the error(s) happen at different places when calling the code, the most common place where I am reproducing the error is around the get_block_from_other_stream function in stream_ordered_memory_resource, particularly:

https://github.com/rapidsai/rmm/blob/7e5a65e96df9a58974a899b1761b00c027eb32d7/include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp#L365

@JohnZed
Copy link
Contributor

JohnZed commented Aug 18, 2020

My little hack to get that new RMM test to pass -- don't throw an exception if the error is a cudaErrorInvalidResourceHandle, which indicates that the stream has been destroyed. Seems like it would be much nicer if we had a cudaStreamIsDestroyed API or something similar, but I don't think there's a better way to see if a stream is invalid that just doing something with it and looking for this error?

diff --git a/include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp b/include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp
index f7ffff6..b7735ec 100644
--- a/include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp
+++ b/include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp
@@ -354,16 +354,23 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
         block_type const b = blocks.get_block(size);  // get the best fit block
 
         if (b.is_valid()) {
+          cudaError_t err;
+          
           // Since we found a block associated with a different stream, we have to insert a wait
           // on the stream's associated event into the allocating stream.
           // TODO: could eliminate this ifdef and have the same behavior for PTDS and non-PTDS
           // But the cudaEventRecord() on every free_block reduces performance significantly
 #ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
-          RMM_CUDA_TRY(cudaStreamWaitEvent(stream_event.stream, blocks_event.event, 0));
+          err = cudaStreamWaitEvent(stream_event.stream, blocks_event.event, 0);
 
 #else
-          RMM_CUDA_TRY(cudaStreamSynchronize(blocks_event.stream));
+          err = cudaStreamSynchronize(blocks_event.stream);
 #endif
+          if (err == cudaErrorInvalidResourceHandle) {
+            // ignore the issue, may represent an already-defunct stream
+          } else {
+            RMM_CUDA_TRY(err);
+          }
           // Move all the blocks to the requesting stream, since it has waited on them
           stream_free_blocks_[stream_event].insert(std::move(blocks));
           stream_free_blocks_.erase(s);
diff --git a/tests/device_buffer_tests.cu b/tests/device_buffer_tests.cu
index 920e70a..4e6ccb0 100644
--- a/tests/device_buffer_tests.cu
+++ b/tests/device_buffer_tests.cu
@@ -483,3 +483,4 @@ TYPED_TEST(DeviceBufferTest, ResizeBigger)
   // Resizing bigger means the data should point to a new allocation
   EXPECT_NE(old_data, buff.data());
 }
+

@jrhemstad
Copy link
Contributor

Thanks for the repro @JohnZed. I was able to simplify it even further. This repro will actually segfault.

TEST(PoolTest, TwoStreams)
{
  Pool mr{rmm::mr::get_current_device_resource(), 0};
  cudaStream_t stream;
  const int size = 10000;
  cudaStreamCreate(&stream);
  EXPECT_NO_THROW(rmm::device_buffer buff(size, stream, &mr));
  cudaStreamDestroy(stream);
  mr.allocate(size);
}

As you identified, when we try and reclaim a block from another stream, we attempt to synchronize a stream that was already destroyed. Unfortunately this isn't guaranteed to return a cudaErrorInvalidResourceHandle and can actually segfault.

@harrism
Copy link
Member

harrism commented Aug 19, 2020

Thanks @JohnZed and @jrhemstad for working to find gtest repros. Soooo much easier than Python repros. 👍

@JohnZed
Copy link
Contributor

JohnZed commented Aug 20, 2020

Fixed via PR 510 to the RMM repo.

@JohnZed JohnZed closed this as completed Aug 20, 2020
rapids-bot bot pushed a commit that referenced this issue Sep 8, 2022
- Fix for MNMG TSVD (similar issue to [cudaErrorContextIsDestroyed in RandomForest](#2632 (comment)))
- #4826
- MNMG Kmeans testing issue : modification of accuracy threshold
- MNMG KNNRegressor testing issue : modification of input for testing
- LabelEncoder documentation test issue : modification of pandas/cuDF display configuration
- RandomForest testing issue : adjust number of estimators to the number of workers

Authors:
  - Victor Lafargue (https://github.com/viclafargue)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #4840
jakirkham pushed a commit to jakirkham/cuml that referenced this issue Feb 27, 2023
- Fix for MNMG TSVD (similar issue to [cudaErrorContextIsDestroyed in RandomForest](rapidsai#2632 (comment)))
- rapidsai#4826
- MNMG Kmeans testing issue : modification of accuracy threshold
- MNMG KNNRegressor testing issue : modification of input for testing
- LabelEncoder documentation test issue : modification of pandas/cuDF display configuration
- RandomForest testing issue : adjust number of estimators to the number of workers

Authors:
  - Victor Lafargue (https://github.com/viclafargue)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4840
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage Need team to review and classify bug Something isn't working
Projects
None yet
Development

No branches or pull requests

7 participants