Skip to content

Commit

Permalink
Merge pull request #37 from mcabbott/biwalk
Browse files Browse the repository at this point in the history
Make `fmap(f, x, y)` useful
  • Loading branch information
mcabbott authored Feb 9, 2022
2 parents 4a834ce + 077d6e5 commit 0cf7942
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 82 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ jobs:
fail-fast: false
matrix:
version:
- '1.5' # Replace this with the minimum Julia version that your package supports.
# - '1' # automatically expands to the latest stable 1.x release of Julia
- '1.0'
- '1.6' # Replace this with the minimum Julia version that your package supports.
- '1' # automatically expands to the latest stable 1.x release of Julia
- 'nightly'
os:
- ubuntu-latest
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
name = "Functors"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
version = "0.2.7"
version = "0.2.8"

[compat]
julia = "1"
Documenter = "0.27"
julia = "1"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Expand Down
12 changes: 12 additions & 0 deletions src/Functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ usually using the macro [@functor](@ref).
"""
functor

@static if VERSION >= v"1.5" # var"@functor" doesn't work on 1.0, temporarily disable
"""
@functor T
@functor T (x,)
Expand Down Expand Up @@ -65,6 +66,7 @@ TwoThirds(Foo(10, 20), Foo(3, 4), 560)
```
"""
var"@functor"
end # VERSION

"""
Functors.isleaf(x)
Expand Down Expand Up @@ -182,6 +184,16 @@ This function walks (maps) over `xs` calling the continuation `f'` to continue t
julia> fmap(x -> 10x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x))
Foo(Bar([1, 2, 3]), (40, 50, Bar(Foo(6, 7))))
```
The behaviour when the same node appears twice can be altered by giving a value
to the `prune` keyword, which is then used in place of all but the first:
```jldoctest
julia> twice = [1, 2];
julia> fmap(float, (x = twice, y = [1,2], z = twice); prune = missing)
(x = [1.0, 2.0], y = [1.0, 2.0], z = missing)
```
"""
fmap

Expand Down
36 changes: 11 additions & 25 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ functor(T, x) = (), _ -> x
functor(x) = functor(typeof(x), x)

functor(::Type{<:Tuple}, x) = x, y -> y
functor(::Type{<:NamedTuple}, x) = x, y -> y
functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity

functor(::Type{<:AbstractArray}, x) = x, y -> y
functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x
Expand Down Expand Up @@ -43,12 +43,11 @@ function _default_walk(f, x)
re(map(f, func))
end

function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())
haskey(cache, x) && return cache[x]
y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x)
cache[x] = y
struct NoKeyword end

return y
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword())
haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune
cache[x] = exclude(x) ? f(x) : walk(x -> fmap(f, x; exclude=exclude, walk=walk, cache=cache, prune=prune), x)
end

###
Expand All @@ -74,27 +73,16 @@ end
### Vararg forms
###

function fmap(f, x, dx...; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x, dx...) : _default_walk((x...) -> fmap(f, x..., cache = cache), x, dx...)
function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword())
haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune
cache[x] = exclude(x) ? f(x, ys...) : walk((xy...,) -> fmap(f, xy...; exclude=exclude, walk=walk, cache=cache, prune=prune), x, ys...)
end

function functor_tuple(f, x::Tuple, dx::Tuple)
map(x, dx) do x, x̄
_default_walk(f, x, x̄)
end
end
functor_tuple(f, x, dx) = f(x, dx)
functor_tuple(f, x, ::Nothing) = x

function _default_walk(f, x, dx)
function _default_walk(f, x, ys...)
func, re = functor(x)
map(func, dx) do x, x̄
# functor_tuple(f, x, x̄)
f(x, x̄)
end |> re
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
re(map(f, func, yfuncs...))
end
_default_walk(f, ::Nothing, ::Nothing) = nothing

###
### FlexibleFunctors.jl
Expand All @@ -112,9 +100,7 @@ function makeflexiblefunctor(m::Module, T, pfield)
func = NamedTuple{pfields}(map(p -> getproperty(x, p), pfields))
return func, re
end

end

end

function flexiblefunctorm(T, pfield = :params)
Expand Down
186 changes: 157 additions & 29 deletions test/basics.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
struct Foo
x
y
end

using Functors: functor

struct Foo; x; y; end
@functor Foo

struct Bar
x
end
struct Bar; x; end
@functor Bar

struct Baz
x
y
z
end
@functor Baz (y,)
struct OneChild3; x; y; z; end
@functor OneChild3 (y,)

struct NoChildren
x
y
end
struct NoChildren2; x; y; end

@static if VERSION >= v"1.6"
@testset "ComposedFunction" begin
Expand All @@ -31,6 +22,10 @@ end
end
end

###
### Basic functionality
###

@testset "Nested" begin
model = Bar(Foo(1, [1, 2, 3]))

Expand All @@ -53,20 +48,80 @@ end
@test fmap(f, x; exclude = x -> x isa AbstractArray) == x
end

@testset "Property list" begin
model = OneChild3(1, 2, 3)
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (1, 4, 3)
end

@testset "cache" begin
shared = [1,2,3]
m1 = Foo(shared, Foo([1,2,3], Foo(shared, [1,2,3])))
m1f = fmap(float, m1)
@test m1f.x === m1f.y.y.x
@test m1f.x !== m1f.y.x
m1p = fmapstructure(identity, m1; prune = nothing)
@test m1p == (x = [1, 2, 3], y = (x = [1, 2, 3], y = (x = nothing, y = [1, 2, 3])))

# A non-leaf node can also be repeated:
m2 = Foo(Foo(shared, 4), Foo(shared, 4))
@test m2.x === m2.y
m2f = fmap(float, m2)
@test m2f.x.x === m2f.y.x
m2p = fmapstructure(identity, m2; prune = Bar(0))
@test m2p == (x = (x = [1, 2, 3], y = 4), y = Bar(0))

# Repeated isbits types should not automatically be regarded as shared:
m3 = Foo(Foo(shared, 1:3), Foo(1:3, shared))
m3p = fmapstructure(identity, m3; prune = 0)
@test m3p.y.y == 0
@test_broken m3p.y.x == 1:3
end

@testset "functor(typeof(x), y) from @functor" begin
nt1, re1 = functor(Foo, (x=1, y=2, z=3))
@test nt1 == (x = 1, y = 2)
@test re1((x = 10, y = 20)) == Foo(10, 20)
re1((y = 22, x = 11)) # gives Foo(22, 11), is that a bug?

nt2, re2 = functor(Foo, (z=33, x=1, y=2))
@test nt2 == (x = 1, y = 2)
@test re2((x = 10, y = 20)) == Foo(10, 20)

@test_throws Exception functor(Foo, (z=33, x=1)) # type NamedTuple has no field y

nt3, re3 = functor(OneChild3, (x=1, y=2, z=3))
@test nt3 == (y = 2,)
@test re3((y = 20,)) == OneChild3(1, 20, 3)
re3(22) # gives OneChild3(1, 22, 3), is that a bug?
end

@testset "functor(typeof(x), y) for Base types" begin
nt11, re11 = functor(NamedTuple{(:x, :y)}, (x=1, y=2, z=3))
@test nt11 == (x = 1, y = 2)
@test re11((x = 10, y = 20)) == (x = 10, y = 20)
re11((y = 22, x = 11))
re11((11, 22)) # passes right through

nt12, re12 = functor(NamedTuple{(:x, :y)}, (z=33, x=1, y=2))
@test nt12 == (x = 1, y = 2)
@test re12((x = 10, y = 20)) == (x = 10, y = 20)

@test_throws Exception functor(NamedTuple{(:x, :y)}, (z=33, x=1))
end

###
### Extras
###

@testset "Walk" begin
model = Foo((0, Bar([1, 2, 3])), [4, 5])

model′ = fmapstructure(identity, model)
@test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5])
end

@testset "Property list" begin
model = Baz(1, 2, 3)
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (1, 4, 3)
end

@testset "fcollect" begin
m1 = [1, 2, 3]
m2 = 1
Expand All @@ -78,7 +133,7 @@ end

m1 = [1, 2, 3]
m2 = Bar(m1)
m0 = NoChildren(:a, :b)
m0 = NoChildren2(:a, :b)
m3 = Foo(m2, m0)
m4 = Bar(m3)
@test all(fcollect(m4) .=== [m4, m3, m2, m1, m0])
Expand All @@ -89,6 +144,79 @@ end
@test all(fcollect(m3) .=== [m3, m1, m2])
end

###
### Vararg forms
###

@testset "fmap(f, x, y)" begin
m1 = (x = [1,2], y = 3)
n1 = (x = [4,5], y = 6)
@test fmap(+, m1, n1) == (x = [5, 7], y = 9)

# Reconstruction type comes from the first argument
foo1 = Foo([7,8], 9)
@test fmap(+, m1, foo1) == (x = [8, 10], y = 12)
@test fmap(+, foo1, n1) isa Foo
@test fmap(+, foo1, n1).x == [11, 13]

# Mismatched trees should be an error
m2 = (x = [1,2], y = (a = [3,4], b = 5))
n2 = (x = [6,7], y = 8)
@test_throws Exception fmap(firsttuple, m2, n2) # ERROR: type Int64 has no field a
@test_throws Exception fmap(firsttuple, m2, n2)

# The cache uses IDs from the first argument
shared = [1,2,3]
m3 = (x = shared, y = [4,5,6], z = shared)
n3 = (x = shared, y = shared, z = [7,8,9])
@test fmap(+, m3, n3) == (x = [2, 4, 6], y = [5, 7, 9], z = [2, 4, 6])
z3 = fmap(+, m3, n3)
@test z3.x === z3.z

# Pruning of duplicates:
@test fmap(+, m3, n3; prune = nothing) == (x = [2,4,6], y = [5,7,9], z = nothing)

# More than two arguments:
z4 = fmap(+, m3, n3, m3, n3)
@test z4 == fmap(x -> 2x, z3)
@test z4.x === z4.z

@test fmap(+, foo1, m1, n1) isa Foo
@static if VERSION >= v"1.6" # fails on Julia 1.0
@test fmap(.*, m1, foo1, n1) == (x = [4*7, 2*5*8], y = 3*6*9)
end
end

@static if VERSION >= v"1.6" # Julia 1.0: LoadError: error compiling top-level scope: type definition not allowed inside a local scope
@testset "old test update.jl" begin
struct M{F,T,S}
σ::F
W::T
b::S
end

@functor M

(m::M)(x) = m.σ.(m.W * x .+ m.b)

m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3))
x = ones(Float32, 4, 2)
m̄, _ = gradient((m,x) -> sum(m(x)), m, x)
= Functors.fmap(m, m̄) do x, y
isnothing(x) && return y
isnothing(y) && return x
x .- 0.1f0 .* y
end

@test.W fill(0.8f0, size(m.W))
@test.b fill(-0.2f0, size(m.b))
end
end # VERSION

###
### FlexibleFunctors.jl
###

struct FFoo
x
y
Expand All @@ -102,13 +230,13 @@ struct FBar
end
@flexiblefunctor FBar p

struct FBaz
struct FOneChild4
x
y
z
p
end
@flexiblefunctor FBaz p
@flexiblefunctor FOneChild4 p

@testset "Flexible Nested" begin
model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,))
Expand All @@ -132,7 +260,7 @@ end
end

@testset "Flexible Property list" begin
model = FBaz(1, 2, 3, (:x, :z))
model = FOneChild4(1, 2, 3, (:x, :z))
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (2, 2, 6)
Expand All @@ -147,7 +275,7 @@ end
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3])
@test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4])

m0 = NoChildren(:a, :b)
m0 = NoChildren2(:a, :b)
m1 = [1, 2, 3]
m2 = FBar(m1, ())
m3 = FFoo(m2, m0, (:x, :y,))
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using Zygote

include("basics.jl")
include("base.jl")
include("update.jl")

if VERSION < v"1.6" # || VERSION > v"1.7-"
@warn "skipping doctests, on Julia $VERSION"
Expand Down
Loading

2 comments on commit 0cf7942

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/54306

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.8 -m "<description of version>" 0cf7942bd3f3a00bdb3179d74eaec17209b6abcd
git push origin v0.2.8

Please sign in to comment.