From ae895868926b79deab754a1cd8043fdc43d2f11f Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Thu, 3 Jun 2021 13:59:13 -0400 Subject: [PATCH] fix #39203, 2-arg `findmax` should return index instead of value --- base/reduce.jl | 48 ++++++++++++++++++++++++------------------------ test/reduce.jl | 24 ++++++++++++------------ 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/base/reduce.jl b/base/reduce.jl index 6490358214c7e..d8de9f8dae0b3 100644 --- a/base/reduce.jl +++ b/base/reduce.jl @@ -771,11 +771,11 @@ minimum(a; kw...) = mapreduce(identity, min, a; kw...) ## findmax, findmin, argmax & argmin """ - findmax(f, domain) -> (f(x), x) + findmax(f, domain) -> (f(x), index) -Returns a pair of a value in the codomain (outputs of `f`) and the corresponding -value in the `domain` (inputs to `f`) such that `f(x)` is maximised. If there -are multiple maximal points, then the first one will be returned. +Returns a pair of a value in the codomain (outputs of `f`) and the index of +the corresponding value in the `domain` (inputs to `f`) such that `f(x)` is maximised. +If there are multiple maximal points, then the first one will be returned. `domain` must be a non-empty iterable. @@ -788,20 +788,20 @@ Values are compared with `isless`. ```jldoctest julia> findmax(identity, 5:9) -(9, 9) +(9, 5) julia> findmax(-, 1:10) (-1, 1) -julia> findmax(first, [(1, :a), (2, :b), (2, :c)]) -(2, (2, :b)) +julia> findmax(first, [(1, :a), (3, :b), (3, :c)]) +(3, 2) julia> findmax(cos, 0:π/2:2π) -(1.0, 0.0) +(1.0, 1) ``` """ -findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain) -_rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m) +findmax(f, domain) = mapfoldl( ((k, v),) -> (f(v), k), _rf_findmax, pairs(domain) ) +_rf_findmax((fm, im), (fx, ix)) = isless(fm, fx) ? (fx, ix) : (fm, im) """ findmax(itr) -> (x, index) @@ -826,14 +826,14 @@ julia> findmax([1, 7, 7, NaN]) ``` """ findmax(itr) = _findmax(itr, :) -_findmax(a, ::Colon) = mapfoldl( ((k, v),) -> (v, k), _rf_findmax, pairs(a) ) +_findmax(a, ::Colon) = findmax(identity, a) """ - findmin(f, domain) -> (f(x), x) + findmin(f, domain) -> (f(x), index) -Returns a pair of a value in the codomain (outputs of `f`) and the corresponding -value in the `domain` (inputs to `f`) such that `f(x)` is minimised. If there -are multiple minimal points, then the first one will be returned. +Returns a pair of a value in the codomain (outputs of `f`) and the index of +the corresponding value in the `domain` (inputs to `f`) such that `f(x)` is minimised. +If there are multiple minimal points, then the first one will be returned. `domain` must be a non-empty iterable. @@ -846,21 +846,21 @@ are multiple minimal points, then the first one will be returned. ```jldoctest julia> findmin(identity, 5:9) -(5, 5) +(5, 1) julia> findmin(-, 1:10) (-10, 10) -julia> findmin(first, [(1, :a), (1, :b), (2, :c)]) -(1, (1, :a)) +julia> findmin(first, [(2, :a), (2, :b), (3, :c)]) +(2, 1) julia> findmin(cos, 0:π/2:2π) -(-1.0, 3.141592653589793) +(-1.0, 3) ``` """ -findmin(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmin, domain) -_rf_findmin((fm, m), (fx, x)) = isgreater(fm, fx) ? (fx, x) : (fm, m) +findmin(f, domain) = mapfoldl( ((k, v),) -> (f(v), k), _rf_findmin, pairs(domain) ) +_rf_findmin((fm, im), (fx, ix)) = isgreater(fm, fx) ? (fx, ix) : (fm, im) """ findmin(itr) -> (x, index) @@ -885,7 +885,7 @@ julia> findmin([1, 7, 7, NaN]) ``` """ findmin(itr) = _findmin(itr, :) -_findmin(a, ::Colon) = mapfoldl( ((k, v),) -> (v, k), _rf_findmin, pairs(a) ) +_findmin(a, ::Colon) = findmin(identity, a) """ argmax(f, domain) @@ -909,7 +909,7 @@ julia> argmax(cos, 0:π/2:2π) 0.0 ``` """ -argmax(f, domain) = findmax(f, domain)[2] +argmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)[2] """ argmax(itr) @@ -962,7 +962,7 @@ julia> argmin(acos, 0:0.1:1) 1.0 ``` """ -argmin(f, domain) = findmin(f, domain)[2] +argmin(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmin, domain)[2] """ argmin(itr) diff --git a/test/reduce.jl b/test/reduce.jl index 5d49c47204484..1e136af11b68a 100644 --- a/test/reduce.jl +++ b/test/reduce.jl @@ -391,22 +391,22 @@ end @testset "findmin(f, domain)" begin @test findmin(-, 1:10) == (-10, 10) - @test findmin(identity, [1, 2, 3, missing]) === (missing, missing) - @test findmin(identity, [1, NaN, 3, missing]) === (missing, missing) - @test findmin(identity, [1, missing, NaN, 3]) === (missing, missing) - @test findmin(identity, [1, NaN, 3]) === (NaN, NaN) - @test findmin(identity, [1, 3, NaN]) === (NaN, NaN) - @test all(findmin(cos, 0:π/2:2π) .≈ (-1.0, π)) + @test findmin(identity, [1, 2, 3, missing]) === (missing, 4) + @test findmin(identity, [1, NaN, 3, missing]) === (missing, 4) + @test findmin(identity, [1, missing, NaN, 3]) === (missing, 2) + @test findmin(identity, [1, NaN, 3]) === (NaN, 2) + @test findmin(identity, [1, 3, NaN]) === (NaN, 3) + @test findmin(cos, 0:π/2:2π) == (-1.0, 3) end @testset "findmax(f, domain)" begin @test findmax(-, 1:10) == (-1, 1) - @test findmax(identity, [1, 2, 3, missing]) === (missing, missing) - @test findmax(identity, [1, NaN, 3, missing]) === (missing, missing) - @test findmax(identity, [1, missing, NaN, 3]) === (missing, missing) - @test findmax(identity, [1, NaN, 3]) === (NaN, NaN) - @test findmax(identity, [1, 3, NaN]) === (NaN, NaN) - @test findmax(cos, 0:π/2:2π) == (1.0, 0.0) + @test findmax(identity, [1, 2, 3, missing]) === (missing, 4) + @test findmax(identity, [1, NaN, 3, missing]) === (missing, 4) + @test findmax(identity, [1, missing, NaN, 3]) === (missing, 2) + @test findmax(identity, [1, NaN, 3]) === (NaN, 2) + @test findmax(identity, [1, 3, NaN]) === (NaN, 3) + @test findmax(cos, 0:π/2:2π) == (1.0, 1) end @testset "argmin(f, domain)" begin