Skip to content

Commit

Permalink
add wrap function which is the safe counterpart to unsafe_wrap. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter authored Dec 9, 2023
1 parent abeb68f commit 84cfe04
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 0 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ New library functions
* `copyuntil(out, io, delim)` and `copyline(out, io)` copy data into an `out::IO` stream ([#48273]).
* `eachrsplit(string, pattern)` iterates split substrings right to left.
* `Sys.username()` can be used to return the current user's username ([#51897]).
* `wrap(Array, m::Union{MemoryRef{T}, Memory{T}}, dims)` which is the safe counterpart to `unsafe_wrap` ([#52049]).

New library features
--------------------
Expand Down
33 changes: 33 additions & 0 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3039,3 +3039,36 @@ intersect(r::AbstractRange, v::AbstractVector) = intersect(v, r)
_getindex(v, i)
end
end

"""
wrap(Array, m::Union{Memory{T}, MemoryRef{T}}, dims)
Create an array of size `dims` using `m` as the underlying memory. This can be thought of as a safe version
of [`unsafe_wrap`](@ref) utilizing `Memory` or `MemoryRef` instead of raw pointers.
"""
function wrap end

@eval @propagate_inbounds function wrap(::Type{Array}, ref::MemoryRef{T}, dims::NTuple{N, Integer}) where {T, N}
mem = ref.mem
mem_len = length(mem) + 1 - memoryrefoffset(ref)
len = Core.checked_dims(dims...)
@boundscheck mem_len >= len || invalid_wrap_err(mem_len, dims, len)
if N != 1 && !(ref === GenericMemoryRef(mem) && len === mem_len)
mem = ccall(:jl_genericmemory_slice, Memory{T}, (Any, Ptr{Cvoid}, Int), mem, ref.ptr_or_offset, len)
ref = MemoryRef(mem)
end
$(Expr(:new, :(Array{T, N}), :ref, :dims))
end

@noinline invalid_wrap_err(len, dims, proddims) = throw(DimensionMismatch(
"Attempted to wrap a MemoryRef of length $len with an Array of size dims=$dims, which is invalid because prod(dims) = $proddims > $len, so that the array would have more elements than the underlying memory can store."))

function wrap(::Type{Array}, m::Memory{T}, dims::NTuple{N, Integer}) where {T, N}
wrap(Array, MemoryRef(m), dims)
end
function wrap(::Type{Array}, m::MemoryRef{T}, l::Integer) where {T}
wrap(Array, m, (l,))
end
function wrap(::Type{Array}, m::Memory{T}, l::Integer) where {T}
wrap(Array, MemoryRef(m), (l,))
end
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ export
vcat,
vec,
view,
wrap,
zeros,

# search, find, match and related functions
Expand Down
19 changes: 19 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3170,3 +3170,22 @@ end
@test c + zero(c) == c
end
end

@testset "Wrapping Memory into Arrays" begin
mem = Memory{Int}(undef, 10) .= 1
memref = MemoryRef(mem)
@test_throws DimensionMismatch wrap(Array, mem, (10, 10))
@test wrap(Array, mem, (5,)) == ones(Int, 5)
@test wrap(Array, mem, 2) == ones(Int, 2)
@test wrap(Array, memref, 10) == ones(Int, 10)
@test wrap(Array, memref, (2,2,2)) == ones(Int,2,2,2)
@test wrap(Array, mem, (5, 2)) == ones(Int, 5, 2)

memref2 = MemoryRef(mem, 3)
@test wrap(Array, memref2, (5,)) == ones(Int, 5)
@test wrap(Array, memref2, 2) == ones(Int, 2)
@test wrap(Array, memref2, (2,2,2)) == ones(Int,2,2,2)
@test wrap(Array, memref2, (3, 2)) == ones(Int, 3, 2)
@test_throws DimensionMismatch wrap(Array, memref2, 9)
@test_throws DimensionMismatch wrap(Array, memref2, 10)
end

0 comments on commit 84cfe04

Please sign in to comment.