Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't allow reinterprets that would expose padding #27877

Merged
merged 1 commit into from
Jul 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions base/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ inttype(::Type{Float32}) = Int32
inttype(::Type{Float64}) = Int64


alignment(::Type{T}) where {T} = ccall(:jl_alignment, Cint, (Csize_t,), sizeof(T))
gc_alignment(::Type{T}) where {T} = ccall(:jl_alignment, Cint, (Csize_t,), sizeof(T))

# All atomic operations have acquire and/or release semantics, depending on
# whether the load or store values. Most of the time, this is what one wants
Expand All @@ -350,13 +350,13 @@ for typ in atomictypes
@eval getindex(x::Atomic{$typ}) =
llvmcall($"""
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
%rv = load atomic $rt %ptr acquire, align $(alignment(typ))
%rv = load atomic $rt %ptr acquire, align $(gc_alignment(typ))
ret $lt %rv
""", $typ, Tuple{Ptr{$typ}}, unsafe_convert(Ptr{$typ}, x))
@eval setindex!(x::Atomic{$typ}, v::$typ) =
llvmcall($"""
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
store atomic $lt %1, $lt* %ptr release, align $(alignment(typ))
store atomic $lt %1, $lt* %ptr release, align $(gc_alignment(typ))
ret void
""", Cvoid, Tuple{Ptr{$typ}, $typ}, unsafe_convert(Ptr{$typ}, x), v)

Expand Down
3 changes: 3 additions & 0 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,9 @@ mutable struct Stateful{T, VS}
# A bit awkward right now, but adapted to the new iteration protocol
nextvalstate::Union{VS, Nothing}
taken::Int
@inline function Stateful{<:Any, Any}(itr::T) where {T}
new{T, Any}(itr, iterate(itr), 0)
end
@inline function Stateful(itr::T) where {T}
VS = approx_iter_type(T)
new{T, VS}(itr, iterate(itr)::VS, 0)
Expand Down
123 changes: 122 additions & 1 deletion base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ the first dimension.
"""
struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
parent::A
readable::Bool
writable::Bool
global reinterpret
function reinterpret(::Type{T}, a::A) where {T,N,S,A<:AbstractArray{S, N}}
function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
Expand All @@ -31,7 +33,28 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
dim = size(a)[1]
rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim)
end
new{T, N, S, A}(a)
readable = array_subpadding(T, S)
writable = array_subpadding(S, T)
new{T, N, S, A}(a, readable, writable)
end
end

function check_readable(a::ReinterpretArray{T, N, S} where N) where {T,S}
# See comment in check_writable
if !a.readable && !array_subpadding(T, S)
throw(PaddingError(T, S))
end
end

function check_writable(a::ReinterpretArray{T, N, S} where N) where {T,S}
# `array_subpadding` is relatively expensive (compared to a simple arrayref),
# so it is cached in the array. However, it is computable at compile time if,
# inference has the types available. By using this form of the check, we can
# get the best of both worlds for the success case. If the types were not
# available to inference, we simply need to check the field (relatively cheap)
# and if they were we should be able to fold this check away entirely.
if !a.writable && !array_subpadding(S, T)
throw(PaddingError(T, S))
end
end

Expand All @@ -53,10 +76,12 @@ unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} =
@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[1]

@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
check_readable(a)
_getindex_ra(a, inds[1], tail(inds))
end

@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, i::Int) where {T,N,S}
check_readable(a)
if isa(IndexStyle(a), IndexLinear)
return _getindex_ra(a, i, ())
end
Expand Down Expand Up @@ -102,10 +127,12 @@ end
@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = (a[1] = v)

@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
check_writable(a)
_setindex_ra!(a, v, inds[1], tail(inds))
end

@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, i::Int) where {T,N,S}
check_writable(a)
if isa(IndexStyle(a), IndexLinear)
return _setindex_ra!(a, v, i, ())
end
Expand Down Expand Up @@ -165,3 +192,97 @@ end
end
return a
end

# Padding
struct Padding
offset::Int
size::Int
end
function intersect(p1::Padding, p2::Padding)
start = max(p1.offset, p2.offset)
stop = min(p1.offset + p1.size, p2.offset + p2.size)
Padding(start, max(0, stop-start))
end

struct PaddingError
S::Type
T::Type
end

function showerror(io::IO, p::PaddingError)
print(io, "Padding of type $(p.S) is not compatible with type $(p.T).")
end

"""
CyclePadding(padding, total_size)

Cylces an iterator of `Padding` structs, restarting the padding at `total_size`.
E.g. if `padding` is all the padding in a struct and `total_size` is the total
aligned size of that array, `CyclePadding` will correspond to the padding in an
infinite vector of such structs.
"""
struct CyclePadding{P}
padding::P
total_size::Int
end
eltype(::Type{<:CyclePadding}) = Padding
IteratorSize(::Type{<:CyclePadding}) = IsInfinite()
isempty(cp::CyclePadding) = isempty(cp.padding)
function iterate(cp::CyclePadding)
y = iterate(cp.padding)
y === nothing && return nothing
y[1], (0, y[2])
end
function iterate(cp::CyclePadding, state::Tuple)
y = iterate(cp.padding, tail(state)...)
y === nothing && return iterate(cp, (state[1]+cp.total_size,))
Padding(y[1].offset+state[1], y[1].size), (state[1], tail(y)...)
end

"""
Compute the location of padding in a type.
"""
function padding(T)
padding = Padding[]
last_end::Int = 0
for i = 1:fieldcount(T)
offset = fieldoffset(T, i)
fT = fieldtype(T, i)
if offset != last_end
push!(padding, Padding(offset, offset-last_end))
end
last_end = offset + sizeof(fT)
end
padding
end

function CyclePadding(T::DataType)
a, s = datatype_alignment(T), sizeof(T)
as = s + (a - (s % a)) % a
pad = padding(T)
s != as && push!(pad, Padding(s, as - s))
CyclePadding(pad, as)
end

using .Iterators: Stateful
@pure function array_subpadding(S, T)
checked_size = 0
lcm_size = lcm(sizeof(S), sizeof(T))
s, t = Stateful{<:Any, Any}(CyclePadding(S)),
Stateful{<:Any, Any}(CyclePadding(T))
isempty(t) && return true
isempty(s) && return false
while checked_size < lcm_size
# Take padding in T
pad = popfirst!(t)
# See if there's corresponding padding in S
while true
ps = peek(s)
ps.offset > pad.offset && return false
intersect(ps, pad) == pad && break
popfirst!(s)
end
checked_size = pad.offset + pad.size
end
return true
end
3 changes: 1 addition & 2 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ include("array.jl")
include("abstractarray.jl")
include("subarray.jl")
include("views.jl")
include("reinterpretarray.jl")


# ## dims-type-converting Array constructors for convenience
# type and dimensionality specified, accepting dims as series of Integers
Expand Down Expand Up @@ -205,6 +203,7 @@ include("reduce.jl")

## core structures
include("reshapedarray.jl")
include("reinterpretarray.jl")
include("bitarray.jl")
include("bitset.jl")

Expand Down
19 changes: 19 additions & 0 deletions test/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,22 @@ let a = fill(1.0, 5, 3)
@test all(a[1:2:5,:] .=== reinterpret(Float64, [Int64(4)])[1])
@test all(r .=== Int64(4))
end

# Error on reinterprets that would expose padding
struct S1
a::Int8
b::Int64
end

struct S2
a::Int16
b::Int64
end

A1 = S1[S1(0, 0)]
A2 = S2[S2(0, 0)]
@test reinterpret(S1, A2)[1] == S1(0, 0)
@test_throws Base.PaddingError (reinterpret(S1, A2)[1] = S2(1, 2))
@test_throws Base.PaddingError reinterpret(S2, A1)[1]
reinterpret(S2, A1)[1] = S2(1, 2)
@test A1[1] == S1(1, 2)