From 3d99c24461d0cf9aac575bd6e6e90ddd32d8099b Mon Sep 17 00:00:00 2001 From: Miles Cranmer Date: Sat, 3 Aug 2024 13:29:51 +0100 Subject: [PATCH] Create `Base.Fix` as general `Fix1`/`Fix2` for partially-applied functions (#54653) This PR generalises `Base.Fix1` and `Base.Fix2` to `Base.Fix{N}`, to allow fixing a single positional argument of a function. With this change, the implementation of these is simply ```julia const Fix1{F,T} = Fix{1,F,T} const Fix2{F,T} = Fix{2,F,T} ``` Along with the PR I also add a larger suite of unittests for all three of these functions to complement the existing tests for `Fix1`/`Fix2`. ### Context There are multiple motivations for this generalization. **By creating a more general `Fix{N}` type, there is no preferential treatment of certain types of functions:** - (i) No limitation that you can only fix positions 1-2. You can now fix any position `n`. - (ii) No asymmetry between 2-argument and n-argument functions. You can now fix an argument for functions with any number of arguments. Think of this like if `Base` only had `Vector{T}` and `Matrix{T}`, and you wished to generalise it to `Array{T,N}`. It is an analogous situation here: `Fix1` and `Fix2` are now *aliases* of `Fix{N}`. - **Convenience**: - `Base.Fix1` and `Base.Fix2` are useful shorthands for creating simple anonymous functions without compiling new functions. - They are common throughout the Julia ecosystem as a shorthand for filling arguments: - `Fix1` https://github.com/search?q=Base.Fix1+language%3Ajulia&type=code - `Fix2` https://github.com/search?q=Base.Fix2+language%3Ajulia&type=code - **Less Compilation**: - Using `Fix*` reduces the need for compilation of repeatedly-used anonymous functions (which can often trigger compilation of new functions). - **Type Stability**: - `Fix`, like `Fix1` and `Fix2`, captures variables in a struct, encouraging users to use a functional paradigm for closures, preventing any potential type instabilities from boxed variables within an anonymous function. - **Easier Functional Programming**: - Allows for a stronger functional programming paradigm by supporting partial functions with _any number of arguments_. Note that this refactors `Fix1` and `Fix2` to be equal to `Fix{1}` and `Fix{2}` respectively, rather than separate structs. This is backwards compatible. Also note that this does not constrain future generalisations of `Fix{n}` for multiple arguments. `Fix{1,F,T}` is the clear generalisation of `Fix1{F,T}`, so this isn't major new syntax choices. But in a future PR you could have, e.g., `Fix{(n1,n2)}` for multiple arguments, and it would still be backwards-compatible with this. --------- Co-authored-by: Dilum Aluthge Co-authored-by: Lilith Orion Hafner Co-authored-by: Alexander Plavin Co-authored-by: Neven Sajko --- NEWS.md | 1 + base/operators.jl | 57 +++++++++++------- base/public.jl | 1 + doc/src/base/base.md | 1 + stdlib/REPL/test/repl.jl | 4 +- test/functional.jl | 126 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 167 insertions(+), 23 deletions(-) diff --git a/NEWS.md b/NEWS.md index 6f12b8ff389cd..2e9b32befe342 100644 --- a/NEWS.md +++ b/NEWS.md @@ -73,6 +73,7 @@ New library functions * `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]). * `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]). * `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims` +* The new `Fix` type is a generalization of `Fix1/Fix2` for fixing a single argument ([#54653]). New library features -------------------- diff --git a/base/operators.jl b/base/operators.jl index b7605ba004b7a..2c8070b44d704 100644 --- a/base/operators.jl +++ b/base/operators.jl @@ -1154,40 +1154,55 @@ julia> filter(!isletter, str) !(f::ComposedFunction{typeof(!)}) = f.inner #allows !!f === f """ - Fix1(f, x) + Fix{N}(f, x) -A type representing a partially-applied version of the two-argument function -`f`, with the first argument fixed to the value "x". In other words, -`Fix1(f, x)` behaves similarly to `y->f(x, y)`. +A type representing a partially-applied version of a function `f`, with the argument +`x` fixed at position `N::Int`. In other words, `Fix{3}(f, x)` behaves similarly to +`(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`. -See also [`Fix2`](@ref Base.Fix2). +!!! compat "Julia 1.12" + This general functionality requires at least Julia 1.12, while `Fix1` and `Fix2` + are available earlier. + +!!! note + When nesting multiple `Fix`, note that the `N` in `Fix{N}` is _relative_ to the current + available arguments, rather than an absolute ordering on the target function. For example, + `Fix{1}(Fix{2}(f, 4), 4)` fixes the first and second arg, while `Fix{2}(Fix{1}(f, 4), 4)` + fixes the first and third arg. """ -struct Fix1{F,T} <: Function +struct Fix{N,F,T} <: Function f::F x::T - Fix1(f::F, x) where {F} = new{F,_stable_typeof(x)}(f, x) - Fix1(f::Type{F}, x) where {F} = new{Type{F},_stable_typeof(x)}(f, x) + function Fix{N}(f::F, x) where {N,F} + if !(N isa Int) + throw(ArgumentError(LazyString("expected type parameter in `Fix` to be `Int`, but got `", N, "::", typeof(N), "`"))) + elseif N < 1 + throw(ArgumentError(LazyString("expected `N` in `Fix{N}` to be integer greater than 0, but got ", N))) + end + new{N,_stable_typeof(f),_stable_typeof(x)}(f, x) + end end -(f::Fix1)(y) = f.f(f.x, y) +function (f::Fix{N})(args::Vararg{Any,M}; kws...) where {N,M} + M < N-1 && throw(ArgumentError(LazyString("expected at least ", N-1, " arguments to `Fix{", N, "}`, but got ", M))) + return f.f(args[begin:begin+(N-2)]..., f.x, args[begin+(N-1):end]...; kws...) +end -""" - Fix2(f, x) +# Special cases for improved constant propagation +(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...) +(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...) -A type representing a partially-applied version of the two-argument function -`f`, with the second argument fixed to the value "x". In other words, -`Fix2(f, x)` behaves similarly to `y->f(y, x)`. """ -struct Fix2{F,T} <: Function - f::F - x::T +Alias for `Fix{1}`. See [`Fix`](@ref Base.Fix). +""" +const Fix1{F,T} = Fix{1,F,T} - Fix2(f::F, x) where {F} = new{F,_stable_typeof(x)}(f, x) - Fix2(f::Type{F}, x) where {F} = new{Type{F},_stable_typeof(x)}(f, x) -end +""" +Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix). +""" +const Fix2{F,T} = Fix{2,F,T} -(f::Fix2)(y) = f.f(y, f.x) """ isequal(x) diff --git a/base/public.jl b/base/public.jl index c11c76c13053c..862aff48da63e 100644 --- a/base/public.jl +++ b/base/public.jl @@ -14,6 +14,7 @@ public AsyncCondition, CodeUnits, Event, + Fix, Fix1, Fix2, Generator, diff --git a/doc/src/base/base.md b/doc/src/base/base.md index 946f917682814..1a8cd29f91066 100644 --- a/doc/src/base/base.md +++ b/doc/src/base/base.md @@ -281,6 +281,7 @@ Base.:(|>) Base.:(∘) Base.ComposedFunction Base.splat +Base.Fix Base.Fix1 Base.Fix2 ``` diff --git a/stdlib/REPL/test/repl.jl b/stdlib/REPL/test/repl.jl index 05db88fa0d8ac..6f0c4a5c3d6ba 100644 --- a/stdlib/REPL/test/repl.jl +++ b/stdlib/REPL/test/repl.jl @@ -1216,9 +1216,9 @@ global some_undef_global @test occursin("does not exist", sprint(show, help_result(".."))) # test that helpmode is sensitive to contextual module @test occursin("No documentation found", sprint(show, help_result("Fix2", Main))) -@test occursin("A type representing a partially-applied version", # exact string may change +@test occursin("Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).", # exact string may change sprint(show, help_result("Base.Fix2", Main))) -@test occursin("A type representing a partially-applied version", # exact string may change +@test occursin("Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).", # exact string may change sprint(show, help_result("Fix2", Base))) diff --git a/test/functional.jl b/test/functional.jl index 3436fb8911cc1..84c4098308ebd 100644 --- a/test/functional.jl +++ b/test/functional.jl @@ -235,3 +235,129 @@ end let (:)(a,b) = (i for i in Base.:(:)(1,10) if i%2==0) @test Int8[ i for i = 1:2 ] == [2,4,6,8,10] end + +@testset "Basic tests of Fix1, Fix2, and Fix" begin + function test_fix1(Fix1=Base.Fix1) + increment = Fix1(+, 1) + @test increment(5) == 6 + @test increment(-1) == 0 + @test increment(0) == 1 + @test map(increment, [1, 2, 3]) == [2, 3, 4] + + concat_with_hello = Fix1(*, "Hello ") + @test concat_with_hello("World!") == "Hello World!" + # Make sure inference is good: + @inferred concat_with_hello("World!") + + one_divided_by = Fix1(/, 1) + @test one_divided_by(10) == 1/10.0 + @test one_divided_by(-5) == 1/-5.0 + + return nothing + end + + function test_fix2(Fix2=Base.Fix2) + return_second = Fix2((x, y) -> y, 999) + @test return_second(10) == 999 + @inferred return_second(10) + @test return_second(-5) == 999 + + divide_by_two = Fix2(/, 2) + @test map(divide_by_two, (2, 4, 6)) == (1.0, 2.0, 3.0) + @inferred map(divide_by_two, (2, 4, 6)) + + concat_with_world = Fix2(*, " World!") + @test concat_with_world("Hello") == "Hello World!" + @inferred concat_with_world("Hello World!") + + return nothing + end + + # Test with normal Base.Fix1 and Base.Fix2 + test_fix1() + test_fix2() + + # Now, repeat the Fix1 and Fix2 tests, but + # with a Fix lambda function used in their place + test_fix1((op, arg) -> Base.Fix{1}(op, arg)) + test_fix2((op, arg) -> Base.Fix{2}(op, arg)) + + # Now, we do more complex tests of Fix: + let Fix=Base.Fix + @testset "Argument Fixation" begin + let f = (x, y, z) -> x + y * z + fixed_f1 = Fix{1}(f, 10) + @test fixed_f1(2, 3) == 10 + 2 * 3 + + fixed_f2 = Fix{2}(f, 5) + @test fixed_f2(1, 4) == 1 + 5 * 4 + + fixed_f3 = Fix{3}(f, 3) + @test fixed_f3(1, 2) == 1 + 2 * 3 + end + end + @testset "Helpful errors" begin + let g = (x, y) -> x - y + # Test minimum N + fixed_g1 = Fix{1}(g, 100) + @test fixed_g1(40) == 100 - 40 + + # Test maximum N + fixed_g2 = Fix{2}(g, 100) + @test fixed_g2(150) == 150 - 100 + + # One over + fixed_g3 = Fix{3}(g, 100) + @test_throws ArgumentError("expected at least 2 arguments to `Fix{3}`, but got 1") fixed_g3(1) + end + end + @testset "Type Stability and Inference" begin + let h = (x, y) -> x / y + fixed_h = Fix{2}(h, 2.0) + @test @inferred(fixed_h(4.0)) == 2.0 + end + end + @testset "Interaction with varargs" begin + vararg_f = (x, y, z...) -> x + 10 * y + sum(z; init=zero(x)) + fixed_vararg_f = Fix{2}(vararg_f, 6) + + # Can call with variable number of arguments: + @test fixed_vararg_f(1, 2, 3, 4) == 1 + 10 * 6 + sum((2, 3, 4)) + @inferred fixed_vararg_f(1, 2, 3, 4) + @test fixed_vararg_f(5) == 5 + 10 * 6 + @inferred fixed_vararg_f(5) + end + @testset "Errors should propagate normally" begin + error_f = (x, y) -> sin(x * y) + fixed_error_f = Fix{2}(error_f, Inf) + @test_throws DomainError fixed_error_f(10) + end + @testset "Chaining Fix together" begin + f1 = Fix{1}(*, "1") + f2 = Fix{1}(f1, "2") + f3 = Fix{1}(f2, "3") + @test f3() == "123" + + g1 = Fix{2}(*, "1") + g2 = Fix{2}(g1, "2") + g3 = Fix{2}(g2, "3") + @test g3("") == "123" + end + @testset "Zero arguments" begin + f = Fix{1}(x -> x, 'a') + @test f() == 'a' + end + @testset "Dummy-proofing" begin + @test_throws ArgumentError("expected `N` in `Fix{N}` to be integer greater than 0, but got 0") Fix{0}(>, 1) + @test_throws ArgumentError("expected type parameter in `Fix` to be `Int`, but got `0.5::Float64`") Fix{0.5}(>, 1) + @test_throws ArgumentError("expected type parameter in `Fix` to be `Int`, but got `1::UInt64`") Fix{UInt64(1)}(>, 1) + end + @testset "Specialize to structs not in `Base`" begin + struct MyStruct + x::Int + end + f = Fix{1}(MyStruct, 1) + @test f isa Fix{1,Type{MyStruct},Int} + end + end +end