Skip to content

Commit

Permalink
Support 2-arg argmin/argmax/findmin/findmax
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Mar 29, 2021
1 parent 12eec81 commit 7c2f4f9
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Compat"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.25.0"
version = "3.26.0"

[deps]
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Expand Down
18 changes: 18 additions & 0 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,24 @@ if VERSION < v"1.2.0-DEV.246"
end
end

if VERSION < v"1.7.0-DEV.119"
# Part of https://github.com/JuliaLang/julia/pull/35316
isunordered(x) = false
isunordered(x::AbstractFloat) = isnan(x)
isunordered(x::Missing) = true

isgreater(x, y) = isunordered(x) || isunordered(y) ? isless(x, y) : isless(y, x)

Base.findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)
_rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m)

Base.findmin(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmin, domain)
_rf_findmin((fm, m), (fx, x)) = isgreater(fm, fx) ? (fx, x) : (fm, m)

Base.argmax(f, domain) = findmax(f, domain)[2]
Base.argmin(f, domain) = findmin(f, domain)[2]
end

include("iterators.jl")
include("deprecated.jl")

Expand Down
33 changes: 33 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -833,3 +833,36 @@ end
@test endswith("abc", r"C"i)
@test endswith("abc", r"Bc"i)
end

# https://github.com/JuliaLang/julia/pull/35316
@testset "2arg" begin
@testset "findmin(f, domain)" begin
@test findmin(-, 1:10) == (-10, 10)
@test findmin(identity, [1, 2, 3, missing]) === (missing, missing)
@test findmin(identity, [1, NaN, 3, missing]) === (missing, missing)
@test findmin(identity, [1, missing, NaN, 3]) === (missing, missing)
@test findmin(identity, [1, NaN, 3]) === (NaN, NaN)
@test findmin(identity, [1, 3, NaN]) === (NaN, NaN)
@test all(findmin(cos, 0:π/2:2π) .≈ (-1.0, π))
end

@testset "findmax(f, domain)" begin
@test findmax(-, 1:10) == (-1, 1)
@test findmax(identity, [1, 2, 3, missing]) === (missing, missing)
@test findmax(identity, [1, NaN, 3, missing]) === (missing, missing)
@test findmax(identity, [1, missing, NaN, 3]) === (missing, missing)
@test findmax(identity, [1, NaN, 3]) === (NaN, NaN)
@test findmax(identity, [1, 3, NaN]) === (NaN, NaN)
@test findmax(cos, 0:π/2:2π) == (1.0, 0.0)
end

@testset "argmin(f, domain)" begin
@test argmin(-, 1:10) == 10
@test argmin(sum, Iterators.product(1:5, 1:5)) == (1, 1)
end

@testset "argmax(f, domain)" begin
@test argmax(-, 1:10) == 1
@test argmax(sum, Iterators.product(1:5, 1:5)) == (5, 5)
end
end

0 comments on commit 7c2f4f9

Please sign in to comment.