Skip to content

Commit

Permalink
Add a to_dlpack interface
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Feb 5, 2022
1 parent 64c4597 commit 608fb29
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 10 deletions.
70 changes: 60 additions & 10 deletions src/DLPack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,6 @@ using Requires
export DLArray, DLVector, DLMatrix, RowMajor, ColMajor


## Aliases and constants ##

const PYCAPSULE_NAME = Ref(
(0x64, 0x6c, 0x74, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x00)
)
const USED_PYCAPSULE_NAME = Ref(
(0x75, 0x73, 0x65, 0x64, 0x5f, 0x64, 0x6c, 0x74, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x00)
)


## Types ##

@enum DLDeviceType::Cint begin
Expand Down Expand Up @@ -77,6 +67,12 @@ mutable struct DLManagedTensor
manager_ctx::Ptr{Cvoid}
deleter::Ptr{Cvoid}

function DLManagedTensor(
dl_tensor::DLTensor, manager_ctx::Ptr{Cvoid}, deleter::Ptr{Cvoid}
)
return new(dl_tensor, manager_ctx, deleter)
end

function DLManagedTensor(dlptr::Ptr{DLManagedTensor})
manager = unsafe_load(dlptr)

Expand All @@ -89,6 +85,12 @@ mutable struct DLManagedTensor
end
end

struct DLCapsule
shape::Vector{Clonglong}
strides::Vector{Clonglong}
tensor::Ref{DLManagedTensor}
end

abstract type MemoryLayout end

struct ColMajor <: MemoryLayout end
Expand Down Expand Up @@ -140,9 +142,22 @@ function DLArray{T, N}(::Type{A}, ::Type{M}, manager::DLManagedTensor, foreign)
return DLArray(manager, foreign, data)
end


## Aliases and constants ##

const DLVector{T} = DLArray{T, 1}
const DLMatrix{T} = DLArray{T, 2}

const PYCAPSULE_NAME = Ref(
(0x64, 0x6c, 0x74, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x00)
)

const USED_PYCAPSULE_NAME = Ref(
(0x75, 0x73, 0x65, 0x64, 0x5f, 0x64, 0x6c, 0x74, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x00)
)

const DLCAPSULE_POOL = Dict{Ref{DLManagedTensor}, DLCapsule}()


## Utils ##

Expand Down Expand Up @@ -190,6 +205,8 @@ function jlarray_type(::Val{D}) where {D}
end
end

dldevice(::StridedArray) = DLDevice(kDLCPU, Cint(0))

device_type(ctx::DLDevice) = ctx.device_type
device_type(tensor::DLTensor) = device_type(tensor.ctx)
device_type(manager::DLManagedTensor) = device_type(manager.dl_tensor)
Expand Down Expand Up @@ -275,6 +292,39 @@ function Base.Broadcast.BroadcastStyle(::Type{D}) where {T, N, A, D <: DLArray{T
end


## DLPack wrapping ##

function dlfinalize(ref)
delete!(DLCAPSULE_POOL, ref)
return nothing
end

function to_dlpack(A::StridedArray{T, N}) where {T, N}
sh = [Clonglong(i) for i in size(A)]
st = [Clonglong(i) for i in strides(A)]

data = pointer(A)
ctx = dldevice(A)
ndim = Cint(N)
dtype = jltypes_to_dtypes()[T]
sh_ptr = pointer(sh)
st_ptr = pointer(st)
dl_tensor = DLTensor(data, ctx, ndim, dtype, sh_ptr, st_ptr, Culonglong(0))

tensor = Ref(DLManagedTensor(dl_tensor, C_NULL, C_NULL))
capsule = DLCapsule(sh, st, tensor)

DLCAPSULE_POOL[tensor] = capsule

finalizer(tensor) do ref
delete!(DLCAPSULE_POOL, ref)
return nothing
end

return tensor
end


## Module initialization ##

function __init__()
Expand Down
24 changes: 24 additions & 0 deletions src/pycall.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
using .PyCall


# We define a noop deleter to pass to new `DLManagedTensor` exported to python libraries
# as some of them (e.g. PyTorch) do not handle the case when the finalizer is `C_NULL`.
# Also, we have to define this here whithin `__init__` and `@require` or it segfaults when
# `ccall`ed (is this a julia or Requires.jl bug, or a world-age issue?)
const PYCALL_NOOP_DELETER = @cfunction(ptr -> nothing, Cvoid, (Ptr{DLManagedTensor},))


function DLManagedTensor(po::PyObject)
if !pyisinstance(po, PyCall.@pyglobalobj(:PyCapsule_Type))
throw(ArgumentError("PyObject must be a PyCapsule"))
Expand Down Expand Up @@ -43,3 +51,19 @@ end
function DLArray{T, N}(::Type{A}, ::Type{M}, o::PyObject, to_dlpack) where {T, N, A, M}
return DLArray{T, N}(A, M, DLManagedTensor(to_dlpack(o)), o)
end

function to_dlpack(A::StridedArray, foreign_from_dlpack::Union{PyObject, Function})
tensor = to_dlpack(A)
tensor[].deleter = PYCALL_NOOP_DELETER

o = GC.@preserve tensor begin
pycapsule = PyObject(PyCall.@pycheck ccall(
(@pysym :PyCapsule_New),
PyPtr, (Ptr{Cvoid}, Ptr{UInt8}, Ptr{Cvoid}),
tensor, PYCAPSULE_NAME, C_NULL
))
foreign_from_dlpack(pycapsule)
end

return PyCall.pyembed(o, A)
end

0 comments on commit 608fb29

Please sign in to comment.