Skip to content

Commit

Permalink
use TypeArithmetic trait in cumsum! implementation (#21666)
Browse files Browse the repository at this point in the history
* use TypeArithmetic trait in cumsum! implementation
  • Loading branch information
jw3126 authored and Sacha0 committed May 9, 2017
1 parent 25f241c commit d214d57
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
15 changes: 10 additions & 5 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -574,13 +574,18 @@ function accumulate_pairwise(op, v::AbstractVector{T}) where T
end

function cumsum!(out, v::AbstractVector, axis::Integer=1)
# for types prone to numerical stability issues, we want
# accumulate_pairwise.
axis == 1 ? accumulate_pairwise!(+, out, v) : copy!(out,v)
# we dispatch on the possibility of numerical stability issues
_cumsum!(out, v, axis, TypeArithmetic(eltype(out)))
end

function cumsum!(out, v::AbstractVector{<:Integer}, axis::Integer=1)
axis == 1 ? accumulate!(+, out, v) : copy!(out,v)
function _cumsum!(out, v, axis, ::ArithmeticRounds)
axis == 1 ? accumulate_pairwise!(+, out, v) : copy!(out, v)
end
function _cumsum!(out, v, axis, ::ArithmeticUnknown)
_cumsum!(out, v, axis, ArithmeticRounds())
end
function _cumsum!(out, v, axis, ::TypeArithmetic)
axis == 1 ? accumulate!(+, out, v) : copy!(out, v)
end

"""
Expand Down
22 changes: 22 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2055,6 +2055,28 @@ end
@test accumulate(op, [10 20 30], 2) == [10 op(10, 20) op(op(10, 20), 30)] == [10 40 110]
end

struct F21666{T <: Base.TypeArithmetic}
x::Float32
end

@testset "Exactness of cumsum # 21666" begin
# test that cumsum uses more stable algorithm
# for types with unknown/rounding arithmetic
Base.TypeArithmetic(::Type{F21666{T}}) where {T} = T
Base.:+(x::F, y::F) where {F <: F21666} = F(x.x + y.x)
Base.convert(::Type{Float64}, x::F21666) = Float64(x.x)
# we make v pretty large, because stable algorithm may have a large base case
v = zeros(300); v[1] = 2; v[200:end] = eps(Float32)

f_rounds = Float64.(cumsum(F21666{Base.ArithmeticRounds}.(v)))
f_unknown = Float64.(cumsum(F21666{Base.ArithmeticUnknown}.(v)))
f_truth = cumsum(v)
f_inexact = Float64.(accumulate(+, Float32.(v)))
@test f_rounds == f_unknown
@test f_rounds != f_inexact
@test norm(f_truth - f_rounds) < norm(f_truth - f_inexact)
end

@testset "zeros and ones" begin
@test ones([1,2], Float64, (2,3)) == ones(2,3)
@test ones(2) == ones(Int, 2) == ones([2,3], Float32, 2) == [1,1]
Expand Down

0 comments on commit d214d57

Please sign in to comment.