-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add share
and wrap
interfaces
#17
Conversation
608fb29
to
4c4b3e4
Compare
I'm seeing intermittent segfaults like above with this branch, just so you are aware. |
Do you have a reproducible example? |
It's just intermittent unfortunately. With https://gist.github.com/rejuvyesh/0c0995ac81d8c75efada7797a292f611: julia> include("test/stresstest_dlpack.jl")
[ Info: Precompiling Zygote [e88e6eb3-aa80-5325-afca-941959d7151f]
┌ Warning: `vendor()` is deprecated, use `BLAS.get_config()` and inspect the output instead
│ caller = npyinitialize() at numpy.jl:67
└ @ PyCall ~/.julia/packages/PyCall/L0fLP/src/numpy.jl:67
Test Passed
Expression: size(grad.params[2]) == size(modelwrap.params[2])
Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
Test Passed
Expression: size(grad.params[2]) == size(modelwrap.params[2])
Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
Test Passed
Expression: size(grad.params[2]) == size(modelwrap.params[2])
Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
signal (11): Segmentation fault
in expression starting at /home/jagupt/.julia/dev/PyCallChainRules/test/stresstest_dlpack.jl:102
unknown function (ip: 0x7f7076f69e30)
Allocations: 62477905 (Pool: 62457694; Big: 20211); GC: 69 It's still quite a big example and I'm trying to reduce it to something smaller. But in case this is useful by itself. |
Also I haven't been able to reproduce the segfault with EDIT: I have been able to get segfaults on |
Can you try again @rejuvyesh? |
Yep, this fixes the issues! Amazing sleuthing! |
To support both PyCall and PythonCall, I need to change the signature of |
I think for proper |
using CUDA
using DLPack
using PyCall
CUDA.allowscalar(false)
dlpack = pyimport("jax.dlpack")
numpy = pyimport("numpy")
pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject
jax = pyimport("jax")
key = jax.random.PRNGKey(0)
jax_x = jax.random.normal(key, (2, 3))
jax_sum = jax.numpy.sum(jax_x)
jl_x = DLPack.DLArray(jax_x, pyto_dlpack)
jl_sum = sum(jl_x)
@assert isapprox(jax_sum.item(), jl_sum) Results into:
|
It seems we will need a separate wrapper for CUDA arrays than |
Ok. I decided to write a The only thing is that I dropped the So now the interface looks like:
|
Once all python libraries support exporting and importing via |
Let me know if you need any help with the tests, but this works quite well! We should ask on |
Just a last heads-up, I have removed the |
Thank you @rejuvyesh for testing this over rejuvyesh/PyCallChainRules.jl#10! That guided the design of the interface quite well and helped to make the PR more robust. |
Fixes #10
TODO
DLArray