-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: Resource temporarily unavailable due to running out of threads (ulimit -u) #2685
Comments
We're going to have a hard time debugging this without being able to reproduce it, which might be hard. Does it need a particular MPI setup or does it reproduce in, say, a cloud VM? I'm not aware of any change in jaxlib that would have intentionally triggered this problem, so we probably need to debug this the hard way. I think the first step would be to debug what resource we are running out of. Some guesses are
If I were going to try debugging this myself, I might see if I can reproduce the behavior under |
@hawkinsp I'd be happy to try out different suggestions. I can try to reproduce this behavior on a different cluster and see if I get it there. It might be important to understand what the observation tells us that the error does not appear for How do I try to reproduce the behavior under the |
I checked two more linux platforms, and my mac's osx: the upshot is that 2/3 linux platforms give the same error; the other linux platform and the osx platform do not throw an error. How can I get the details of the faulty linux platforms to share them here? Is it expected that the pip version of jax shows problems on some linux platforms? The good news is that to reproduce the error it seems like we don't need a node with a large number of cores. All one needs is to run the minimal code snippet above with more than 24 MPI processes:
|
Could you share the output of I'm unable to reproduce this problem on a Google Compute Platform cloud VM under Debian 10. I created an n1-standard-64 instance (64 vCPUs) and installed Debian 10. I installed Python 3.7 and mpi4py using If you were able to reproduce this on a reasonably standard linux platform, it would be helpful if you can share details. I need to be able to reproduce this to debug it, and the best case would be that we can reproduce it in some configuration we can both get access to (e.g., on a cloud VM). |
Linux cori01 4.12.14-150.47-default #1 SMP Wed Dec 18 15:05:52 UTC 2019 (8162e25) x86_64 x86_64 x86_64 GNU/Linux core file size (blocks, -c) 0
Linux scc1 3.10.0-1062.12.1.el7.x86_64 #1 SMP Tue Feb 4 23:02:59 UTC 2020 x86_64 x86_64 x86_64 GNU/Linux core file size (blocks, -c) 0
Linux cuiw-ubuntu 4.15.0-96-generic #97-Ubuntu SMP Wed Apr 1 03:25:46 UTC 2020 x86_64 x86_64 x86_64 GNU/Linux core file size (blocks, -c) 0
Linux fe001 4.4.0-31-generic #50-Ubuntu SMP Wed Jul 13 00:07:12 UTC 2016 x86_64 x86_64 x86_64 GNU/Linux core file size (blocks, -c) 0 |
I think the issue is the value of e.g., on my
If I set JAX uses threading internally in some fairly fundamental ways, so we're not going to be able to eliminate threading, but we can certainly provide an option to reduce the size of some of the internal threadpools. Or perhaps we can detect this kind of MPI configuration and choose a more appropriate thread pool size automatically. You might be able to work around the problem for now by raising |
@hawkinsp I see, so what is the proper way to manually set the number of threads that jax creates? I feel like I should be able to tell it to use a single thread if needed. This is particularly helpful when large simulations are launched and one needs to fit within the available resources. I currently have
but it doesnt see to help. |
I see, and then it must be this intrinsic threading in jax, that was updated when |
Happily it turns out there's an easy solution to this: the libraries underlying JAX respect CPU affinity values and when creating a threadpool sizes it according to the task's CPU affinity map. So if you tell MPI to assign, say, one core per process via the appropriate options to e.g.,
worked fine for me. (I gather you should probably use Does that help? |
Unfortunately neither of
etc. work for me [I have However, the idea is spot on, because the SLURM command
@hawkinsp it might be helpful to comment on this somewhere on the jax documentation page. |
I have exactly the same issue! I am training a FLAX model, and after several steps (210 - 230) I get: I am not using mpirun or anything...
Python packages:
I am running the experiments on:
I even do not know from where the error occurs |
jax=0.1.62
jaxlib=0.1.43
I bumped into an issue running multiple processes which call jax in parallel using Open MPI. I was able to distill the error as follows:
I'm running the following code snippet with CPU backend:
in parallel over 26 (independent) processes using
mpi4py
by executing the commandThis miniprogram runs on a big 128 GB compute node with 30 Haswell processors [more on the Haswell node specifics here]. All processors are reserved exclusively for this calculation.
Every time I run this job, a number of random processes terminate with the jax RuntimeError
I also tried distributing the same 26 MPI processes over two and more identical compute nodes and see the same behavior.
Occasionally, I see the same error occur in the
PRNGKey
function.I also found out that this error does NOT occur with
jax=0.1.58, jaxlib=0.1.37
. Does anyone have an idea what might be causing this -- it must be within the commitjaxlib=0.1.37 --> jaxlib=0.1.38
.The text was updated successfully, but these errors were encountered: