From 570170a204cab099203946d5fb0c5ea2a535c4d7 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Mon, 30 Oct 2023 20:42:22 +0000 Subject: [PATCH] IO: fix API safety issue for Ptr Align the API for Ref with the new definition for AbstractArray (#49769) and ensure this API does not accept raw Ptr as input. Refs #42593 --- base/io.jl | 27 +++++++++++++++++++++------ test/read.jl | 18 ++++++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/base/io.jl b/base/io.jl index bc87f5df9be7ba..37255b069fcc22 100644 --- a/base/io.jl +++ b/base/io.jl @@ -780,10 +780,17 @@ end @noinline unsafe_write(s::IO, p::Ref{T}, n::Integer) where {T} = unsafe_write(s, unsafe_convert(Ref{T}, p)::Ptr, n) # mark noinline to ensure ref is gc-rooted somewhere (by the caller) unsafe_write(s::IO, p::Ptr, n::Integer) = unsafe_write(s, convert(Ptr{UInt8}, p), convert(UInt, n)) -write(s::IO, x::Ref{T}) where {T} = unsafe_write(s, x, Core.sizeof(T)) +function write(s::IO, x::Ref{T}) where {T} + x isa Ptr && error("write cannot copy from a Ptr") + if isbitstype(T) + unsafe_write(s, x, Core.sizeof(T)) + else + write(s, x[]) + end +end write(s::IO, x::Int8) = write(s, reinterpret(UInt8, x)) function write(s::IO, x::Union{Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128,Float16,Float32,Float64}) - return write(s, Ref(x)) + return unsafe_write(s, Ref(x), Core.sizeof(x)) end write(s::IO, x::Bool) = write(s, UInt8(x)) @@ -797,7 +804,7 @@ function write(s::IO, A::AbstractArray) r = Ref{eltype(A)}() for a in A r[] = a - nb += @noinline unsafe_write(s, r, sizeof(r)) # r must be heap-allocated + nb += @noinline unsafe_write(s, r, Core.sizeof(r)) # r must be heap-allocated end return nb end @@ -861,11 +868,19 @@ end @noinline unsafe_read(s::IO, p::Ref{T}, n::Integer) where {T} = unsafe_read(s, unsafe_convert(Ref{T}, p)::Ptr, n) # mark noinline to ensure ref is gc-rooted somewhere (by the caller) unsafe_read(s::IO, p::Ptr, n::Integer) = unsafe_read(s, convert(Ptr{UInt8}, p), convert(UInt, n)) -read!(s::IO, x::Ref{T}) where {T} = (unsafe_read(s, x, Core.sizeof(T)); x) +function read!(s::IO, x::Ref{T}) where {T} + x isa Ptr && error("read! cannot copy into a Ptr") + if isbitstype(T) + unsafe_read(s, x, Core.sizeof(T)) + else + x[] = read(s, T) + end + return x +end read(s::IO, ::Type{Int8}) = reinterpret(Int8, read(s, UInt8)) function read(s::IO, T::Union{Type{Int16},Type{UInt16},Type{Int32},Type{UInt32},Type{Int64},Type{UInt64},Type{Int128},Type{UInt128},Type{Float16},Type{Float32},Type{Float64}}) - return read!(s, Ref{T}(0))[]::T + return unsafe_read(s, Ref{T}(0), Core.sizeof(T))[]::T end read(s::IO, ::Type{Bool}) = (read(s, UInt8) != 0) @@ -878,7 +893,7 @@ function read!(s::IO, A::AbstractArray{T}) where {T} if isbitstype(T) r = Ref{T}() for i in eachindex(A) - @noinline unsafe_read(s, r, sizeof(r)) # r must be heap-allocated + @noinline unsafe_read(s, r, Core.sizeof(r)) # r must be heap-allocated A[i] = r[] end else diff --git a/test/read.jl b/test/read.jl index 3f22dd55a507a5..283381668c28a8 100644 --- a/test/read.jl +++ b/test/read.jl @@ -720,3 +720,21 @@ end @test isempty(r) && isempty(collect(r)) end end + +@testset "Ref API" begin + io = PipeBuffer() + @test write(io, Ref{Any}(0xabcd_1234)) === 4 + @test read(io, UInt32) === 0xabcd_1234 + @test_throws ErrorException("write cannot copy from a Ptr") invoke(write, Tuple{typeof(io), Ref{Cvoid}}, io, C_NULL) + @test_throws ErrorException("write cannot copy from a Ptr") invoke(write, Tuple{typeof(io), Ref{Int}}, io, Ptr{Int}(0)) + @test_throws ErrorException("write cannot copy from a Ptr") invoke(write, Tuple{typeof(io), Ref{Any}}, io, Ptr{Any}(0)) + @test_throws ErrorException("read! cannot copy into a Ptr") read!(io, C_NULL) + @test_throws ErrorException("read! cannot copy into a Ptr") read!(io, Ptr{Int}(0)) + @test_throws ErrorException("read! cannot copy into a Ptr") read!(io, Ptr{Any}(0)) + @test eof(io) + @test write(io, C_NULL) === sizeof(Int) + @test write(io, Ptr{Int}(4)) === sizeof(Int) + @test write(io, Ptr{Any}(5)) === sizeof(Int) + @test read!(io, Int[1, 2, 3]) == [0, 4, 5] + @test eof(io) +end