Skip to content

Commit

Permalink
Merge pull request #23750 from JuliaLang/kf/reinterpretarray
Browse files Browse the repository at this point in the history
Implement ReinterpretArray
  • Loading branch information
Keno authored Oct 9, 2017
2 parents 19b3ca7 + 0054051 commit 8273529
Show file tree
Hide file tree
Showing 38 changed files with 767 additions and 225 deletions.
12 changes: 12 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ This section lists changes that do not have deprecation warnings.
* All command line arguments passed via `-e`, `-E`, and `-L` will be executed in the order
given on the command line ([#23665]).

* The return type of `reinterpret` has changed to `ReinterpretArray`. `reinterpret` on sparse
arrays has been discontinued.

Library improvements
--------------------

Expand Down Expand Up @@ -313,6 +316,10 @@ Library improvements
* New function `equalto(x)`, which returns a function that compares its argument to `x`
using `isequal` ([#23812]).

* `reinterpret` now works on any AbstractArray using the new `ReinterpretArray` type.
This supersedes the old behavior of reinterpret on Arrays. As a result, reinterpreting
arrays with different alignment requirements (removed in 0.6) is once again allowed ([#23750]).

Compiler/Runtime improvements
-----------------------------

Expand Down Expand Up @@ -511,6 +518,11 @@ Deprecated or removed
* `find` functions now operate only on booleans by default. To look for non-zeros, use
`x->x!=0` or `!iszero` ([#23120]).

* The ability of `reinterpret` to yield `Array`s of different type than the underlying storage
has been removed. The `reinterpret` function is still available, but now returns a
`ReinterpretArray`. The three argument form of `reinterpret` that implicitly reshapes
has been deprecated ([#23750]).

Command-line option changes
---------------------------

Expand Down
27 changes: 0 additions & 27 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,33 +218,6 @@ original.
"""
copy(a::T) where {T<:Array} = ccall(:jl_array_copy, Ref{T}, (Any,), a)

function reinterpret(::Type{T}, a::Array{S,1}) where T where S
nel = Int(div(length(a) * sizeof(S), sizeof(T)))
# TODO: maybe check that remainder is zero?
return reinterpret(T, a, (nel,))
end

function reinterpret(::Type{T}, a::Array{S}) where T where S
if sizeof(S) != sizeof(T)
throw(ArgumentError("result shape not specified"))
end
reinterpret(T, a, size(a))
end

function reinterpret(::Type{T}, a::Array{S}, dims::NTuple{N,Int}) where T where S where N
function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
@_noinline_meta
throw(ArgumentError("cannot reinterpret Array{$(S)} to ::Type{Array{$(T)}}, type $(U) is not a bits type"))
end
isbits(T) || throwbits(S, T, T)
isbits(S) || throwbits(S, T, S)
nel = div(length(a) * sizeof(S), sizeof(T))
if prod(dims) != nel
_throw_dmrsa(dims, nel)
end
ccall(:jl_reshape_array, Array{T,N}, (Any, Any, Any), Array{T,N}, a, dims)
end

# reshaping to same # of dimensions
function reshape(a::Array{T,N}, dims::NTuple{N,Int}) where T where N
if prod(dims) != length(a)
Expand Down
4 changes: 4 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,10 @@ end
# also remove deprecation warnings in find* functions in array.jl, sparse/sparsematrix.jl,
# and sparse/sparsevector.jl.

# issue #22849
@deprecate reinterpret(::Type{T}, a::Array{S}, dims::NTuple{N,Int}) where {T, S, N} reshape(reinterpret(T, vec(a)), dims)
@deprecate reinterpret(::Type{T}, a::SparseMatrixCSC{S}, dims::NTuple{N,Int}) where {T, S, N} reinterpret(T, reshape(a, dims))

# END 0.7 deprecations

# BEGIN 1.0 deprecations
Expand Down
10 changes: 1 addition & 9 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,20 +321,12 @@ unsafe_convert(::Type{P}, x::Ptr) where {P<:Ptr} = convert(P, x)
reinterpret(type, A)
Change the type-interpretation of a block of memory.
For arrays, this constructs an array with the same binary data as the given
For arrays, this constructs a view of the array with the same binary data as the given
array, but with the specified element type.
For example,
`reinterpret(Float32, UInt32(7))` interprets the 4 bytes corresponding to `UInt32(7)` as a
[`Float32`](@ref).
!!! warning
It is not allowed to `reinterpret` an array to an element type with a larger alignment then
the alignment of the array. For a normal `Array`, this is the alignment of its element type.
For a reinterpreted array, this is the alignment of the `Array` it was reinterpreted from.
For example, `reinterpret(UInt32, UInt8[0, 0, 0, 0])` is not allowed but
`reinterpret(UInt32, reinterpret(UInt8, Float32[1.0]))` is allowed.
# Examples
```jldoctest
julia> reinterpret(Float32, UInt32(7))
Expand Down
2 changes: 2 additions & 0 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ add_tfunc(sdiv_int, 2, 2, math_tfunc, 30)
add_tfunc(udiv_int, 2, 2, math_tfunc, 30)
add_tfunc(srem_int, 2, 2, math_tfunc, 30)
add_tfunc(urem_int, 2, 2, math_tfunc, 30)
add_tfunc(add_ptr, 2, 2, math_tfunc, 1)
add_tfunc(sub_ptr, 2, 2, math_tfunc, 1)
add_tfunc(neg_float, 1, 1, math_tfunc, 1)
add_tfunc(add_float, 2, 2, math_tfunc, 1)
add_tfunc(sub_float, 2, 2, math_tfunc, 1)
Expand Down
7 changes: 4 additions & 3 deletions base/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,16 @@ readlines(s=STDIN; chomp::Bool=true) = collect(eachline(s, chomp=chomp))

## byte-order mark, ntoh & hton ##

let endian_boms = reinterpret(UInt8, UInt32[0x01020304])
let a = UInt32[0x01020304]
endian_bom = @gc_preserve a unsafe_load(convert(Ptr{UInt8}, pointer(a)))
global ntoh, hton, ltoh, htol
if endian_boms == UInt8[1:4;]
if endian_bom == 0x01
ntoh(x) = x
hton(x) = x
ltoh(x) = bswap(x)
htol(x) = bswap(x)
const global ENDIAN_BOM = 0x01020304
elseif endian_boms == UInt8[4:-1:1;]
elseif endian_bom == 0x04
ntoh(x) = bswap(x)
hton(x) = bswap(x)
ltoh(x) = x
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ Base.isequal(F::T, G::T) where {T<:Factorization} = all(f -> isequal(getfield(F,
# With a real lhs and complex rhs with the same precision, we can reinterpret
# the complex rhs as a real rhs with twice the number of columns
function (\)(F::Factorization{T}, B::VecOrMat{Complex{T}}) where T<:BlasReal
c2r = reshape(transpose(reinterpret(T, B, (2, length(B)))), size(B, 1), 2*size(B, 2))
c2r = reshape(transpose(reinterpret(T, reshape(B, (1, length(B))))), size(B, 1), 2*size(B, 2))
x = A_ldiv_B!(F, c2r)
return reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)), _ret_size(F, B))
return reshape(collect(reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)))), _ret_size(F, B))
end

for (f1, f2) in ((:\, :A_ldiv_B!),
Expand Down
6 changes: 3 additions & 3 deletions base/linalg/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ end
# With a real lhs and complex rhs with the same precision, we can reinterpret
# the complex rhs as a real rhs with twice the number of columns
function (\)(F::LQ{T}, B::VecOrMat{Complex{T}}) where T<:BlasReal
c2r = reshape(transpose(reinterpret(T, B, (2, length(B)))), size(B, 1), 2*size(B, 2))
c2r = reshape(transpose(reinterpret(T, reshape(B, (1, length(B))))), size(B, 1), 2*size(B, 2))
x = A_ldiv_B!(F, c2r)
return reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)),
isa(B, AbstractVector) ? (size(F,2),) : (size(F,2), size(B,2)))
return reshape(collect(reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)))),
isa(B, AbstractVector) ? (size(F,2),) : (size(F,2), size(B,2)))
end


Expand Down
10 changes: 5 additions & 5 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ A_mul_B!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) where
for elty in (Float32,Float64)
@eval begin
function A_mul_B!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty})
Afl = reinterpret($elty,A,(2size(A,1),size(A,2)))
Afl = reinterpret($elty,A)
yfl = reinterpret($elty,y)
gemv!(yfl,'N',Afl,x)
return y
Expand Down Expand Up @@ -148,8 +148,8 @@ A_mul_B!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) wher
for elty in (Float32,Float64)
@eval begin
function A_mul_B!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty})
Afl = reinterpret($elty, A, (2size(A,1), size(A,2)))
Cfl = reinterpret($elty, C, (2size(C,1), size(C,2)))
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
gemm_wrapper!(Cfl, 'N', 'N', Afl, B)
return C
end
Expand Down Expand Up @@ -190,8 +190,8 @@ A_mul_Bt!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) whe
for elty in (Float32,Float64)
@eval begin
function A_mul_Bt!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty})
Afl = reinterpret($elty, A, (2size(A,1), size(A,2)))
Cfl = reinterpret($elty, C, (2size(C,1), size(C,2)))
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
gemm_wrapper!(Cfl, 'N', 'T', Afl, B)
return C
end
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -918,15 +918,15 @@ function (\)(A::Union{QR{T},QRCompactWY{T},QRPivoted{T}}, BIn::VecOrMat{Complex{
# |z2|z4| -> |y1|y2|y3|y4| -> |x2|y2| -> |x2|y2|x4|y4|
# |x3|y3|
# |x4|y4|
B = reshape(transpose(reinterpret(T, BIn, (2, length(BIn)))), size(BIn, 1), 2*size(BIn, 2))
B = reshape(transpose(reinterpret(T, reshape(BIn, (1, length(BIn))))), size(BIn, 1), 2*size(BIn, 2))

X = A_ldiv_B!(A, _append_zeros(B, T, n))

# |z1|z3| reinterpret |x1|x2|x3|x4| transpose |x1|y1| reshape |x1|y1|x3|y3|
# |z2|z4| <- |y1|y2|y3|y4| <- |x2|y2| <- |x2|y2|x4|y4|
# |x3|y3|
# |x4|y4|
XX = reinterpret(Complex{T}, transpose(reshape(X, div(length(X), 2), 2)), _ret_size(A, BIn))
XX = reshape(collect(reinterpret(Complex{T}, transpose(reshape(X, div(length(X), 2), 2)))), _ret_size(A, BIn))
return _cut_B(XX, 1:n)
end

Expand Down
4 changes: 2 additions & 2 deletions base/pointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ eltype(::Type{Ptr{T}}) where {T} = T
isless(x::Ptr, y::Ptr) = isless(UInt(x), UInt(y))
-(x::Ptr, y::Ptr) = UInt(x) - UInt(y)

+(x::Ptr, y::Integer) = oftype(x, (UInt(x) + (y % UInt) % UInt))
-(x::Ptr, y::Integer) = oftype(x, (UInt(x) - (y % UInt) % UInt))
+(x::Ptr, y::Integer) = oftype(x, Intrinsics.add_ptr(UInt(x), (y % UInt) % UInt))
-(x::Ptr, y::Integer) = oftype(x, Intrinsics.sub_ptr(UInt(x), (y % UInt) % UInt))
+(x::Integer, y::Ptr) = y + x

"""
Expand Down
11 changes: 6 additions & 5 deletions base/random/dSFMT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ function dsfmt_jump(s::DSFMT_state, jp::AbstractString)
val = s.val
nval = length(val)
index = val[nval - 1]
work = zeros(UInt64, JN32 >> 1)
work = zeros(Int32, JN32)
rwork = reinterpret(UInt64, work)
dsfmt = Vector{UInt64}(nval >> 1)
ccall(:memcpy, Ptr{Void}, (Ptr{UInt64}, Ptr{Int32}, Csize_t),
dsfmt, val, (nval - 1) * sizeof(Int32))
Expand All @@ -113,17 +114,17 @@ function dsfmt_jump(s::DSFMT_state, jp::AbstractString)
for c in jp
bits = parse(UInt8,c,16)
for j in 1:4
(bits & 0x01) != 0x00 && dsfmt_jump_add!(work, dsfmt)
(bits & 0x01) != 0x00 && dsfmt_jump_add!(rwork, dsfmt)
bits = bits >> 0x01
dsfmt_jump_next_state!(dsfmt)
end
end

work[end] = index
return DSFMT_state(reinterpret(Int32, work))
rwork[end] = index
return DSFMT_state(work)
end

function dsfmt_jump_add!(dest::Vector{UInt64}, src::Vector{UInt64})
function dsfmt_jump_add!(dest::AbstractVector{UInt64}, src::Vector{UInt64})
dp = dest[end] >> 1
sp = src[end] >> 1
diff = ((sp - dp + N) % N)
Expand Down
134 changes: 134 additions & 0 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Gives a reinterpreted view (of element type T) of the underlying array (of element type S).
If the size of `T` differs from the size of `S`, the array will be compressed/expanded in
the first dimension.
"""
struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
parent::A
function reinterpret(::Type{T}, a::A) where {T,N,S,A<:AbstractArray{S, N}}
function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
@_noinline_meta
throw(ArgumentError("cannot reinterpret `$(S)` `$(T)`, type `$(U)` is not a bits type"))
end
function throwsize0(::Type{S}, ::Type{T})
@_noinline_meta
throw(ArgumentError("cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a different size"))
end
function thrownonint(::Type{S}, ::Type{T}, dim)
@_noinline_meta
throw(ArgumentError("""
cannot reinterpret an `$(S)` array to `$(T)` whose first dimension has size `$(dim)`.
The resulting array would have non-integral first dimension.
"""))
end
isbits(T) || throwbits(S, T, T)
isbits(S) || throwbits(S, T, S)
(N != 0 || sizeof(T) == sizeof(S)) || throwsize0(S, T)
if N != 0 && sizeof(S) != sizeof(T)
dim = size(a)[1]
rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim)
end
new{T, N, S, A}(a)
end
end

parent(a::ReinterpretArray) = a.parent

eltype(a::ReinterpretArray{T}) where {T} = T
function size(a::ReinterpretArray{T,N,S} where {N}) where {T,S}
psize = size(a.parent)
size1 = div(psize[1]*sizeof(S), sizeof(T))
tuple(size1, tail(psize)...)
end

unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} = Ptr{T}(unsafe_convert(Ptr{S},a.parent))

@inline @propagate_inbounds getindex(a::ReinterpretArray{T,0}) where {T} = reinterpret(T, a.parent[])
@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[1]

@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
return reinterpret(T, a.parent[inds...])
else
ind_start, sidx = divrem((inds[1]-1)*sizeof(T), sizeof(S))
t = Ref{T}()
s = Ref{S}()
@gc_preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
i = 1
nbytes_copied = 0
# This is a bit complicated to deal with partial elements
# at both the start and the end. LLVM will fold as appropriate,
# once it knows the data layout
while nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tail(inds)...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(tptr, unsafe_load(sptr, sidx + 1), nbytes_copied + 1)
sidx += 1
nbytes_copied += 1
end
sidx = 0
i += 1
end
end
return t[]
end
end

@inline @propagate_inbounds setindex!(a::ReinterpretArray{T,0,S} where T, v) where {S} = (a.parent[] = reinterpret(S, v))
@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = (a[1] = v)

@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
v = convert(T, v)::T
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
return setindex!(a.parent, reinterpret(S, v), inds...)
else
ind_start, sidx = divrem((inds[1]-1)*sizeof(T), sizeof(S))
t = Ref{T}(v)
s = Ref{S}()
@gc_preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
nbytes_copied = 0
i = 1
# Deal with any partial elements at the start. We'll have to copy in the
# element from the original array and overwrite the relevant parts
if sidx != 0
s[] = a.parent[ind_start + i, tail(inds)...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
a.parent[ind_start + i, tail(inds)...] = s[]
i += 1
sidx = 0
end
# Deal with the main body of elements
while nbytes_copied < sizeof(T) && (sizeof(T) - nbytes_copied) > sizeof(S)
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
a.parent[ind_start + i, tail(inds)...] = s[]
i += 1
sidx = 0
end
# Deal with trailing partial elements
if nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tail(inds)...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
a.parent[ind_start + i, tail(inds)...] = s[]
end
end
end
return a
end
6 changes: 6 additions & 0 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1888,6 +1888,12 @@ function showarg(io::IO, r::ReshapedArray, toplevel)
toplevel && print(io, " with eltype ", eltype(r))
end

function showarg(io::IO, r::ReinterpretArray{T}, toplevel) where {T}
print(io, "reinterpret($T, ")
showarg(io, parent(r), false)
print(io, ')')
end

# n-dimensional arrays
function show_nd(io::IO, a::AbstractArray, print_matrix, label_slices)
limit::Bool = get(io, :limit, false)
Expand Down
Loading

2 comments on commit 8273529

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily benchmark build, I will reply here when finished:

@nanosoldier runbenchmarks(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something went wrong when running your job:

NanosoldierError: failed to run benchmarks against primary commit: failed process: Process(`sudo cset shield -e su nanosoldier -- -c ./benchscript.sh`, ProcessExited(1)) [1]

Logs and partial data can be found here
cc @ararslan

Please sign in to comment.