Skip to content

Commit

Permalink
Refactor scalar range getindex (#50467)
Browse files Browse the repository at this point in the history
  • Loading branch information
LilithHafner authored Aug 30, 2023
1 parent fb76136 commit 6f026e3
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 49 deletions.
6 changes: 1 addition & 5 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1288,11 +1288,7 @@ end
# To avoid invalidations from multidimensional.jl: getindex(A::Array, i1::Union{Integer, CartesianIndex}, I::Union{Integer, CartesianIndex}...)
@propagate_inbounds getindex(A::Array, i1::Integer, I::Integer...) = A[to_indices(A, (i1, I...))...]

function unsafe_getindex(A::AbstractArray, I...)
@inline
@inbounds r = getindex(A, I...)
r
end
unsafe_getindex(A::AbstractArray, I...) = @inbounds getindex(A, I...)

struct CanonicalIndexError <: Exception
func::String
Expand Down
9 changes: 9 additions & 0 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2859,3 +2859,12 @@ function intersect(v::AbstractVector, r::AbstractRange)
return vectorfilter(T, _shrink_filter!(seen), common)
end
intersect(r::AbstractRange, v::AbstractVector) = intersect(v, r)

# Here instead of range.jl for bootstrapping because `@propagate_inbounds` depends on Vectors.
@propagate_inbounds function getindex(v::AbstractRange, i::Integer)
if i isa Bool # Not via dispatch to avoid ambiguities
throw(ArgumentError("invalid index: $i of type Bool"))
else
_getindex(v, i)
end
end
63 changes: 21 additions & 42 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -910,11 +910,15 @@ function isassigned(r::AbstractRange, i::Integer)
firstindex(r) <= i <= lastindex(r)
end

# `_getindex` is like `getindex` but does not check if `i isa Bool`
function _getindex(v::AbstractRange, i::Integer)
@boundscheck checkbounds(v, i)
unsafe_getindex(v, i)
end

_in_unit_range(v::UnitRange, val, i::Integer) = i > 0 && val <= v.stop && val >= v.start

function getindex(v::UnitRange{T}, i::Integer) where T
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
function _getindex(v::UnitRange{T}, i::Integer) where T
val = convert(T, v.start + (i - oneunit(i)))
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
val
Expand All @@ -923,68 +927,38 @@ end
const OverflowSafe = Union{Bool,Int8,Int16,Int32,Int64,Int128,
UInt8,UInt16,UInt32,UInt64,UInt128}

function getindex(v::UnitRange{T}, i::Integer) where {T<:OverflowSafe}
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
function _getindex(v::UnitRange{T}, i::Integer) where {T<:OverflowSafe}
val = v.start + (i - oneunit(i))
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
val % T
end

function getindex(v::OneTo{T}, i::Integer) where T
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
@boundscheck ((i > 0) & (i <= v.stop)) || throw_boundserror(v, i)
convert(T, i)
end

function getindex(v::AbstractRange{T}, i::Integer) where T
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
@boundscheck checkbounds(v, i)
convert(T, first(v) + (i - oneunit(i))*step_hp(v))
end

let BitInteger64 = Union{Int8,Int16,Int32,Int64,UInt8,UInt16,UInt32,UInt64} # for bootstrapping
function checkbounds(::Type{Bool}, v::StepRange{<:BitInteger64, <:BitInteger64}, i::BitInteger64)
@inline
res = widemul(step(v), i-oneunit(i)) + first(v)
(0 < i) & ifelse(0 < step(v), res <= last(v), res >= last(v))
end
end

function getindex(r::Union{StepRangeLen,LinRange}, i::Integer)
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
@boundscheck checkbounds(r, i)
unsafe_getindex(r, i)
end

# This is separate to make it useful even when running with --check-bounds=yes
# unsafe_getindex is separate to make it useful even when running with --check-bounds=yes
# it assumes the index is inbounds but does not segfault even if the index is out of bounds.
# it does not check if the index isa bool.
unsafe_getindex(v::OneTo{T}, i::Integer) where T = convert(T, i)
unsafe_getindex(v::AbstractRange{T}, i::Integer) where T = convert(T, first(v) + (i - oneunit(i))*step_hp(v))
function unsafe_getindex(r::StepRangeLen{T}, i::Integer) where T
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = oftype(r.offset, i) - r.offset
T(r.ref + u*r.step)
end

function _getindex_hiprec(r::StepRangeLen, i::Integer) # without rounding by T
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = oftype(r.offset, i) - r.offset
r.ref + u*r.step
end

function unsafe_getindex(r::LinRange, i::Integer)
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
lerpi(i-oneunit(i), r.lendiv, r.start, r.stop)
end
unsafe_getindex(r::LinRange, i::Integer) = lerpi(i-oneunit(i), r.lendiv, r.start, r.stop)

function lerpi(j::Integer, d::Integer, a::T, b::T) where T
@inline
t = j/d # ∈ [0,1]
# compute approximately fma(t, b, -fma(t, a, a))
return T((1-t)*a + t*b)
end

# non-scalar indexing

getindex(r::AbstractRange, ::Colon) = copy(r)

function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integer}
Expand Down Expand Up @@ -1083,6 +1057,11 @@ function getindex(r::StepRangeLen{T}, s::OrdinalRange{S}) where {T, S<:Integer}
end
end

function _getindex_hiprec(r::StepRangeLen, i::Integer) # without rounding by T
u = oftype(r.offset, i) - r.offset
r.ref + u*r.step
end

function getindex(r::LinRange{T}, s::OrdinalRange{S}) where {T, S<:Integer}
@inline
@boundscheck checkbounds(r, s)
Expand Down
2 changes: 0 additions & 2 deletions base/twiceprecision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,6 @@ end
# This assumes that r.step has already been split so that (0:len-1)*r.step.hi is exact
function unsafe_getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, i::Integer) where T
# Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = oftype(r.offset, i) - r.offset
shift_hi, shift_lo = u*r.step.hi, u*r.step.lo
x_hi, x_lo = add12(r.ref.hi, shift_hi)
Expand Down

5 comments on commit 6f026e3

@vtjnash
Copy link
Member

@vtjnash vtjnash commented on 6f026e3 Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nanosoldier runbenchmarks("array", vs="@fb7613635cab77cf269790335e8121f513c9ea96")

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here.

@vtjnash
Copy link
Member

@vtjnash vtjnash commented on 6f026e3 Sep 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LilithHafner it seems maybe the removal of inbounds was a bit costly on some benchmarks. Could you see if those are real regressions, and if we should do something about them?

@LilithHafner
Copy link
Member Author

@LilithHafner LilithHafner commented on 6f026e3 Sep 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was the removal of inline, not inbounds; I can verify that at least some of them are real; and yes, should be an easy fix. Thanks for @ ing me.

@LilithHafner
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.