Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Fix{N} for fixing a single positional argument at any position #829

Merged
merged 15 commits into from
Aug 7, 2024
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Compat"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.15.0"
version = "4.16.0"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ changes in `julia`.

## Supported features

* `Fix{N}` which fixes an argument at the `N`th position ([#54653]) (since Compat 4.16.0)
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved

* `chopprefix(s, prefix)` and `chopsuffix(s, suffix)` ([#40995]) (since Compat 4.15.0)

* `logrange(lo, hi; length)` is like `range` but with a constant ratio, not difference. ([#39071]) (since Compat 4.14.0) Note that on Julia 1.8 and earlier, the version from Compat has slightly lower floating-point accuracy than the one in Base (Julia 1.11 and later).
Expand Down Expand Up @@ -192,3 +194,4 @@ Note that you should specify the correct minimum version for `Compat` in the
[#47679]: https://github.com/JuliaLang/julia/pull/47679
[#48038]: https://github.com/JuliaLang/julia/issues/48038
[#50105]: https://github.com/JuliaLang/julia/issues/50105
[#54653]: https://github.com/JuliaLang/julia/issues/54653
62 changes: 62 additions & 0 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,68 @@ if VERSION < v"1.8.0-DEV.1016"
export chopprefix, chopsuffix
end

# https://github.com/JuliaLang/julia/pull/54653: add Fix
@static if !isdefined(Base, :Fix)
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
@static if !isdefined(Base, :_stable_typeof)
_stable_typeof(x) = typeof(x)
_stable_typeof(::Type{T}) where {T} = Type{T}
else
using Base: _stable_typeof
end

"""
Fix{N}(f, x)

A type representing a partially-applied version of a function `f`, with the argument
`x` fixed at position `N::Int`. In other words, `Fix{3}(f, x)` behaves similarly to
`(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.

!!! compat "Julia 1.12"
This general functionality requires at least Julia 1.12, while `Fix1` and `Fix2`
are available earlier.

!!! note
When nesting multiple `Fix`, note that the `N` in `Fix{N}` is _relative_ to the current
available arguments, rather than an absolute ordering on the target function. For example,
`Fix{1}(Fix{2}(f, 4), 4)` fixes the first and second arg, while `Fix{2}(Fix{1}(f, 4), 4)`
fixes the first and third arg.
"""
struct Fix{N,F,T} <: Function
f::F
x::T

function Fix{N}(f::F, x) where {N,F}
if !(N isa Int)
throw(ArgumentError("expected type parameter in `Fix` to be `Int`, but got `$N::$(typeof(N))`"))
elseif N < 1
throw(ArgumentError("expected `N` in `Fix{N}` to be integer greater than 0, but got $N"))
end
new{N,_stable_typeof(f),_stable_typeof(x)}(f, x)
end
end

function (f::Fix{N})(args::Vararg{Any,M}; kws...) where {N,M}
M < N-1 && throw(ArgumentError("expected at least $(N-1) arguments to `Fix{$N}`, but got $M"))
return f.f(args[begin:begin+(N-2)]..., f.x, args[begin+(N-1):end]...; kws...)
end

# Special cases for improved constant propagation
(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...)
(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...)

"""
Alias for `Fix{1}`. See [`Fix`](@ref Base.Fix).
"""
const Fix1{F,T} = Fix{1,F,T}

"""
Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).
"""
const Fix2{F,T} = Fix{2,F,T}
else
using Base: Fix, Fix1, Fix2
end

include("deprecated.jl")

end # module Compat
131 changes: 131 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -907,3 +907,134 @@ end
@test isa(chopsuffix(S("foo"), "oo"), SubString)
end
end

# https://github.com/JuliaLang/julia/pull/54653: add Fix
@testset "Fix" begin
function test_fix1(Fix1=Compat.Fix1)
increment = Fix1(+, 1)
@test increment(5) == 6
@test increment(-1) == 0
@test increment(0) == 1
@test map(increment, [1, 2, 3]) == [2, 3, 4]

concat_with_hello = Fix1(*, "Hello ")
@test concat_with_hello("World!") == "Hello World!"
# Make sure inference is good:
@inferred concat_with_hello("World!")

one_divided_by = Fix1(/, 1)
@test one_divided_by(10) == 1/10.0
@test one_divided_by(-5) == 1/-5.0

return nothing
end

function test_fix2(Fix2=Compat.Fix2)
return_second = Fix2((x, y) -> y, 999)
@test return_second(10) == 999
@inferred return_second(10)
@test return_second(-5) == 999

divide_by_two = Fix2(/, 2)
@test map(divide_by_two, (2, 4, 6)) == (1.0, 2.0, 3.0)
@inferred map(divide_by_two, (2, 4, 6))

concat_with_world = Fix2(*, " World!")
@test concat_with_world("Hello") == "Hello World!"
@inferred concat_with_world("Hello World!")

return nothing
end

# Test with normal Base.Fix1 and Base.Fix2
test_fix1()
test_fix2()

# Now, repeat the Fix1 and Fix2 tests, but
# with a Fix lambda function used in their place
test_fix1((op, arg) -> Compat.Fix{1}(op, arg))
test_fix2((op, arg) -> Compat.Fix{2}(op, arg))

# Now, we do more complex tests of Fix:
let Fix=Compat.Fix
@testset "Argument Fixation" begin
let f = (x, y, z) -> x + y * z
fixed_f1 = Fix{1}(f, 10)
@test fixed_f1(2, 3) == 10 + 2 * 3

fixed_f2 = Fix{2}(f, 5)
@test fixed_f2(1, 4) == 1 + 5 * 4

fixed_f3 = Fix{3}(f, 3)
@test fixed_f3(1, 2) == 1 + 2 * 3
end
end
@testset "Helpful errors" begin
let g = (x, y) -> x - y
# Test minimum N
fixed_g1 = Fix{1}(g, 100)
@test fixed_g1(40) == 100 - 40

# Test maximum N
fixed_g2 = Fix{2}(g, 100)
@test fixed_g2(150) == 150 - 100

# One over
fixed_g3 = Fix{3}(g, 100)
@test_throws ArgumentError("expected at least 2 arguments to `Fix{3}`, but got 1") fixed_g3(1)
end
end
@testset "Type Stability and Inference" begin
let h = (x, y) -> x / y
fixed_h = Fix{2}(h, 2.0)
@test @inferred(fixed_h(4.0)) == 2.0
end
end
@testset "Interaction with varargs" begin
vararg_f = (x, y, z...) -> x + 10 * y + sum(z; init=zero(x))
fixed_vararg_f = Fix{2}(vararg_f, 6)

# Can call with variable number of arguments:
@test fixed_vararg_f(1, 2, 3, 4) == 1 + 10 * 6 + sum((2, 3, 4))
if VERSION >= v"1.7.0"
@inferred fixed_vararg_f(1, 2, 3, 4)
end
@test fixed_vararg_f(5) == 5 + 10 * 6
if VERSION >= v"1.7.0"
@inferred fixed_vararg_f(5)
end
end
@testset "Errors should propagate normally" begin
error_f = (x, y) -> sin(x * y)
fixed_error_f = Fix{2}(error_f, Inf)
@test_throws DomainError fixed_error_f(10)
end
@testset "Chaining Fix together" begin
f1 = Fix{1}(*, "1")
f2 = Fix{1}(f1, "2")
f3 = Fix{1}(f2, "3")
@test f3() == "123"

g1 = Fix{2}(*, "1")
g2 = Fix{2}(g1, "2")
g3 = Fix{2}(g2, "3")
@test g3("") == "123"
end
@testset "Zero arguments" begin
f = Fix{1}(x -> x, 'a')
@test f() == 'a'
end
@testset "Dummy-proofing" begin
@test_throws ArgumentError("expected `N` in `Fix{N}` to be integer greater than 0, but got 0") Fix{0}(>, 1)
@test_throws ArgumentError("expected type parameter in `Fix` to be `Int`, but got `0.5::Float64`") Fix{0.5}(>, 1)
@test_throws ArgumentError("expected type parameter in `Fix` to be `Int`, but got `1::UInt64`") Fix{UInt64(1)}(>, 1)
end
@testset "Specialize to structs not in `Base`" begin
struct MyStruct
x::Int
end
f = Fix{1}(MyStruct, 1)
@test f isa Fix{1,Type{MyStruct},Int}
end
end
end
Loading