Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simply show implementations and make them round-trippable #732

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ GPUArraysCore = "0.1.6"
LinearAlgebra = "1.10"
Logging = "1.10"
LossFunctions = "0.11.1"
LuxCore = "0.1.14"
LuxCore = "0.1.16"
LuxDeviceUtils = "0.1.22"
LuxLib = "0.3.23"
LuxTestUtils = "0.1.15"
Expand Down
5 changes: 2 additions & 3 deletions src/contrib/freeze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,9 @@

function Base.show(io::IO, f::FrozenLayer{which_params}) where {which_params}
if which_params === nothing
return print(io, f.layer, " (with all parameters frozen)")
return print(io, "FrozenLayer(", f.layer, ")")
end
wp = join(map(x -> "`$(x)`", which_params), ", ", " & ")
return print(io, f.layer, " (with ", wp, " frozen)")
print(io, "FrozenLayer(", f.layer, ", ", which_params, ")")

Check warning on line 100 in src/contrib/freeze.jl

View check run for this annotation

Codecov / codecov/patch

src/contrib/freeze.jl#L100

Added line #L100 was not covered by tests
end

"""
Expand Down
10 changes: 7 additions & 3 deletions src/helpers/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,19 +516,23 @@ end
function Lux._big_show(io::IO, obj::CompactLuxLayer, indent::Int=0, name=nothing)
setup_strings = obj.setup_strings
local_name = obj.name
layer, input, block = obj.strings
if local_name !== nothing && local_name != ""
Lux._layer_show(io, obj, indent, name)
_str = name === nothing ? "" : "$name = "
str = _str * local_name
print(io, " "^indent, str, indent == 0 ? "" : ",")
_show_parameters_count(io, obj, indent, str)
indent == 0 || println(io)
return
end
layer, input, block = obj.strings
pre, post = ("(", ")")
println(io, " "^indent, isnothing(name) ? "" : "$name = ", layer, pre)
for (k, v) in pairs(setup_strings)
val = _getproperty(obj.layers, Val(k))
if val === nothing
println(io, " "^(indent + 4), "$k = $v,")
else
Lux._big_show(io, val, indent + 4, k)
_big_show(io, val, indent + 4, k)
end
end
if indent == 0 # i.e. this is the outermost container
Expand Down
2 changes: 1 addition & 1 deletion src/helpers/stateful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ mutable struct StatefulLuxLayer{ST, M <: AbstractExplicitLayer, psType, stType}
end
end

function Base.show(io::IO, s::StatefulLuxLayer{ST}) where {ST}
function Base.show(io::IO, ::MIME"text/plain", s::StatefulLuxLayer{ST}) where {ST}
_print_wrapper_model(io, "StatefulLuxLayer{$ST}", s.model)
end

Expand Down
8 changes: 4 additions & 4 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ SelectDim(dim, index) = SelectDim{dim, index}()
return selectdim(x, dim, index), st
end

function Base.show(io::IO, s::SelectDim{dim, index}) where {dim, index}
return print(io, "SelectDim(dim = ", dim, ", index = ", index, ")")
function Base.show(io::IO, ::SelectDim{dim, index}) where {dim, index}
return print(io, "SelectDim(", dim, ", ", index, ")")
end

"""
Expand Down Expand Up @@ -267,8 +267,8 @@ end
@inline __maybe_direct_call(f, x, ps, st, ::Val{false}) = f(x, ps, st)
@inline __maybe_direct_call(f, x, ps, st, ::Val{true}) = f(x), st

function Base.show(io::IO, w::WrappedFunction)
return print(io, "WrappedFunction(", w.func, ")")
function Base.show(io::IO, w::WrappedFunction{T}) where {T}
return print(io, "WrappedFunction{$(Meta.quot(T))}($(w.func))")
end

"""
Expand Down
51 changes: 44 additions & 7 deletions src/layers/containers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
SkipConnection(layer, connection; name=nothing)
SkipConnection(layers, connection; name=nothing)
SkipConnection(; layers, connection, name=nothing)

Create a skip connection which consists of a layer or [`Chain`](@ref) of consecutive layers
and a shortcut connection linking the block's input to the output through a user-supplied
Expand Down Expand Up @@ -44,13 +45,24 @@

See [`Parallel`](@ref) for a more general implementation.
"""
@kwdef @concrete struct SkipConnection <: AbstractExplicitContainerLayer{(:layers,)}
@concrete struct SkipConnection <: AbstractExplicitContainerLayer{(:layers,)}
layers
connection
name = nothing
name
end

function Functors.functor(::Type{<:SkipConnection}, sc)
recon = let name = sc.name
nt -> SkipConnection(nt.layers, nt.connection, name)
end
return (; sc.layers, sc.connection), recon
end

function SkipConnection(layers, connection; name::NAME_TYPE=nothing)
return SkipConnection(; layers, connection, name)
end

function SkipConnection(; layers, connection, name::NAME_TYPE=nothing)
return SkipConnection(layers, connection, name)
end

Expand Down Expand Up @@ -81,6 +93,7 @@
"""
Parallel(connection, layers...; name=nothing)
Parallel(connection; name=nothing, layers...)
Parallel(; connection, layers..., name=nothing)

Create a layer which passes an input to each path in `layers`, before reducing the output
with `connection`.
Expand Down Expand Up @@ -149,11 +162,22 @@
name
end

function Functors.functor(::Type{<:Parallel}, p)
recon = let name = p.name
nt -> Parallel(nt.connection, nt.layers, name)
end
return (; p.connection, p.layers), recon
end

function Parallel(connection, layers...; name::NAME_TYPE=nothing)
return Parallel(connection, __named_tuple_layers(layers...), name)
end

function Parallel(connection; name::NAME_TYPE=nothing, kwargs...)
return Parallel(; connection, name, kwargs...)
end

function Parallel(; name::NAME_TYPE=nothing, connection, kwargs...)
return Parallel(connection, (; kwargs...), name)
end

Expand Down Expand Up @@ -276,6 +300,7 @@
"""
PairwiseFusion(connection, layers...; name=nothing)
PairwiseFusion(connection; name=nothing, layers...)
PairwiseFusion(; connection, layers..., name=nothing)

```
x1 → layer1 → y1 ↘
Expand Down Expand Up @@ -342,11 +367,22 @@
name
end

function Functors.functor(::Type{<:PairwiseFusion}, p)
recon = let name = p.name
nt -> PairwiseFusion(nt.connection, nt.layers, name)
end
return (; p.connection, p.layers), recon
end

function PairwiseFusion(connection, layers...; name::NAME_TYPE=nothing)
return PairwiseFusion(connection, __named_tuple_layers(layers...), name)
end

function PairwiseFusion(connection; name::NAME_TYPE=nothing, kwargs...)
return PairwiseFusion(; connection, name, kwargs...)
end

function PairwiseFusion(; name::NAME_TYPE=nothing, connection, kwargs...)
return PairwiseFusion(connection, (; kwargs...), name)
end

Expand Down Expand Up @@ -658,13 +694,14 @@
model::M
end

function LuxCore.display_name(::RepeatedLayer{N, IJ}) where {N, IJ}
return "RepeatedLayer{repeats = $N, input_injection = $IJ}"
end
LuxCore.display_name(::RepeatedLayer{N, IJ}) where {N, IJ} = "RepeatedLayer{$N, $IJ}"

RepeatedLayer{N, IJ}(model) where {N, IJ} = RepeatedLayer{N, IJ, typeof(model)}(model)
RepeatedLayer{N, IJ}(; model) where {N, IJ} = RepeatedLayer{N, IJ, typeof(model)}(model)

Check warning on line 700 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L700

Added line #L700 was not covered by tests

function RepeatedLayer(model::AbstractExplicitLayer; repeats::Val{N}=Val(10),
input_injection::Val{IJ}=Val(false)) where {N, IJ}
return RepeatedLayer{N, IJ, typeof(model)}(model)
return RepeatedLayer{N, IJ}(model)
end

(m::RepeatedLayer)(x, ps, st) = repeatedlayer(m, m.model, x, ps, st)
Expand Down
106 changes: 37 additions & 69 deletions src/layers/display.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,83 +8,59 @@
end
end

function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitLayer)
!get(io, :compact, false) && return _layer_show(io, x)
show(io, x)

Check warning on line 13 in src/layers/display.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/display.jl#L13

Added line #L13 was not covered by tests
end

function _big_show(io::IO, obj, indent::Int=0, name=nothing)
pre, post = "(", ")"
children = _get_children(obj)
if obj isa Function
println(io, " "^indent, isnothing(name) ? "" : "$name = ", obj)
elseif all(_show_leaflike, children)
if obj isa Function || obj isa Nothing
print(io, " "^indent, isnothing(name) ? "" : "$name = ", obj)
indent != 0 && println(io, ",")
return
end

children = _children(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent, name)
else
println(io, " "^indent, isnothing(name) ? "" : "$name = ", display_name(obj), pre)
if obj isa Chain
for k in Base.keys(obj)
_big_show(io, obj.layers[k], indent + 4, k)
end
elseif obj isa Parallel
if obj.connection !== nothing
_big_show(io, obj.connection, indent + 4)
end
for k in Base.keys(obj)
_big_show(io, obj.layers[k], indent + 4, k)
end
elseif obj isa PairwiseFusion
_big_show(io, obj.connection, indent + 4)
for k in Base.keys(obj)
_big_show(io, obj.layers[k], indent + 4, k)
end
elseif obj isa BranchLayer
for k in Base.keys(obj)
_big_show(io, obj.layers[k], indent + 4, k)
end
elseif obj isa Maxout
for k in Base.keys(obj)
_big_show(io, obj.layers[k], indent + 4, k)
end
elseif children isa NamedTuple
for (k, c) in pairs(children)
_big_show(io, c, indent + 4, k)
end
else
for c in children
_big_show(io, c, indent + 4)
end
println(io, " "^indent, isnothing(name) ? "" : "$name = ", display_name(obj), "(")
for (k, c) in pairs(children)
_big_show(io, c, indent + 4, k)
end
if indent == 0 # i.e. this is the outermost container
print(io, rpad(post, 2))
print(io, rpad(")", 2))
_big_finale(io, obj)
else
println(io, " "^indent, post, ",")
println(io, " "^indent, ")", ",")
end
end
end

_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for:
_show_leaflike(x::AbstractExplicitLayer) = false

function _get_children(l::AbstractExplicitContainerLayer{names}) where {names}
return NamedTuple{names}(getfield.((l,), names))
end
function _get_children(p::Parallel)
return p.connection === nothing ? p.layers : (p.connection, p.layers...)
_children(x::AbstractExplicitLayer) = Functors.children(x)
function _children(m::AbstractExplicitContainerLayer{layers}) where {layers}
children = Functors.children(m)
length(layers) ≥ 2 && return children
field = first(layers)
hasfield(typeof(children), field) || return children
nt = getfield(children, field)
nt isa NamedTuple || (nt = NamedTuple{(field,)}((nt,)))
return merge(Base.structdiff(children, NamedTuple{(field,)}), nt)
end
_get_children(s::SkipConnection) = (s.layers, s.connection)
_get_children(s::WeightNorm) = (s.layer,)
_get_children(nt::NamedTuple) = nt
_get_children(::Any) = ()

function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitLayer)
if !get(io, :compact, false)
_layer_show(io, x)
else
show(io, x)
end
end
_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for:
_show_leaflike(x::AbstractExplicitLayer) = false

function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
_str = isnothing(name) ? "" : "$name = "
str = _str * sprint(show, layer; context=io)
print(io, " "^indent, str, indent == 0 ? "" : ",")
_show_parameters_count(io, layer, indent, str)
indent == 0 || println(io)
return
end

function _show_parameters_count(io::IO, layer, indent, str::String)
paramlength = parameterlength(layer)
if paramlength > 0
print(io, " "^max(2, (indent == 0 ? 20 : 39) - indent - length(str)))
Expand All @@ -95,24 +71,16 @@
indent == 0 ? " non-trainable" : ""; color=:light_black)
end
end
indent == 0 || println(io)
return
end

function _big_finale(io::IO, m, len=8)
paramlength = parameterlength(m)
nonparamlength = statelength(m)
pars = underscorise(paramlength)
nonparam = underscorise(nonparamlength)
printstyled(io, " "^len, "# Total: "; color=:light_black)
println(io, pars, " parameters,")
println(io, underscorise(parameterlength(m)), " parameters,")
printstyled(io, " "^10, "# plus "; color=:light_black)
print(io, nonparam, " states.")
print(io, underscorise(statelength(m)), " states.")
return
end

# utility functions

function underscorise(n::Integer)
return join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_')
end
Expand Down
2 changes: 1 addition & 1 deletion src/layers/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ function (d::VariationalHiddenDropout)(x, ps, st::NamedTuple)
return y, merge(st, (; mask, rng, update_mask=Val(false)))
end

function Base.show(io::IO, d::VariationalHiddenDropout)
function Base.show(io::IO, ::MIME"text/plain", d::VariationalHiddenDropout)
print(io, "VariationalHiddenDropout(", d.p)
d.dims != Colon() && print(io, ", dims=", d.dims)
return print(io, ")")
Expand Down
7 changes: 4 additions & 3 deletions src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
bumper = Val(false)
end

function Base.show(io::IO, l::DynamicExpressionsLayer)
function Base.show(io::IO, ::MIME"text/plain", l::DynamicExpressionsLayer)

Check warning on line 93 in src/layers/extension.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/extension.jl#L93

Added line #L93 was not covered by tests
return print(io, "DynamicExpressionNode($(l.expression))")
end

Expand Down Expand Up @@ -183,7 +183,7 @@

(l::FluxLayer)(x, ps, st) = l.re(ps.p)(x), st

Base.show(io::IO, l::FluxLayer) = print(io, "FluxLayer($(l.layer))")
Base.show(io::IO, ::MIME"text/plain", l::FluxLayer) = print(io, "FluxLayer($(l.layer))")

## SimpleChains.jl

Expand Down Expand Up @@ -224,7 +224,8 @@
end
end

function Base.show(io::IO, s::SimpleChainsLayer{ToArray}) where {ToArray}
function Base.show(
io::IO, ::MIME"text/plain", s::SimpleChainsLayer{ToArray}) where {ToArray}
_print_wrapper_model(io, "SimpleChainsLayer{$ToArray}", s.lux_layer)
end

Expand Down
Loading
Loading