Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ReinterpretArray #23750

Merged
merged 3 commits into from
Oct 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ This section lists changes that do not have deprecation warnings.
* All command line arguments passed via `-e`, `-E`, and `-L` will be executed in the order
given on the command line ([#23665]).

* The return type of `reinterpret` has changed to `ReinterpretArray`. `reinterpret` on sparse
arrays has been discontinued.

Library improvements
--------------------

Expand Down Expand Up @@ -300,6 +303,10 @@ Library improvements
* New function `equalto(x)`, which returns a function that compares its argument to `x`
using `isequal` ([#23812]).

* `reinterpret` now works on any AbstractArray using the new `ReinterpretArray` type.
This supersedes the old behavior of reinterpret on Arrays. As a result, reinterpreting
arrays with different alignment requirements (removed in 0.6) is once again allowed ([#23750]).

Compiler/Runtime improvements
-----------------------------

Expand Down Expand Up @@ -498,6 +505,11 @@ Deprecated or removed
* `find` functions now operate only on booleans by default. To look for non-zeros, use
`x->x!=0` or `!iszero` ([#23120]).

* The ability of `reinterpret` to yield `Array`s of different type than the underlying storage
has been removed. The `reinterpret` function is still available, but now returns a
`ReinterpretArray`. The three argument form of `reinterpret` that implicitly reshapes
has been deprecated ([#23750]).

Command-line option changes
---------------------------

Expand Down
27 changes: 0 additions & 27 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,33 +218,6 @@ original.
"""
copy(a::T) where {T<:Array} = ccall(:jl_array_copy, Ref{T}, (Any,), a)

function reinterpret(::Type{T}, a::Array{S,1}) where T where S
nel = Int(div(length(a) * sizeof(S), sizeof(T)))
# TODO: maybe check that remainder is zero?
return reinterpret(T, a, (nel,))
end

function reinterpret(::Type{T}, a::Array{S}) where T where S
if sizeof(S) != sizeof(T)
throw(ArgumentError("result shape not specified"))
end
reinterpret(T, a, size(a))
end

function reinterpret(::Type{T}, a::Array{S}, dims::NTuple{N,Int}) where T where S where N
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a deprecation.

function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
@_noinline_meta
throw(ArgumentError("cannot reinterpret Array{$(S)} to ::Type{Array{$(T)}}, type $(U) is not a bits type"))
end
isbits(T) || throwbits(S, T, T)
isbits(S) || throwbits(S, T, S)
nel = div(length(a) * sizeof(S), sizeof(T))
if prod(dims) != nel
_throw_dmrsa(dims, nel)
end
ccall(:jl_reshape_array, Array{T,N}, (Any, Any, Any), Array{T,N}, a, dims)
end

# reshaping to same # of dimensions
function reshape(a::Array{T,N}, dims::NTuple{N,Int}) where T where N
if prod(dims) != length(a)
Expand Down
4 changes: 4 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1871,6 +1871,10 @@ end
# also remove deprecation warnings in find* functions in array.jl, sparse/sparsematrix.jl,
# and sparse/sparsevector.jl.

# issue #22849
@deprecate reinterpret(::Type{T}, a::Array{S}, dims::NTuple{N,Int}) where {T, S, N} reshape(reinterpret(T, vec(a)), dims)
@deprecate reinterpret(::Type{T}, a::SparseMatrixCSC{S}, dims::NTuple{N,Int}) where {T, S, N} reinterpret(T, reshape(a, dims))

# END 0.7 deprecations

# BEGIN 1.0 deprecations
Expand Down
10 changes: 1 addition & 9 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,20 +321,12 @@ unsafe_convert(::Type{P}, x::Ptr) where {P<:Ptr} = convert(P, x)
reinterpret(type, A)

Change the type-interpretation of a block of memory.
For arrays, this constructs an array with the same binary data as the given
For arrays, this constructs a view of the array with the same binary data as the given
array, but with the specified element type.
For example,
`reinterpret(Float32, UInt32(7))` interprets the 4 bytes corresponding to `UInt32(7)` as a
[`Float32`](@ref).

!!! warning

It is not allowed to `reinterpret` an array to an element type with a larger alignment then
the alignment of the array. For a normal `Array`, this is the alignment of its element type.
For a reinterpreted array, this is the alignment of the `Array` it was reinterpreted from.
For example, `reinterpret(UInt32, UInt8[0, 0, 0, 0])` is not allowed but
`reinterpret(UInt32, reinterpret(UInt8, Float32[1.0]))` is allowed.

# Examples
```jldoctest
julia> reinterpret(Float32, UInt32(7))
Expand Down
2 changes: 2 additions & 0 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ add_tfunc(sdiv_int, 2, 2, math_tfunc, 30)
add_tfunc(udiv_int, 2, 2, math_tfunc, 30)
add_tfunc(srem_int, 2, 2, math_tfunc, 30)
add_tfunc(urem_int, 2, 2, math_tfunc, 30)
add_tfunc(add_ptr, 2, 2, math_tfunc, 1)
add_tfunc(sub_ptr, 2, 2, math_tfunc, 1)
add_tfunc(neg_float, 1, 1, math_tfunc, 1)
add_tfunc(add_float, 2, 2, math_tfunc, 1)
add_tfunc(sub_float, 2, 2, math_tfunc, 1)
Expand Down
7 changes: 4 additions & 3 deletions base/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,16 @@ readlines(s=STDIN; chomp::Bool=true) = collect(eachline(s, chomp=chomp))

## byte-order mark, ntoh & hton ##

let endian_boms = reinterpret(UInt8, UInt32[0x01020304])
let a = UInt32[0x01020304]
endian_bom = @gc_preserve a unsafe_load(convert(Ptr{UInt8}, pointer(a)))
global ntoh, hton, ltoh, htol
if endian_boms == UInt8[1:4;]
if endian_bom == 0x01
ntoh(x) = x
hton(x) = x
ltoh(x) = bswap(x)
htol(x) = bswap(x)
const global ENDIAN_BOM = 0x01020304
elseif endian_boms == UInt8[4:-1:1;]
elseif endian_bom == 0x04
ntoh(x) = bswap(x)
hton(x) = bswap(x)
ltoh(x) = x
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ Base.isequal(F::T, G::T) where {T<:Factorization} = all(f -> isequal(getfield(F,
# With a real lhs and complex rhs with the same precision, we can reinterpret
# the complex rhs as a real rhs with twice the number of columns
function (\)(F::Factorization{T}, B::VecOrMat{Complex{T}}) where T<:BlasReal
c2r = reshape(transpose(reinterpret(T, B, (2, length(B)))), size(B, 1), 2*size(B, 2))
c2r = reshape(transpose(reinterpret(T, reshape(B, (1, length(B))))), size(B, 1), 2*size(B, 2))
x = A_ldiv_B!(F, c2r)
return reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)), _ret_size(F, B))
return reshape(collect(reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)))), _ret_size(F, B))
end

for (f1, f2) in ((:\, :A_ldiv_B!),
Expand Down
6 changes: 3 additions & 3 deletions base/linalg/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ end
# With a real lhs and complex rhs with the same precision, we can reinterpret
# the complex rhs as a real rhs with twice the number of columns
function (\)(F::LQ{T}, B::VecOrMat{Complex{T}}) where T<:BlasReal
c2r = reshape(transpose(reinterpret(T, B, (2, length(B)))), size(B, 1), 2*size(B, 2))
c2r = reshape(transpose(reinterpret(T, reshape(B, (1, length(B))))), size(B, 1), 2*size(B, 2))
x = A_ldiv_B!(F, c2r)
return reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)),
isa(B, AbstractVector) ? (size(F,2),) : (size(F,2), size(B,2)))
return reshape(collect(reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)))),
isa(B, AbstractVector) ? (size(F,2),) : (size(F,2), size(B,2)))
end


Expand Down
10 changes: 5 additions & 5 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ A_mul_B!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) where
for elty in (Float32,Float64)
@eval begin
function A_mul_B!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty})
Afl = reinterpret($elty,A,(2size(A,1),size(A,2)))
Afl = reinterpret($elty,A)
yfl = reinterpret($elty,y)
gemv!(yfl,'N',Afl,x)
return y
Expand Down Expand Up @@ -148,8 +148,8 @@ A_mul_B!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) wher
for elty in (Float32,Float64)
@eval begin
function A_mul_B!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty})
Afl = reinterpret($elty, A, (2size(A,1), size(A,2)))
Cfl = reinterpret($elty, C, (2size(C,1), size(C,2)))
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
gemm_wrapper!(Cfl, 'N', 'N', Afl, B)
return C
end
Expand Down Expand Up @@ -190,8 +190,8 @@ A_mul_Bt!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) whe
for elty in (Float32,Float64)
@eval begin
function A_mul_Bt!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty})
Afl = reinterpret($elty, A, (2size(A,1), size(A,2)))
Cfl = reinterpret($elty, C, (2size(C,1), size(C,2)))
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
gemm_wrapper!(Cfl, 'N', 'T', Afl, B)
return C
end
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,15 +923,15 @@ function (\)(A::Union{QR{T},QRCompactWY{T},QRPivoted{T}}, BIn::VecOrMat{Complex{
# |z2|z4| -> |y1|y2|y3|y4| -> |x2|y2| -> |x2|y2|x4|y4|
# |x3|y3|
# |x4|y4|
B = reshape(transpose(reinterpret(T, BIn, (2, length(BIn)))), size(BIn, 1), 2*size(BIn, 2))
B = reshape(transpose(reinterpret(T, reshape(BIn, (1, length(BIn))))), size(BIn, 1), 2*size(BIn, 2))

X = A_ldiv_B!(A, _append_zeros(B, T, n))

# |z1|z3| reinterpret |x1|x2|x3|x4| transpose |x1|y1| reshape |x1|y1|x3|y3|
# |z2|z4| <- |y1|y2|y3|y4| <- |x2|y2| <- |x2|y2|x4|y4|
# |x3|y3|
# |x4|y4|
XX = reinterpret(Complex{T}, transpose(reshape(X, div(length(X), 2), 2)), _ret_size(A, BIn))
XX = reshape(collect(reinterpret(Complex{T}, transpose(reshape(X, div(length(X), 2), 2)))), _ret_size(A, BIn))
return _cut_B(XX, 1:n)
end

Expand Down
4 changes: 2 additions & 2 deletions base/pointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ eltype(::Type{Ptr{T}}) where {T} = T
isless(x::Ptr, y::Ptr) = isless(UInt(x), UInt(y))
-(x::Ptr, y::Ptr) = UInt(x) - UInt(y)

+(x::Ptr, y::Integer) = oftype(x, (UInt(x) + (y % UInt) % UInt))
-(x::Ptr, y::Integer) = oftype(x, (UInt(x) - (y % UInt) % UInt))
+(x::Ptr, y::Integer) = oftype(x, Intrinsics.add_ptr(UInt(x), (y % UInt) % UInt))
-(x::Ptr, y::Integer) = oftype(x, Intrinsics.sub_ptr(UInt(x), (y % UInt) % UInt))
+(x::Integer, y::Ptr) = y + x

"""
Expand Down
11 changes: 6 additions & 5 deletions base/random/dSFMT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ function dsfmt_jump(s::DSFMT_state, jp::AbstractString)
val = s.val
nval = length(val)
index = val[nval - 1]
work = zeros(UInt64, JN32 >> 1)
work = zeros(Int32, JN32)
rwork = reinterpret(UInt64, work)
dsfmt = Vector{UInt64}(nval >> 1)
ccall(:memcpy, Ptr{Void}, (Ptr{UInt64}, Ptr{Int32}, Csize_t),
dsfmt, val, (nval - 1) * sizeof(Int32))
Expand All @@ -113,17 +114,17 @@ function dsfmt_jump(s::DSFMT_state, jp::AbstractString)
for c in jp
bits = parse(UInt8,c,16)
for j in 1:4
(bits & 0x01) != 0x00 && dsfmt_jump_add!(work, dsfmt)
(bits & 0x01) != 0x00 && dsfmt_jump_add!(rwork, dsfmt)
bits = bits >> 0x01
dsfmt_jump_next_state!(dsfmt)
end
end

work[end] = index
return DSFMT_state(reinterpret(Int32, work))
rwork[end] = index
return DSFMT_state(work)
end

function dsfmt_jump_add!(dest::Vector{UInt64}, src::Vector{UInt64})
function dsfmt_jump_add!(dest::AbstractVector{UInt64}, src::Vector{UInt64})
dp = dest[end] >> 1
sp = src[end] >> 1
diff = ((sp - dp + N) % N)
Expand Down
134 changes: 134 additions & 0 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Gives a reinterpreted view (of element type T) of the underlying array (of element type S).
If the size of `T` differs from the size of `S`, the array will be compressed/expanded in
the first dimension.
"""
struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
parent::A
function reinterpret(::Type{T}, a::A) where {T,N,S,A<:AbstractArray{S, N}}
function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
@_noinline_meta
throw(ArgumentError("cannot reinterpret `$(S)` `$(T)`, type `$(U)` is not a bits type"))
end
function throwsize0(::Type{S}, ::Type{T})
@_noinline_meta
throw(ArgumentError("cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a different size"))
end
function thrownonint(::Type{S}, ::Type{T}, dim)
@_noinline_meta
throw(ArgumentError("""
cannot reinterpret an `$(S)` array to `$(T)` whose first dimension has size `$(dim)`.
The resulting array would have non-integral first dimension.
"""))
end
isbits(T) || throwbits(S, T, T)
isbits(S) || throwbits(S, T, S)
(N != 0 || sizeof(T) == sizeof(S)) || throwsize0(S, T)
if N != 0 && sizeof(S) != sizeof(T)
dim = size(a)[1]
rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim)
end
new{T, N, S, A}(a)
end
end

parent(a::ReinterpretArray) = a.parent

eltype(a::ReinterpretArray{T}) where {T} = T
function size(a::ReinterpretArray{T,N,S} where {N}) where {T,S}
psize = size(a.parent)
size1 = div(psize[1]*sizeof(S), sizeof(T))
tuple(size1, tail(psize)...)
end

unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} = Ptr{T}(unsafe_convert(Ptr{S},a.parent))

@inline @propagate_inbounds getindex(a::ReinterpretArray{T,0}) where {T} = reinterpret(T, a.parent[])
@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[1]

@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
return reinterpret(T, a.parent[inds...])
else
ind_start, sidx = divrem((inds[1]-1)*sizeof(T), sizeof(S))
t = Ref{T}()
s = Ref{S}()
@gc_preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
i = 1
nbytes_copied = 0
# This is a bit complicated to deal with partial elements
# at both the start and the end. LLVM will fold as appropriate,
# once it knows the data layout
while nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tail(inds)...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(tptr, unsafe_load(sptr, sidx + 1), nbytes_copied + 1)
sidx += 1
nbytes_copied += 1
end
sidx = 0
i += 1
end
end
return t[]
end
end

@inline @propagate_inbounds setindex!(a::ReinterpretArray{T,0,S} where T, v) where {S} = (a.parent[] = reinterpret(S, v))
@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = (a[1] = v)

@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
v = convert(T, v)::T
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
return setindex!(a.parent, reinterpret(S, v), inds...)
else
ind_start, sidx = divrem((inds[1]-1)*sizeof(T), sizeof(S))
t = Ref{T}(v)
s = Ref{S}()
@gc_preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
nbytes_copied = 0
i = 1
# Deal with any partial elements at the start. We'll have to copy in the
# element from the original array and overwrite the relevant parts
if sidx != 0
s[] = a.parent[ind_start + i, tail(inds)...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
a.parent[ind_start + i, tail(inds)...] = s[]
i += 1
sidx = 0
end
# Deal with the main body of elements
while nbytes_copied < sizeof(T) && (sizeof(T) - nbytes_copied) > sizeof(S)
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
a.parent[ind_start + i, tail(inds)...] = s[]
i += 1
sidx = 0
end
# Deal with trailing partial elements
if nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tail(inds)...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
a.parent[ind_start + i, tail(inds)...] = s[]
end
end
end
return a
end
6 changes: 6 additions & 0 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1888,6 +1888,12 @@ function showarg(io::IO, r::ReshapedArray, toplevel)
toplevel && print(io, " with eltype ", eltype(r))
end

function showarg(io::IO, r::ReinterpretArray{T}, toplevel) where {T}
print(io, "reinterpret($T, ")
showarg(io, parent(r), false)
print(io, ')')
end

# n-dimensional arrays
function show_nd(io::IO, a::AbstractArray, print_matrix, label_slices)
limit::Bool = get(io, :limit, false)
Expand Down
Loading