Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segfault in multithreaded Cython code in memoryview __dealloc__ in scikit-learn #50

Closed
ogrisel opened this issue Apr 14, 2022 · 6 comments

Comments

@ogrisel
Copy link

ogrisel commented Apr 14, 2022

Here is the original reproducer:

pip install cython numpy scipy pytest
git clone https://github.com/scikit-learn/scikit-learn
cd scikit-learn
python setup.py develop
gdb --ex r --args python -m pytest -svlx -k "test_parallel[RandomForestClassifier]" sklearn/ensemble/tests/test_forest.py
Fatal Python error: Aborted

Stack (most recent call first):
  File "/home/ogrisel/code/scikit-learn/sklearn/tree/_classes.py", line 964 in fit
  File "/home/ogrisel/code/scikit-learn/sklearn/ensemble/_forest.py", line 189 in _parallel_build_trees
  File "/home/ogrisel/code/scikit-learn/sklearn/utils/fixes.py", line 117 in __call__
  File "/home/ogrisel/nogil-venv/lib/python3.9/site-packages/joblib/parallel.py", line 262 in <listcomp>
  File "/home/ogrisel/nogil-venv/lib/python3.9/site-packages/joblib/parallel.py", line 262 in __call__
  File "/home/ogrisel/nogil-venv/lib/python3.9/site-packages/joblib/_parallel_backends.py", line 595 in __call__
  File "/home/ogrisel/code/nogil/Lib/multiprocessing/pool.py", line 125 in worker
  File "/home/ogrisel/code/nogil/Lib/threading.py", line 886 in run
  File "/home/ogrisel/code/nogil/Lib/threading.py", line 935 in _bootstrap_inner
  File "/home/ogrisel/code/nogil/Lib/threading.py", line 906 in _bootstrap

Note that in this code, joblib is using simple Python-level threads instead of Python worker processes.

with either of the following GDB backtraces:

#0  __pthread_kill_implementation (no_tid=0, signo=6, threadid=140737175123520) at pthread_kill.c:44
#1  __pthread_kill_internal (signo=6, threadid=140737175123520) at pthread_kill.c:80
#2  __GI___pthread_kill (threadid=140737175123520, signo=signo@entry=6) at pthread_kill.c:91
#3  0x00007ffff7ce6476 in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007ffff7ccc7b7 in __GI_abort () at abort.c:79
#5  0x00007ffff7d2d606 in __libc_message (action=action@entry=do_abort, fmt=fmt@entry=0x7ffff7e7f13d "%s\n") at ../sysdeps/posix/libc_fatal.c:155
#6  0x00007ffff7d44afc in malloc_printerr (str=str@entry=0x7ffff7e81d18 "double free or corruption (fasttop)") at malloc.c:5543
#7  0x00007ffff7d463bb in _int_free (av=0x7ffff7ebdc60 <main_arena>, p=0x5555559e8c80, have_lock=0) at malloc.c:4426
#8  0x00007ffff7d48d05 in __GI___libc_free (mem=<optimized out>) at malloc.c:3278
#9  0x00007ffff0fbb0e2 in __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_2__dealloc__ (__pyx_v_self=0x44e5aa10790) at sklearn/tree/_tree.cpp:28305
#10 __pyx_memoryview___dealloc__ (__pyx_v_self=<sklearn.tree._tree.memoryview at remote 0x44e5aa10790>) at sklearn/tree/_tree.cpp:28112
#11 __pyx_tp_dealloc_memoryview (o=<sklearn.tree._tree.memoryview at remote 0x44e5aa10790>) at sklearn/tree/_tree.cpp:39929
#12 0x00007ffff1070177 in _Py_DECREF (op=<optimized out>) at /home/ogrisel/code/nogil/Include/object.h:569
#13 __Pyx_XDEC_MEMVIEW (lineno=24909, have_gil=1, memslice=0x44e5aa00520) at sklearn/tree/_criterion.c:28987
#14 __pyx_tp_dealloc_7sklearn_4tree_10_criterion_Criterion (o=<sklearn.tree._criterion.Gini at remote 0x44e5aa00500>) at sklearn/tree/_criterion.c:24909
#15 __pyx_tp_dealloc_7sklearn_4tree_10_criterion_ClassificationCriterion (o=<sklearn.tree._criterion.Gini at remote 0x44e5aa00500>) at sklearn/tree/_criterion.c:25026

at another run I got the following backtrace:

#0  __pthread_kill_implementation (no_tid=0, signo=6, threadid=140737191908928) at pthread_kill.c:44
#1  __pthread_kill_internal (signo=6, threadid=140737191908928) at pthread_kill.c:80
#2  __GI___pthread_kill (threadid=140737191908928, signo=signo@entry=6) at pthread_kill.c:91
#3  0x00007ffff7ce6476 in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007ffff7ccc7b7 in __GI_abort () at abort.c:79
#5  0x00007ffff7d2d606 in __libc_message (action=action@entry=do_abort, fmt=fmt@entry=0x7ffff7e7f13d "%s\n") at ../sysdeps/posix/libc_fatal.c:155
#6  0x00007ffff7d44afc in malloc_printerr (str=str@entry=0x7ffff7e81cc0 "free(): double free detected in tcache 2") at malloc.c:5543
#7  0x00007ffff7d46a4f in _int_free (av=0x7ffff7ebdc60 <main_arena>, p=0x5555559e8520, have_lock=0) at malloc.c:4360
#8  0x00007ffff7d48d05 in __GI___libc_free (mem=<optimized out>) at malloc.c:3278
#9  0x00007ffff102f352 in __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_2__dealloc__ (__pyx_v_self=0x4fb1d220850) at sklearn/tree/_splitter.c:16407
#10 __pyx_memoryview___dealloc__ (__pyx_v_self=<sklearn.tree._splitter.memoryview at remote 0x4fb1d220850>) at sklearn/tree/_splitter.c:16214
#11 __pyx_tp_dealloc_memoryview (o=<sklearn.tree._splitter.memoryview at remote 0x4fb1d220850>) at sklearn/tree/_splitter.c:27573
#12 0x00007ffff10318b5 in _Py_DECREF (op=<optimized out>) at /home/ogrisel/code/nogil/Include/object.h:569
#13 __Pyx_XDEC_MEMVIEW (lineno=26682, have_gil=1, memslice=0x4fb1d230400) at sklearn/tree/_splitter.c:29899
#14 __pyx_tp_dealloc_7sklearn_4tree_9_splitter_BaseDenseSplitter (o=<sklearn.tree._splitter.BestSplitter at remote 0x4fb1d230290>) at sklearn/tree/_splitter.c:26682
#15 0x000055555581315c in _PyEval_Fast (ts=0x555555ad8540, initial_acc=Register(as_int64 = 8), initial_pc=0x17 <error: Cannot access memory at address 0x17>) at Python/ceval.c:1125

So in both cases this is occurring in the __dealloc__ method of a Cython managed memoryview but from 2 different Cython files in the scikit-learn source code.

Note: I installed the nogil wheels for numpy, scipy and Cython with pip:

  • Cython 0.29.26
  • NumPy 1.22.3
  • SciPy 1.7.1

If you want I can try to spend time to craft a minimal reproducer using only Cython (and probably numpy) without scikit-learn.

@colesbury
Copy link
Owner

Thanks, I can reproduce the issue and see the problem (a global cache of locks in Cython). I should have a fix soon.

@colesbury
Copy link
Owner

Hi @ogrisel - the fix is taking longer than I originally expected. The Cython issue was easy enough to track down and fix, but it also exposed issues in the NumPy fork and nogil Python.

@colesbury
Copy link
Owner

I've fixed issues in:

  • Cython
  • NumPy 1.19.3
  • nogil Python

I have not ported the fix to NumPy 1.22.3 yet.

You'll need to rebuild nogil Python from source and install the updated packages.

pip cache purge
pip install cython numpy==1.19.3 scipy

The bugs were:

  • Cython memory view used a pool of locks which weren't thread-safe without the GIL [patch]
  • NumPy's buffer protocol wasn't thread safe without the GIL [patch]
  • The implementation of list slicing wasn't thread safe with concurrent modifications [patch]. This showed up infrequently with concurrent calls to warnings.catch_warnings in scikit-learn.

@colesbury
Copy link
Owner

I've fixed the NumPy 1.22.3 as well now.

  • I've fixed another bug in nogil Python that led to a sklearn/compose/tests/test_target.py::test_transform_target_regressor_1d_transformer failing [patch]
  • The k-means code needs modification to run correctly without the GIL [patch]

That's all the bugs I found. I think pytest passes now. Please let me know if you run into any other issues.

@ogrisel
Copy link
Author

ogrisel commented Apr 19, 2022

The k-means code needs modification to run correctly without the GIL [patch]

Yes I had already had noticed but did not mention it because I knew it was on the scikit-learn side and was planning to investigate. Thanks for the patch, I will open a PR to get it upstream.

@ogrisel
Copy link
Author

ogrisel commented Apr 19, 2022

I confirm the fixes seem to work as intended: all scikit-learn tests now pass with the latest nogil versions of CPython / Cython / numpy / scipy.

However I might have observed a memory leak when running some tests under sklearn/ensemble/_hist_gradient_boostin/tests.

I will investigate an open a dedicated issue if needed.

Thanks again for the quick feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants