diff --git a/NEWS.md b/NEWS.md index 30a4c1af21d89..f813db02819d7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -61,6 +61,7 @@ New library functions * `copyuntil(out, io, delim)` and `copyline(out, io)` copy data into an `out::IO` stream ([#48273]). * `eachrsplit(string, pattern)` iterates split substrings right to left. * `Sys.username()` can be used to return the current user's username ([#51897]). +* `wrap(Array, m::Union{MemoryRef{T}, Memory{T}}, dims)` which is the safe counterpart to `unsafe_wrap` ([#52049]). New library features -------------------- diff --git a/base/array.jl b/base/array.jl index 23a1a59052147..1dd7e66e8274a 100644 --- a/base/array.jl +++ b/base/array.jl @@ -3039,3 +3039,36 @@ intersect(r::AbstractRange, v::AbstractVector) = intersect(v, r) _getindex(v, i) end end + +""" + wrap(Array, m::Union{Memory{T}, MemoryRef{T}}, dims) + +Create an array of size `dims` using `m` as the underlying memory. This can be thought of as a safe version +of [`unsafe_wrap`](@ref) utilizing `Memory` or `MemoryRef` instead of raw pointers. +""" +function wrap end + +@eval @propagate_inbounds function wrap(::Type{Array}, ref::MemoryRef{T}, dims::NTuple{N, Integer}) where {T, N} + mem = ref.mem + mem_len = length(mem) + 1 - memoryrefoffset(ref) + len = Core.checked_dims(dims...) + @boundscheck mem_len >= len || invalid_wrap_err(mem_len, dims, len) + if N != 1 && !(ref === GenericMemoryRef(mem) && len === mem_len) + mem = ccall(:jl_genericmemory_slice, Memory{T}, (Any, Ptr{Cvoid}, Int), mem, ref.ptr_or_offset, len) + ref = MemoryRef(mem) + end + $(Expr(:new, :(Array{T, N}), :ref, :dims)) +end + +@noinline invalid_wrap_err(len, dims, proddims) = throw(DimensionMismatch( + "Attempted to wrap a MemoryRef of length $len with an Array of size dims=$dims, which is invalid because prod(dims) = $proddims > $len, so that the array would have more elements than the underlying memory can store.")) + +function wrap(::Type{Array}, m::Memory{T}, dims::NTuple{N, Integer}) where {T, N} + wrap(Array, MemoryRef(m), dims) +end +function wrap(::Type{Array}, m::MemoryRef{T}, l::Integer) where {T} + wrap(Array, m, (l,)) +end +function wrap(::Type{Array}, m::Memory{T}, l::Integer) where {T} + wrap(Array, MemoryRef(m), (l,)) +end diff --git a/base/exports.jl b/base/exports.jl index 398d828f9cf19..374016f22571f 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -458,6 +458,7 @@ export vcat, vec, view, + wrap, zeros, # search, find, match and related functions diff --git a/test/arrayops.jl b/test/arrayops.jl index 9c9aac987929d..8e33e209ee88b 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -3170,3 +3170,22 @@ end @test c + zero(c) == c end end + +@testset "Wrapping Memory into Arrays" begin + mem = Memory{Int}(undef, 10) .= 1 + memref = MemoryRef(mem) + @test_throws DimensionMismatch wrap(Array, mem, (10, 10)) + @test wrap(Array, mem, (5,)) == ones(Int, 5) + @test wrap(Array, mem, 2) == ones(Int, 2) + @test wrap(Array, memref, 10) == ones(Int, 10) + @test wrap(Array, memref, (2,2,2)) == ones(Int,2,2,2) + @test wrap(Array, mem, (5, 2)) == ones(Int, 5, 2) + + memref2 = MemoryRef(mem, 3) + @test wrap(Array, memref2, (5,)) == ones(Int, 5) + @test wrap(Array, memref2, 2) == ones(Int, 2) + @test wrap(Array, memref2, (2,2,2)) == ones(Int,2,2,2) + @test wrap(Array, memref2, (3, 2)) == ones(Int, 3, 2) + @test_throws DimensionMismatch wrap(Array, memref2, 9) + @test_throws DimensionMismatch wrap(Array, memref2, 10) +end