Skip to content

Commit

Permalink
Merge pull request #27160 from ninjin/nin/argmaxmin
Browse files Browse the repository at this point in the history
[RFC] Add argmin and argmax over given dimensions
  • Loading branch information
StefanKarpinski authored Jun 1, 2018
2 parents 333981f + 607b1f3 commit 166cbed
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
50 changes: 50 additions & 0 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -806,3 +806,53 @@ function _findmax(A, region)
end

reducedim1(R, A) = _length(indices1(R)) == 1

"""
argmin(A; dims) -> indices
For an array input, return the indices of the minimum elements over the given dimensions.
`NaN` is treated as less than all other values.
# Examples
```jldoctest
julia> A = [1.0 2; 3 4]
2×2 Array{Float64,2}:
1.0 2.0
3.0 4.0
julia> argmin(A, dims=1)
1×2 Array{CartesianIndex{2},2}:
CartesianIndex(1, 1) CartesianIndex(1, 2)
julia> argmin(A, dims=2)
2×1 Array{CartesianIndex{2},2}:
CartesianIndex(1, 1)
CartesianIndex(2, 1)
```
"""
argmin(A::AbstractArray; dims=:) = findmin(A; dims=dims)[2]

"""
argmax(A; dims) -> indices
For an array input, return the indices of the maximum elements over the given dimensions.
`NaN` is treated as greater than all other values.
# Examples
```jldoctest
julia> A = [1.0 2; 3 4]
2×2 Array{Float64,2}:
1.0 2.0
3.0 4.0
julia> argmax(A, dims=1)
1×2 Array{CartesianIndex{2},2}:
CartesianIndex(2, 1) CartesianIndex(2, 2)
julia> argmax(A, dims=2)
2×1 Array{CartesianIndex{2},2}:
CartesianIndex(1, 2)
CartesianIndex(2, 2)
```
"""
argmax(A::AbstractArray; dims=:) = findmax(A; dims=dims)[2]
8 changes: 7 additions & 1 deletion test/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ end
end

end
## findmin/findmax/minumum/maximum
## findmin/findmax/minimum/maximum

A = [1.0 5.0 6.0;
5.0 2.0 4.0]
Expand Down Expand Up @@ -352,3 +352,9 @@ end
T <: Base.SmallUnsigned ? UInt :
T)
end

@testset "argmin/argmax" begin
B = reshape(3^3:-1:1, (3, 3, 3))
@test B[argmax(B, dims=[2, 3])] == maximum(B, dims=[2, 3])
@test B[argmin(B, dims=[2, 3])] == minimum(B, dims=[2, 3])
end

0 comments on commit 166cbed

Please sign in to comment.