Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax/tests/lax_scipy_sparse_test.py segfaults on GPU; other GPU test failures #5713

Closed
skye opened this issue Feb 12, 2021 · 5 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@skye
Copy link
Member

skye commented Feb 12, 2021

I'm unable to run all unit tests with jaxlib==0.1.60+cuda111. I suspect this is an issue for all GPU builds.

$ python3 -m pytest jax/tests/
================================================================================================================================== test session starts ===================================================================================================================================
platform linux -- Python 3.6.9, pytest-6.1.1, py-1.9.0, pluggy-0.13.1
rootdir: /home/skyewm/jax, configfile: pytest.ini
plugins: xdist-2.1.0, forked-1.3.0
collected 11931 items                                                                                                                                                                                                                                                                    

jax/tests/api_test.py ....................s.................ss.......................................................................................................................................s..............................s.......ss..........s................s........ [  2%]
...............................s......s................................s.ss.......s..........s...................                                                                                                                                                                  [  3%]
jax/tests/api_util_test.py ............                                                                                                                                                                                                                                            [  3%]
jax/tests/array_interoperability_test.py ..........sssssssssssssssssssssssssssssssssssssssssssssssssss                                                                                                                                                                             [  3%]
jax/tests/batching_test.py .......................................ssssssssssss...................................FF................................................................FF....................F...                                                                      [  5%]
jax/tests/callback_test.py .........                                                                                                                                                                                                                                               [  5%]
jax/tests/core_test.py ......................................................................................................................................................................................................................................                      [  7%]
jax/tests/custom_object_test.py ................                                                                                                                                                                                                                                   [  7%]
jax/tests/debug_nans_test.py ..........                                                                                                                                                                                                                                            [  7%]
jax/tests/doubledouble_test.py ..............................................................                                                                                                                                                                                      [  7%]
jax/tests/dtypes_test.py ......................................................................................................................................................................................................................................................... [  9%]
................................................................................................................................                                                                                                                                                   [ 11%]
jax/tests/errors_test.py ssss                                                                                                                                                                                                                                                      [ 11%]
jax/tests/fft_test.py ..........................................................................................                                                                                                                                                                   [ 11%]
jax/tests/generated_fun_test.py ........................                                                                                                                                                                                                                           [ 12%]
jax/tests/host_callback_to_tf_test.py ssssssssss                                                                                                                                                                                                                                   [ 12%]
jax/tests/image_test.py ssssssssssssssssssss..........................................                                                                                                                                                                                             [ 12%]
jax/tests/infeed_test.py ....                                                                                                                                                                                                                                                      [ 12%]
jax/tests/jax_jit_test.py ..............                                                                                                                                                                                                                                           [ 12%]
jax/tests/jax_to_hlo_test.py ..                                                                                                                                                                                                                                                    [ 12%]
jax/tests/jaxpr_util_test.py .....                                                                                                                                                                                                                                                 [ 12%]
jax/tests/jet_test.py ......s.......F........s..........ss...ss..s...s.ss..........sss.......s                                                                                                                                                                                     [ 13%]
jax/tests/lax_autodiff_test.py ..................................F....F.FFFFFFFFFFFFFFFFFFFFF..................................................................................................................................................................................... [ 15%]
.................................................................................................................................................................................................................................................................................. [ 17%]
............................                                                                                                                                                                                                                                                       [ 18%]
jax/tests/lax_control_flow_test.py .....................s......................................................................................................................................................................................................................... [ 20%]
...............................                                                                                                                                                                                                                                                    [ 20%]
jax/tests/lax_numpy_einsum_test.py ..............................................................................................................................................                                                                                                  [ 21%]
jax/tests/lax_numpy_indexing_test.py ............................................................................................................................................................................................................................................. [ 23%]
.................................................................................                                                                                                                                                                                                  [ 24%]
jax/tests/lax_numpy_test.py ..................................................................................s...............................................s.s................................................................................................................. [ 26%]
...................................................FFFFFFFFFF..............................................................................................................................................................................................................sssss.. [ 28%]
.......................................................................................................................................................................................s.........sssssss.......................................................................... [ 30%]
.................................................................................................................................................................................................................................................................................. [ 33%]
.................................................................................................................................................................................................................................................................................. [ 35%]
.................................................................................................................................................................................................................................................................................. [ 37%]
.................................................................................................................................................................................................................................................................................. [ 39%]
.................................................................................................................................................................................................................................................................................. [ 42%]
.......................ss..............................................................FFFFFFFFFF................................................................................................................................................................................. [ 44%]
.................................................................................................................................................................................................................................................................................. [ 46%]
..........................sssss......................................................................................................................................s............................................................................................................ [ 49%]
.......................................................................................................................................................s..ss..s..s..ss..s.s.sss.ss.s.s..ss..s..s..ss..s..s..ss..s..s..ss..s..s.s..s....s..ss..s..ss.ss..s..s..s........s...s..s..s [ 51%]
......s..s........s.s..s....s.s..s..s..ss..s..s..ss..s..s.s..s....s..ss..s..s..ss..s....s.s..s....s.s..s.s.sss.ss.s.s.s..ss..s.sss.ss.s.s..ss..s..s..ss..s..s..ss..s....s.s..s.s.sss.ss.s.s.sss.ss.s..s.s..s.s.sss.ss.ss.sss.ss.s.s.s..ss.....s.s..s.ss....s..ss.s.sss.ss.ssss.ss. [ 53%]
s.s.s..s....s.s..s......s...s..s..ss..s..s..ss..s..s..ss..s.......s..s..s..ss.s.ss..ss....s.s.sss.ss.ssss.ss.s.s..ss..s..s..ss..s..s..ss..s..s..ss..s..s.s..s....s.s..s...s.sss.ss.s.s..ss..s.s.sss.ss.ss.sss.ss.ss.sss.ss.ss.sss.ss.ss.sss.ss.s.s.s..ss.....s.s..s.s.sss.ss.ss... [ 56%]
.s.....s.s..s...s.sss.ss.ss....s....s.sss.ss.s...s.s..s....s.s..s..s..ss..s..s.s..ss..ss..ss....ss..ss..s.s....s.......s.s..s..s..ss..s..s..ss..s.ssssssssss.s..ss..s.ss..ss..s.ssssssss...s..ss..s..s..ss..s.s.sss.ss.s...s.s..s.                                                 [ 57%]
jax/tests/lax_numpy_vectorize_test.py ............................                                                                                                                                                                                                                 [ 58%]
jax/tests/lax_scipy_sparse_test.py ssssssssss......ssssssssssssssssss.......Fatal Python error: Segmentation fault

Thread 0x00007fb2a68ec740 (most recent call first):
  File "/home/skyewm/jax/jax/interpreters/xla.py", line 356 in backend_compile
  File "/home/skyewm/jax/jax/interpreters/xla.py", line 292 in xla_primitive_callable
  File "/home/skyewm/jax/jax/_src/util.py", line 191 in cached
  File "/home/skyewm/jax/jax/_src/util.py", line 198 in wrapper
  File "/home/skyewm/jax/jax/interpreters/xla.py", line 242 in apply_primitive
  File "/home/skyewm/jax/jax/core.py", line 628 in process_primitive
  File "/home/skyewm/jax/jax/core.py", line 282 in bind
  File "/home/skyewm/jax/jax/core.py", line 363 in eval_jaxpr
  File "/home/skyewm/jax/jax/core.py", line 152 in jaxpr_as_fun
  File "/home/skyewm/jax/jax/_src/lax/control_flow.py", line 2234 in _custom_linear_solve_impl
  File "/home/skyewm/jax/jax/core.py", line 628 in process_primitive
  File "/home/skyewm/jax/jax/core.py", line 282 in bind
  File "/home/skyewm/jax/jax/_src/lax/control_flow.py", line 2224 in custom_linear_solve
  File "/home/skyewm/jax/jax/_src/scipy/sparse/linalg.py", line 622 in gmres
  File "/home/skyewm/jax/tests/lax_scipy_sparse_test.py", line 271 in test_gmres_on_identity_system
[...]

Looks like there are other test failures too, but they didn't print due to the segfault.
cc @hawkinsp

@skye skye added the bug Something isn't working label Feb 12, 2021
@skye skye self-assigned this Feb 12, 2021
@hawkinsp hawkinsp self-assigned this Feb 18, 2021
@hawkinsp
Copy link
Collaborator

I built jaxlib with debug symbols and grabbed a stack trace:

#0  raise (sig=<optimized out>) at ../sysdeps/unix/sysv/linux/raise.c:51
#1  <signal handler called>
#2  0x00007fd372341bd8 in __GI___pthread_timedjoin_ex (threadid=140545868896000, thread_return=0x0,
    abstime=0x0, block=true) at pthread_join_common.c:40
#3  0x00007fd36c7339bf in blas_thread_shutdown_ ()
   from /home/phawkins/.pyenv/versions/py3.8.6/lib/python3.8/site-packages/numpy/core/../../numpy.libs/libopenblasp-r0-5bebc122.3.13.dev.so
#4  0x00007fd37188789a in __libc_fork () at ../sysdeps/nptl/fork.c:96
#5  0x00007fd345661070 in tensorflow::SubProcess::Start (this=0x7fd1fdffd4a0)
    at external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:210
#6  0x00007fd345501393 in stream_executor::CompileGpuAsm (cc_major=7, cc_minor=0,
    ptx_contents=0x7fd1c809c670 "//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 6.0\n.target sm_70\n.address_size 64\n\n\t// .globl\tadd_16\n.extern .global .align 64 .b8 buffer_for_constant_219[4];\n\n.visible .entry add_16(\n\t.param .u6"..., options=...)
    at external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:249
#7  0x00007fd345500163 in stream_executor::CompileGpuAsm (device_ordinal=0,
    ptx_contents=0x7fd1c809c670 "//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 6.0\n.target sm_70\n.address_size 64\n\n\t// .globl\tadd_16\n.extern .global .align 64 .b8 buffer_for_constant_219[4];\n\n.visible .entry add_16(\n\t.param .u6"..., options=...)
    at external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:150
#8  0x00007fd33fb124e2 in xla::gpu::NVPTXCompiler::CompileGpuAsmOrGetCachedResult (
    this=0x563d97c2cba0, stream_exec=0x7fd218006a60,
    ptx="//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 6.0\n.target sm_70\n.address_size 64\n\n\t// .globl\tadd_16\n.extern .global .align 64 .b8 buffer_for_constant_219[4];\n\n.visible .entry add_16(\n\t.param .u6"..., cc_major=7, cc_minor=0, hlo_module_config=..., relocatable=true)
    at external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:377
#9  0x00007fd33fb11dc4 in xla::gpu::NVPTXCompiler::CompileTargetBinary (this=0x563d97c2cba0,
    module_config=..., llvm_module=0x7fd1c8020050, gpu_version=..., stream_exec=0x7fd218006a60,
    relocatable=true, debug_module=0x563dc0e82fa0)

This is probably an instance of OpenBlas (used by NumPy) misbehaving in a process that also fork()s.

@hawkinsp
Copy link
Collaborator

OpenMathLib/OpenBLAS#3111 should fix the underlying OpenBLAS problem, I think. I can't confirm that 100% because I was unable to reproduce the original issue with a self-built OpenBLAS, only the one that is bundled with NumPy.

However, since it will take some time for any OpenBLAS fix to make it into a NumPy release and for that fix to make it to users, I'll also look into avoiding calling pthread_atfork handlers when spawning a subprocess.

@hawkinsp
Copy link
Collaborator

With an upcoming fix to TensorFlow to avoid calling pthread_atfork() handlers, I am down to only two failures:

=========================== short test summary info ============================
FAILED tests/lax_numpy_test.py::NumpySignaturesTest::testWrappedSignaturesMatch
FAILED tests/pmap_test.py::PmapTest::test_replicate_backend - ValueError: com...
========== 2 failed, 10854 passed, 1198 skipped in 955.43s (0:15:55) ===========

The former is related to NumPy 1.20 on my machine and unrelated to GPU specifically.

The latter I am unsure: it doesn't appear when I run that one file in isolation. So I'm guessing it must have something to do with pytest running a particular combination of tests on one worker.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Feb 19, 2021
…/execvp() on non-Android POSIX platforms.

The goal of this change is to avoid calling pthread_atfork() handlers. Some libraries, in particular the version of OpenBLAS included in NumPy, have buggy pthread_atfork() handlers. See OpenMathLib/OpenBLAS#3111 and jax-ml/jax#5713 for details.

Now, while we can and have fixed the buggy atfork handlers, it will take some time for the fix to be deployed in a NumPy release and for users to update to a new NumPy release. So we also take an additional step: avoid running atfork handlers in Subprocess.

My copy of the glibc documentation says:
"
According  to  POSIX, it unspecified whether fork handlers established with pthread_atfork(3)
are called when posix_spawn() is invoked.  On glibc, fork handlers are  called  only  if  the
child is created using fork(2).
"
It appears glibc 2.24 and newer do not call pthread_atfork() handlers from posix_spawn().

Using posix_spawn() should be at least no worse than an explicit fork()/execvp() pair, and on glibc it should do the right thing.

PiperOrigin-RevId: 358317859
Change-Id: Ic1d95446706efa7c0db4e79bf8281f14b2bd99df
@hawkinsp hawkinsp mentioned this issue Feb 19, 2021
@hawkinsp
Copy link
Collaborator

The pmap_test.py failure looks like this:

self = <pmap_test.PmapTest testMethod=test_replicate_backend>

    @jtu.skip_on_devices("cpu")
    def test_replicate_backend(self):
      # https://github.com/google/jax/issues/4223
      def fn(indices):
        return jnp.equal(indices, jnp.arange(3)).astype(jnp.float32)
      mapped_fn = jax.pmap(fn, axis_name='i', backend='cpu')
      mapped_fn = jax.pmap(mapped_fn, axis_name='j', backend='cpu')
      indices = np.array([[[2], [1]], [[0], [0]]])
>     mapped_fn(indices)  # doesn't crash
E     jax._src.traceback_util.FilteredStackTrace: ValueError: compiling computation that requires 4 logical devices, but only 1 XLA devices are available (num_replicas=4, num_partitions=1)
E
E     The stack trace above excludes JAX-internal frames.
E     The following is the original exception that occurred, unmodified.
E
E     --------------------

tests/pmap_test.py:1641: FilteredStackTrace

@hawkinsp hawkinsp removed their assignment Feb 19, 2021
@hawkinsp
Copy link
Collaborator

I think all the issues identified here are already fixed at head.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants