Skip to content

Commit

Permalink
use pairs in findmin and findmax, supporting all indexable coll…
Browse files Browse the repository at this point in the history
…ections

return `CartesianIndex` for n-d arrays in findmin, findmax, indmin, indmax

more compact printing of `CartesianIndex`

change sparse `_findr` macro to a function
  • Loading branch information
JeffBezanson committed Sep 1, 2017
1 parent 099de5b commit 5f7efae
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 119 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ This section lists changes that do not have deprecation warnings.
This avoids stack overflows in the common case of definitions like
`f(x, y) = f(promote(x, y)...)` ([#22801]).

* `findmin`, `findmax`, `indmin`, and `indmax` used to always return linear indices.
They now return `CartesianIndex`es for all but 1-d arrays, and in general return
the `keys` of indexed collections (e.g. dictionaries) ([#22907]).

Library improvements
--------------------

Expand Down
24 changes: 12 additions & 12 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2072,13 +2072,13 @@ function findmax(a)
if isempty(a)
throw(ArgumentError("collection must be non-empty"))
end
s = start(a)
mi = i = 1
m, s = next(a, s)
while !done(a, s)
p = pairs(a)
s = start(p)
(mi, m), s = next(p, s)
i = mi
while !done(p, s)
m != m && break
ai, s = next(a, s)
i += 1
(i, ai), s = next(p, s)
if ai != ai || isless(m, ai)
m = ai
mi = i
Expand Down Expand Up @@ -2113,13 +2113,13 @@ function findmin(a)
if isempty(a)
throw(ArgumentError("collection must be non-empty"))
end
s = start(a)
mi = i = 1
m, s = next(a, s)
while !done(a, s)
p = pairs(a)
s = start(p)
(mi, m), s = next(p, s)
i = mi
while !done(p, s)
m != m && break
ai, s = next(a, s)
i += 1
(i, ai), s = next(p, s)
if ai != ai || isless(ai, m)
m = ai
mi = i
Expand Down
3 changes: 2 additions & 1 deletion base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
module IteratorsMD
import Base: eltype, length, size, start, done, next, first, last, in, getindex,
setindex!, IndexStyle, min, max, zero, one, isless, eachindex,
ndims, iteratorsize, convert
ndims, iteratorsize, convert, show

import Base: +, -, *
import Base: simd_outer_range, simd_inner_length, simd_index
Expand Down Expand Up @@ -80,6 +80,7 @@ module IteratorsMD
@inline _flatten(i, I...) = (i, _flatten(I...)...)
@inline _flatten(i::CartesianIndex, I...) = (i.I..., _flatten(I...)...)
CartesianIndex(index::Tuple{Vararg{Union{Integer, CartesianIndex}}}) = CartesianIndex(index...)
show(io::IO, i::CartesianIndex) = (print(io, "CartesianIndex"); show(io, i.I))

# length
length(::CartesianIndex{N}) where {N} = N
Expand Down
24 changes: 13 additions & 11 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -635,20 +635,22 @@ function findminmax!(f, Rval, Rind, A::AbstractArray{T,N}) where {T,N}
# Otherwise, keep the result in Rval/Rind so that we traverse A in storage order.
indsAt, indsRt = safe_tail(indices(A)), safe_tail(indices(Rval))
keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt)
k = 0
ks = keys(A)
k, kss = next(ks, start(ks))
zi = zero(eltype(ks))
if reducedim1(Rval, A)
i1 = first(indices1(Rval))
@inbounds for IA in CartesianRange(indsAt)
IR = Broadcast.newindex(IA, keep, Idefault)
tmpRv = Rval[i1,IR]
tmpRi = Rind[i1,IR]
for i in indices(A,1)
k += 1
tmpAv = A[i,IA]
if tmpRi == 0 || (tmpRv == tmpRv && (tmpAv != tmpAv || f(tmpAv, tmpRv)))
if tmpRi == zi || (tmpRv == tmpRv && (tmpAv != tmpAv || f(tmpAv, tmpRv)))
tmpRv = tmpAv
tmpRi = k
end
k, kss = next(ks, kss)
end
Rval[i1,IR] = tmpRv
Rind[i1,IR] = tmpRi
Expand All @@ -657,14 +659,14 @@ function findminmax!(f, Rval, Rind, A::AbstractArray{T,N}) where {T,N}
@inbounds for IA in CartesianRange(indsAt)
IR = Broadcast.newindex(IA, keep, Idefault)
for i in indices(A, 1)
k += 1
tmpAv = A[i,IA]
tmpRv = Rval[i,IR]
tmpRi = Rind[i,IR]
if tmpRi == 0 || (tmpRv == tmpRv && (tmpAv != tmpAv || f(tmpAv, tmpRv)))
if tmpRi == zi || (tmpRv == tmpRv && (tmpAv != tmpAv || f(tmpAv, tmpRv)))
Rval[i,IR] = tmpAv
Rind[i,IR] = k
end
k, kss = next(ks, kss)
end
end
end
Expand All @@ -680,7 +682,7 @@ dimensions of `rval` and `rind`, and store the results in `rval` and `rind`.
"""
function findmin!(rval::AbstractArray, rind::AbstractArray, A::AbstractArray;
init::Bool=true)
findminmax!(isless, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,0), A)
findminmax!(isless, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A)
end

"""
Expand Down Expand Up @@ -709,10 +711,10 @@ function findmin(A::AbstractArray{T}, region) where T
if prod(map(length, reduced_indices(A, region))) != 0
throw(ArgumentError("collection slices must be non-empty"))
end
(similar(A, ri), similar(dims->zeros(Int, dims), ri))
(similar(A, ri), similar(dims->zeros(eltype(keys(A)), dims), ri))
else
findminmax!(isless, fill!(similar(A, ri), first(A)),
similar(dims->zeros(Int, dims), ri), A)
similar(dims->zeros(eltype(keys(A)), dims), ri), A)
end
end

Expand All @@ -727,7 +729,7 @@ dimensions of `rval` and `rind`, and store the results in `rval` and `rind`.
"""
function findmax!(rval::AbstractArray, rind::AbstractArray, A::AbstractArray;
init::Bool=true)
findminmax!(isgreater, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,0), A)
findminmax!(isgreater, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A)
end

"""
Expand Down Expand Up @@ -756,10 +758,10 @@ function findmax(A::AbstractArray{T}, region) where T
if prod(map(length, reduced_indices(A, region))) != 0
throw(ArgumentError("collection slices must be non-empty"))
end
similar(A, ri), similar(dims->zeros(Int, dims), ri)
similar(A, ri), similar(dims->zeros(eltype(keys(A)), dims), ri)
else
findminmax!(isgreater, fill!(similar(A, ri), first(A)),
similar(dims->zeros(Int, dims), ri), A)
similar(dims->zeros(eltype(keys(A)), dims), ri), A)
end
end

Expand Down
79 changes: 39 additions & 40 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1815,101 +1815,100 @@ function _findz(A::SparseMatrixCSC{Tv,Ti}, rows=1:A.m, cols=1:A.n) where {Tv,Ti}
row = 0
rowmin = rows[1]; rowmax = rows[end]
allrows = (rows == 1:A.m)
@inbounds for col in cols
@inbounds for col in cols
r1::Int = colptr[col]
r2::Int = colptr[col+1] - 1
if !allrows && (r1 <= r2)
r1 = searchsortedfirst(rowval, rowmin, r1, r2, Forward)
(r1 <= r2 ) && (r2 = searchsortedlast(rowval, rowmax, r1, r2, Forward))
end
row = rowmin

while (r1 <= r2) && (row == rowval[r1]) && (nzval[r1] != zval)
r1 += 1
row += 1
end
(row <= rowmax) && (return sub2ind(size(A), row, col))
(row <= rowmax) && (return CartesianIndex(row, col))
end
return 0
return CartesianIndex(0, 0)
end

macro _findr(op, A, region, Tv, Ti)
esc(quote
N = nnz($A)
L = length($A)
function _findr(op, A, region, Tv)
Ti = eltype(keys(A))
i1 = first(keys(A))
N = nnz(A)
L = length(A)
if L == 0
if prod(map(length, Base.reduced_indices($A, $region))) != 0
if prod(map(length, Base.reduced_indices(A, region))) != 0
throw(ArgumentError("array slices must be non-empty"))
else
ri = Base.reduced_indices0($A, $region)
return (similar($A, ri), similar(dims->zeros(Int, dims), ri))
ri = Base.reduced_indices0(A, region)
return (similar(A, ri), similar(dims->zeros(Ti, dims), ri))
end
end

colptr = $A.colptr; rowval = $A.rowval; nzval = $A.nzval; m = $A.m; n = $A.n
zval = zero($Tv)
szA = size($A)
colptr = A.colptr; rowval = A.rowval; nzval = A.nzval; m = A.m; n = A.n
zval = zero(Tv)
szA = size(A)

if $region == 1 || $region == (1,)
(N == 0) && (return (fill(zval,1,n), fill(convert($Ti,1),1,n)))
S = Vector{$Tv}(n); I = Vector{$Ti}(n)
if region == 1 || region == (1,)
(N == 0) && (return (fill(zval,1,n), fill(i1,1,n)))
S = Vector{Tv}(n); I = Vector{Ti}(n)
@inbounds for i = 1 : n
Sc = zval; Ic = _findz($A, 1:m, i:i)
if Ic == 0
Sc = zval; Ic = _findz(A, 1:m, i:i)
if Ic == CartesianIndex(0, 0)
j = colptr[i]
Ic = sub2ind(szA, rowval[j], i)
Ic = CartesianIndex(rowval[j], i)
Sc = nzval[j]
end
for j = colptr[i] : colptr[i+1]-1
if ($op)(nzval[j], Sc)
if op(nzval[j], Sc)
Sc = nzval[j]
Ic = sub2ind(szA, rowval[j], i)
Ic = CartesianIndex(rowval[j], i)
end
end
S[i] = Sc; I[i] = Ic
end
return(reshape(S,1,n), reshape(I,1,n))
elseif $region == 2 || $region == (2,)
(N == 0) && (return (fill(zval,m,1), fill(convert($Ti,1),m,1)))
S = Vector{$Tv}(m); I = Vector{$Ti}(m)
elseif region == 2 || region == (2,)
(N == 0) && (return (fill(zval,m,1), fill(i1,m,1)))
S = Vector{Tv}(m); I = Vector{Ti}(m)
@inbounds for row in 1:m
S[row] = zval; I[row] = _findz($A, row:row, 1:n)
if I[row] == 0
I[row] = sub2ind(szA, row, 1)
S[row] = zval; I[row] = _findz(A, row:row, 1:n)
if I[row] == CartesianIndex(0, 0)
I[row] = CartesianIndex(row, 1)
S[row] = A[row,1]
end
end
@inbounds for i = 1 : n, j = colptr[i] : colptr[i+1]-1
row = rowval[j]
if ($op)(nzval[j], S[row])
if op(nzval[j], S[row])
S[row] = nzval[j]
I[row] = sub2ind(szA, row, i)
I[row] = CartesianIndex(row, i)
end
end
return (reshape(S,m,1), reshape(I,m,1))
elseif $region == (1,2)
(N == 0) && (return (fill(zval,1,1), fill(convert($Ti,1),1,1)))
hasz = nnz($A) != length($A)
elseif region == (1,2)
(N == 0) && (return (fill(zval,1,1), fill(i1,1,1)))
hasz = nnz(A) != length(A)
Sv = hasz ? zval : nzval[1]
Iv::($Ti) = hasz ? _findz($A) : 1
@inbounds for i = 1 : $A.n, j = colptr[i] : (colptr[i+1]-1)
if ($op)(nzval[j], Sv)
Iv::(Ti) = hasz ? _findz(A) : i1
@inbounds for i = 1 : A.n, j = colptr[i] : (colptr[i+1]-1)
if op(nzval[j], Sv)
Sv = nzval[j]
Iv = sub2ind(szA, rowval[j], i)
Iv = CartesianIndex(rowval[j], i)
end
end
return (fill(Sv,1,1), fill(Iv,1,1))
else
throw(ArgumentError("invalid value for region; must be 1, 2, or (1,2)"))
end
end) #quote
end

_isless_fm(a, b) = b == b && ( a != a || isless(a, b) )
_isgreater_fm(a, b) = b == b && ( a != a || isless(b, a) )

findmin(A::SparseMatrixCSC{Tv,Ti}, region) where {Tv,Ti} = @_findr(_isless_fm, A, region, Tv, Ti)
findmax(A::SparseMatrixCSC{Tv,Ti}, region) where {Tv,Ti} = @_findr(_isgreater_fm, A, region, Tv, Ti)
findmin(A::SparseMatrixCSC{Tv,Ti}, region) where {Tv,Ti} = _findr(_isless_fm, A, region, Tv)
findmax(A::SparseMatrixCSC{Tv,Ti}, region) where {Tv,Ti} = _findr(_isgreater_fm, A, region, Tv)
findmin(A::SparseMatrixCSC) = (r=findmin(A,(1,2)); (r[1][1], r[2][1]))
findmax(A::SparseMatrixCSC) = (r=findmax(A,(1,2)); (r[1][1], r[2][1]))

Expand Down
7 changes: 6 additions & 1 deletion test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ end
@test indmin(5:-2:1) == 3

#23094
@test findmax(Set(["abc"])) === ("abc", 1)
@test_throws MethodError findmax(Set(["abc"]))
@test findmin(["abc", "a"]) === ("a", 2)
@test_throws MethodError findmax([Set([1]), Set([2])])
@test findmin([0.0, -0.0]) === (-0.0, 2)
Expand Down Expand Up @@ -1814,6 +1814,11 @@ s, si = findmax(S)
@test a == b == s
@test ai == bi == si

for X in (A, B, S)
@test findmin(X) == findmin(Dict(pairs(X)))
@test findmax(X) == findmax(Dict(pairs(X)))
end

fill!(B, 2)
@test all(x->x==2, B)

Expand Down
Loading

0 comments on commit 5f7efae

Please sign in to comment.