Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix non numeric accumulate(op, v0, x) (#25506) #25515

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 8 additions & 19 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -685,17 +685,6 @@ _iterable(v) = Iterators.repeated(v)
end
end

##

# see discussion in #18364 ... we try not to widen type of the resulting array
# from cumsum or cumprod, but in some cases (+, Bool) we may not have a choice.
rcum_promote_type(op, ::Type{T}, ::Type{S}) where {T,S<:Number} = promote_op(op, T, S)
rcum_promote_type(op, ::Type{T}) where {T<:Number} = rcum_promote_type(op, T,T)
rcum_promote_type(op, ::Type{T}) where {T} = T

# handle sums of Vector{Bool} and similar. it would be nice to handle
# any AbstractArray here, but it's not clear how that would be possible
rcum_promote_type(op, ::Type{Array{T,N}}) where {T,N} = Array{rcum_promote_type(op,T), N}

# accumulate_pairwise slightly slower then accumulate, but more numerically
# stable in certain situations (e.g. sums).
Expand Down Expand Up @@ -723,14 +712,14 @@ function accumulate_pairwise!(op::Op, result::AbstractVector, v::AbstractVector)
n = length(li)
n == 0 && return result
i1 = first(li)
@inbounds result[i1] = v1 = v[i1]
@inbounds result[i1] = v1 = reduce_first(op,v[i1])
n == 1 && return result
_accumulate_pairwise!(op, result, v, v1, i1+1, n-1)
return result
end

function accumulate_pairwise(op, v::AbstractVector{T}) where T
out = similar(v, rcum_promote_type(op, T))
out = similar(v, promote_op(op, T, T))
return accumulate_pairwise!(op, out, v)
end

Expand Down Expand Up @@ -775,7 +764,7 @@ julia> cumsum(a,2)
```
"""
function cumsum(A::AbstractArray{T}, dim::Integer) where T
out = similar(A, rcum_promote_type(+, T))
out = similar(A, promote_op(+, T, T))
cumsum!(out, A, dim)
end

Expand Down Expand Up @@ -909,7 +898,7 @@ julia> accumulate(+, fill(1, 3, 3), 2)
```
"""
function accumulate(op, A, dim::Integer)
out = similar(A, rcum_promote_type(op, eltype(A)))
out = similar(A, promote_op(op, eltype(A), eltype(A)))
accumulate!(op, out, A, dim)
end

Expand Down Expand Up @@ -977,7 +966,7 @@ function accumulate!(op, B, A, dim::Integer)
# register usage and will be slightly faster
ind1 = inds_t[1]
@inbounds for I in CartesianIndices(tail(inds_t))
tmp = convert(eltype(B), A[first(ind1), I])
tmp = reduce_first(op, A[first(ind1), I])
B[first(ind1), I] = tmp
for i_1 = first(ind1)+1:last(ind1)
tmp = op(tmp, A[i_1, I])
Expand Down Expand Up @@ -1027,7 +1016,7 @@ end
# Copy the initial element in each 1d vector along dimension `dim`
ii = first(ind)
@inbounds for J in R2, I in R1
B[I, ii, J] = A[I, ii, J]
B[I, ii, J] = reduce_first(op, A[I, ii, J])
end
# Accumulate
@inbounds for J in R2, i in first(ind)+1:last(ind), I in R1
Expand Down Expand Up @@ -1058,7 +1047,7 @@ julia> accumulate(min, 0, [1,2,-1])
```
"""
function accumulate(op, v0, x::AbstractVector)
T = rcum_promote_type(op, typeof(v0), eltype(x))
T = promote_op(op, typeof(v0), eltype(x))
out = similar(x, T)
accumulate!(op, out, v0, x)
end
Expand All @@ -1075,7 +1064,7 @@ function _accumulate1!(op, B, v1, A::AbstractVector, dim::Integer)
inds == linearindices(B) || throw(DimensionMismatch("linearindices of A and B don't match"))
dim > 1 && return copyto!(B, A)
i1 = inds[1]
cur_val = v1
cur_val = reduce_first(op, v1)
B[i1] = cur_val
@inbounds for i in inds[2:end]
cur_val = op(cur_val, A[i])
Expand Down
7 changes: 7 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2152,6 +2152,13 @@ end
op(x,y) = 2x+y
@test accumulate(op, [10,20, 30]) == [10, op(10, 20), op(op(10, 20), 30)] == [10, 40, 110]
@test accumulate(op, [10 20 30], 2) == [10 op(10, 20) op(op(10, 20), 30)] == [10 40 110]

#25506
@test accumulate((acc, x) -> acc+x[1], 0, [(1,2), (3,4), (5,6)]) == [1, 4, 9]
@test accumulate(*, ['a', 'b']) == ["a", "ab"]
@inferred accumulate(*, String[])
@test accumulate(*, ['a' 'b'; 'c' 'd'], 1) == ["a" "b"; "ac" "bd"]
@test accumulate(*, ['a' 'b'; 'c' 'd'], 2) == ["a" "ab"; "c" "cd"]
end

struct F21666{T <: Base.ArithmeticStyle}
Expand Down