-
Notifications
You must be signed in to change notification settings - Fork 527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] RMM-only context destroyed error with Random Forest in loop #2632
Comments
I am able to reproduce the results as mentioned in the issue. On switching the predict model to "CPU" the memory leak persist. Therefore it seems that the error is caused due to some information/variable not being cleared while fitting the model. |
Thanks for digging into it @Salonijain27 ! Just to clarify, do you think it's memory actually not being freed around the |
Hi Nick, I think that while running the fit function the memory is being freed but RMM thinks memory isnt free.
Also i would like to add that cuml used to successfully run memory leakage tests for RF till the RMM library got updated and removed the support for the function : |
Another interesting find is that on moving
It seems like RF fit function is okay but when we delete the model and free the memory RMM thinks the memory isnt freed. |
Nice find @Salonijain27 . EDIT: Nevermind. This isn't actually true. |
This is seen with RF Classification as well:
Furthermore, just as seen in RF regressor, on moving the |
@Salonijain27 to test some other similar models to see if it's a systemic error |
@Salonijain27 @JohnZed , in the 2020-08-14 nightly (~9 AM EDT) the repeated RF Classifier and Regressor fit tests now cause a segfault rather than simply grow memory uncontrollably. On iteration two of the example Saloni added above, we fail. It appears that, on the second call to
The second time we segfault rather than get the stacktrace.
|
Perhaps this is related to the RMM suballocator changes? cc @jakirkham as I believe he's been exploring this a bit |
Would double check that the |
Thanks @jakirkham . Looks like my RMM was behind. This still happens in a fresh environment with RMM at
|
Thanks Nick! 😄 Yeah we are still seeing some issues with UCX-Py locally (though not on CI any more) ( rapidsai/ucx-py#578 ). So I suspect there are still some subtle things we need to identify. |
Can you guys turn on logging and share the logs? You'll want to look for allocs without frees in the logs. |
Just to add, this involves enabling logging and specifying the log filename. |
Thanks @harrism . I do see allocs without frees, but only if I include the So far, I've tested KNN (Reg/Clf), Random Forest (Reg/Clf) and Logistic Regression with the following script. Only Random Forest appears to have this issue. # to run: python rmm-model-logger.py rfr-logs.txt
import sys
import cudf
import cuml
import rmm
import numpy as np
logfilename = sys.argv[1]
# swap estimator class here
clf = cuml.ensemble.RandomForestClassifier
rmm.reinitialize(
pool_allocator=True,
managed_memory=False,
initial_pool_size=2e9,
logging=True,
devices=0,
log_file_name=logfilename,
)
X = cudf.DataFrame({"a": range(10), "b": range(10,20)}).astype("float32")
y = cudf.Series(np.random.choice([0, 1], 10))
for i in range(30):
print(i)
model = clf()
model.fit(X, y)
preds = model.predict(X) Logs: import pandas as pd
df = pd.read_csv("rfc-logs.dev0.txt")
print(df.Action.value_counts())
allocate 211
free 201
Name: Action, dtype: int64
df = pd.read_csv("rfr-logs.dev0.txt")
print(df.Action.value_counts())
allocate 204
free 189
Name: Action, dtype: int64 rfr-logs.dev0.txt Interestingly, if I run the script but comment out the import pandas as pd
df = pd.read_csv("rfc-fit-only-logs.dev0.txt")
print(df.Action.value_counts())
import pandas as pd
df = pd.read_csv("rfr-fit-only-logs.dev0.txt")
print(df.Action.value_counts())
free 185
allocate 185
Name: Action, dtype: int64
free 206
allocate 206
Name: Action, dtype: int64 rfc-fit-only-logs.dev0.txt Full traceback:
Environment:
|
Note that this code works (using the same handle for each rep):
which seems like more evidence that the previous handle (or the allocator on it) is being accessed after it goes out of scope. I'll look through the code more and see if there is any possibility of the allocator being cached by accident somewhere... |
Does What happens if you swap the order of these two statements? |
I believe swapping the order shouldn't matter as the |
Ah, didn't catch that before. Thanks. |
Another oddity (as seen in the backtrace above) is that the error typically occurs in cupy code when cupy is attempting to create a devicebuffer via RMM. Still would be consistent with RF somehow deleting or corrupting the allocator though, I suppose. |
Is this a single-process, single-threaded example? Or are there multiple processes? The cudaErrorContextIsDestroyed is not an error I see very often -- when I have seen it was usually due to forking a new process after the CUDA context was already created, so the child process did not get a context. |
@harrism it would be a single process single threaded example AFAIK. Another interesting finding, if we set the number of cuda streams in RF to 1 then the failure seems to completely go away (at least for me locally): import sys
import cudf
import cuml
import rmm
import numpy as np
logfilename = sys.argv[1]
# swap estimator class here
clf = cuml.ensemble.RandomForestRegressor
rmm.reinitialize(
pool_allocator=True,
managed_memory=False,
initial_pool_size=2e9,
logging=True,
devices=0,
log_file_name=logfilename,
)
X = cudf.DataFrame({"a": range(10), "b": range(10,20)}).astype("float32")
y = cudf.Series(np.random.choice([0, 1], 10))
for i in range(100):
model = clf(n_streams=1) # 2 or more and third iteration fails at fit
print(i)
model.fit(X, y)
preds = model.predict(X, predict_model="GPU") If I set it to anything greater than 1 then the failures come back. Also depending on whether you perform predict or change the order of things (for example creating a cuml handle in the loop), you get either a bad_alloc, invalid cuda resource or the context is destroyed errors, so its hard to read into the specific error yet. |
Adding to my comment above, RF is the only model that uses multiple streams if I remember correctly, so that would explain why the rest of the models don't show this. |
Interesting finding @dantegd . Could this possibly be related to stream order? |
It definitely could, particularly because you can make it fail at different parts with different errors depending on whether you do predict or not, doing multiple predicts, and shuffling the commands above. |
With two streams, I'm not seeing failures consistently at the same place. Sometimes I get to the 2nd iteration, sometimes the 3rd iteration before the failure. EDIT: You're faster :) |
Unfortunately, I still get failures with 1 stream. Running on CUDA 11:
|
I rolled back to an RMM from Aug 11 (before cnmem deprecation). With that version, I can't repro this crash at all. Even when adding back the predict. RMM built at
|
Running @dantegd 's code but with the older RMM and 8 streams, that test passes. |
Adding to the rabbit hole, with the latest RMM this loop passes: for i in range(1000):
model = cuml.RandomForestRegressor(n_streams=1)
print(i)
model.fit(X, y)
preds = model.predict(X, predict_model="GPU") but if we remove the predict it fails: for i in range(12):
model = cuml.RandomForestRegressor(n_streams=1)
print(i)
model.fit(X, y) |
We should also be careful to define what "passing" means. Does passing mean "doesn't error", or does passing mean "doesn't error and doesn't leak memory"? Based on the earlier comments, before the cnmem deprecation, these were not erroring but leaking memory |
passing == works perfectly fine no memory leak I can see |
OK sounds like there could be an issue with stream free lists in the new pool resource. Will look later today. Anything you can do to make the repro as simple as possible will help. |
We're likely to allocate memory from one stream, then kill that stream and then go back and allocate more from another stream. Does this trigger an issue in stream_ordered_memory_resource.hpp when it goes to do The RF code will create a bunch of streams and allocate on each one then kill all of them when the cuml handle gets cleaned up. |
Apologies that I don't know much about expected semantics in RMM, so maybe I'm misunderstanding the expected behavior, but a gtest like this in RMM fails:
This is similar to the random forest example.
|
Note that even if the error(s) happen at different places when calling the code, the most common place where I am reproducing the error is around the |
My little hack to get that new RMM test to pass -- don't throw an exception if the error is a cudaErrorInvalidResourceHandle, which indicates that the stream has been destroyed. Seems like it would be much nicer if we had a cudaStreamIsDestroyed API or something similar, but I don't think there's a better way to see if a stream is invalid that just doing something with it and looking for this error?
|
Thanks for the repro @JohnZed. I was able to simplify it even further. This repro will actually segfault.
As you identified, when we try and reclaim a block from another stream, we attempt to synchronize a stream that was already destroyed. Unfortunately this isn't guaranteed to return a |
Thanks @JohnZed and @jrhemstad for working to find gtest repros. Soooo much easier than Python repros. 👍 |
Fixed via PR 510 to the RMM repo. |
- Fix for MNMG TSVD (similar issue to [cudaErrorContextIsDestroyed in RandomForest](#2632 (comment))) - #4826 - MNMG Kmeans testing issue : modification of accuracy threshold - MNMG KNNRegressor testing issue : modification of input for testing - LabelEncoder documentation test issue : modification of pandas/cuDF display configuration - RandomForest testing issue : adjust number of estimators to the number of workers Authors: - Victor Lafargue (https://github.com/viclafargue) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: #4840
- Fix for MNMG TSVD (similar issue to [cudaErrorContextIsDestroyed in RandomForest](rapidsai#2632 (comment))) - rapidsai#4826 - MNMG Kmeans testing issue : modification of accuracy threshold - MNMG KNNRegressor testing issue : modification of input for testing - LabelEncoder documentation test issue : modification of pandas/cuDF display configuration - RandomForest testing issue : adjust number of estimators to the number of workers Authors: - Victor Lafargue (https://github.com/viclafargue) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: rapidsai#4840
It seems we may have an RMM-only memory leak with RandomForestRegressor. This could come up in a wide range of workloads, such as using RandomForestRegressor with RMM during hyper-parameter optimization.
In the following example:
Is it possible there is a place where RMM isn't getting visibility of a call to free memory?
Environment: 2020-07-31 nightly at ~ 9AM EDT
The text was updated successfully, but these errors were encountered: