From a6de2afd733242b2c99b9ad0212407547e108d3a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 25 Jun 2024 21:14:45 -0700 Subject: [PATCH] Make show non-trippable --- Project.toml | 2 +- src/contrib/freeze.jl | 5 +- src/helpers/compact.jl | 10 ++-- src/helpers/stateful.jl | 2 +- src/layers/basic.jl | 8 +-- src/layers/containers.jl | 47 +++++++++++++++-- src/layers/display.jl | 106 ++++++++++++++------------------------- src/layers/dropout.jl | 2 +- src/layers/extension.jl | 7 +-- src/layers/normalize.jl | 9 +++- src/layers/recurrent.jl | 8 +-- test/shared_testsetup.jl | 7 ++- 12 files changed, 116 insertions(+), 97 deletions(-) diff --git a/Project.toml b/Project.toml index 3a788a24c2..771125a6e7 100644 --- a/Project.toml +++ b/Project.toml @@ -93,7 +93,7 @@ GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" Logging = "1.10" LossFunctions = "0.11.1" -LuxCore = "0.1.14" +LuxCore = "0.1.16" LuxDeviceUtils = "0.1.22" LuxLib = "0.3.23" LuxTestUtils = "0.1.15" diff --git a/src/contrib/freeze.jl b/src/contrib/freeze.jl index d017627a52..b715393dd7 100644 --- a/src/contrib/freeze.jl +++ b/src/contrib/freeze.jl @@ -95,10 +95,9 @@ end function Base.show(io::IO, f::FrozenLayer{which_params}) where {which_params} if which_params === nothing - return print(io, f.layer, " (with all parameters frozen)") + return print(io, "FrozenLayer(", f.layer, ")") end - wp = join(map(x -> "`$(x)`", which_params), ", ", " & ") - return print(io, f.layer, " (with ", wp, " frozen)") + print(io, "FrozenLayer(", f.layer, ", ", which_params, ")") end """ diff --git a/src/helpers/compact.jl b/src/helpers/compact.jl index 5f5647acd6..6eb8b84b1d 100644 --- a/src/helpers/compact.jl +++ b/src/helpers/compact.jl @@ -516,11 +516,15 @@ end function Lux._big_show(io::IO, obj::CompactLuxLayer, indent::Int=0, name=nothing) setup_strings = obj.setup_strings local_name = obj.name - layer, input, block = obj.strings if local_name !== nothing && local_name != "" - Lux._layer_show(io, obj, indent, name) + _str = name === nothing ? "" : "$name = " + str = _str * local_name + print(io, " "^indent, str, indent == 0 ? "" : ",") + _show_parameters_count(io, obj, indent, str) + indent == 0 || println(io) return end + layer, input, block = obj.strings pre, post = ("(", ")") println(io, " "^indent, isnothing(name) ? "" : "$name = ", layer, pre) for (k, v) in pairs(setup_strings) @@ -528,7 +532,7 @@ function Lux._big_show(io::IO, obj::CompactLuxLayer, indent::Int=0, name=nothing if val === nothing println(io, " "^(indent + 4), "$k = $v,") else - Lux._big_show(io, val, indent + 4, k) + _big_show(io, val, indent + 4, k) end end if indent == 0 # i.e. this is the outermost container diff --git a/src/helpers/stateful.jl b/src/helpers/stateful.jl index 0b505613ca..978bf9ab7f 100644 --- a/src/helpers/stateful.jl +++ b/src/helpers/stateful.jl @@ -52,7 +52,7 @@ mutable struct StatefulLuxLayer{ST, M <: AbstractExplicitLayer, psType, stType} end end -function Base.show(io::IO, s::StatefulLuxLayer{ST}) where {ST} +function Base.show(io::IO, ::MIME"text/plain", s::StatefulLuxLayer{ST}) where {ST} _print_wrapper_model(io, "StatefulLuxLayer{$ST}", s.model) end diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 0bf7356db3..3bba29e996 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -181,8 +181,8 @@ SelectDim(dim, index) = SelectDim{dim, index}() return selectdim(x, dim, index), st end -function Base.show(io::IO, s::SelectDim{dim, index}) where {dim, index} - return print(io, "SelectDim(dim = ", dim, ", index = ", index, ")") +function Base.show(io::IO, ::SelectDim{dim, index}) where {dim, index} + return print(io, "SelectDim(", dim, ", ", index, ")") end """ @@ -267,8 +267,8 @@ end @inline __maybe_direct_call(f, x, ps, st, ::Val{false}) = f(x, ps, st) @inline __maybe_direct_call(f, x, ps, st, ::Val{true}) = f(x), st -function Base.show(io::IO, w::WrappedFunction) - return print(io, "WrappedFunction(", w.func, ")") +function Base.show(io::IO, w::WrappedFunction{T}) where {T} + return print(io, "WrappedFunction{$(Meta.quot(T))}($(w.func))") end """ diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 2b01c84167..14e8213b16 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -1,5 +1,6 @@ """ - SkipConnection(layer, connection; name=nothing) + SkipConnection(layers, connection; name=nothing) + SkipConnection(; layers, connection, name=nothing) Create a skip connection which consists of a layer or [`Chain`](@ref) of consecutive layers and a shortcut connection linking the block's input to the output through a user-supplied @@ -50,7 +51,18 @@ See [`Parallel`](@ref) for a more general implementation. name = nothing end +function Functors.functor(::Type{<:SkipConnection}, sc) + recon = let name = sc.name + (l, c) -> SkipConnection(l, c, name) + end + return (; sc.layers, sc.connection), recon +end + function SkipConnection(layers, connection; name::NAME_TYPE=nothing) + return SkipConnection(; layers, connection, name) +end + +function SkipConnection(; layers, connection, name::NAME_TYPE=nothing) return SkipConnection(layers, connection, name) end @@ -81,6 +93,7 @@ end """ Parallel(connection, layers...; name=nothing) Parallel(connection; name=nothing, layers...) + Parallel(; connection, layers..., name=nothing) Create a layer which passes an input to each path in `layers`, before reducing the output with `connection`. @@ -149,11 +162,22 @@ julia> size.(first(model((x1, x2), ps, st))) name end +function Functors.functor(::Type{<:Parallel}, p) + recon = let name = p.name + (c, ls) -> Parallel(c, ls, name) + end + return (; p.connection, p.layers), recon +end + function Parallel(connection, layers...; name::NAME_TYPE=nothing) return Parallel(connection, __named_tuple_layers(layers...), name) end function Parallel(connection; name::NAME_TYPE=nothing, kwargs...) + return Parallel(; connection, name, kwargs...) +end + +function Parallel(; name::NAME_TYPE=nothing, connection, kwargs...) return Parallel(connection, (; kwargs...), name) end @@ -276,6 +300,7 @@ Base.keys(m::BranchLayer) = Base.keys(getfield(m, :layers)) """ PairwiseFusion(connection, layers...; name=nothing) PairwiseFusion(connection; name=nothing, layers...) + PairwiseFusion(; connection, layers..., name=nothing) ``` x1 → layer1 → y1 ↘ @@ -342,11 +367,22 @@ end name end +function Functors.functor(::Type{<:PairwiseFusion}, p) + recon = let name = p.name + (c, ls) -> PairwiseFusion(c, ls, name) + end + return (; p.connection, p.layers), recon +end + function PairwiseFusion(connection, layers...; name::NAME_TYPE=nothing) return PairwiseFusion(connection, __named_tuple_layers(layers...), name) end function PairwiseFusion(connection; name::NAME_TYPE=nothing, kwargs...) + return PairwiseFusion(; connection, name, kwargs...) +end + +function PairwiseFusion(; name::NAME_TYPE=nothing, connection, kwargs...) return PairwiseFusion(connection, (; kwargs...), name) end @@ -658,13 +694,14 @@ struct RepeatedLayer{N, IJ, M <: AbstractExplicitLayer} <: model::M end -function LuxCore.display_name(::RepeatedLayer{N, IJ}) where {N, IJ} - return "RepeatedLayer{repeats = $N, input_injection = $IJ}" -end +LuxCore.display_name(::RepeatedLayer{N, IJ}) where {N, IJ} = "RepeatedLayer{$N, $IJ}" + +RepeatedLayer{N, IJ}(model) where {N, IJ} = RepeatedLayer{N, IJ, typeof(model)}(model) +RepeatedLayer{N, IJ}(; model) where {N, IJ} = RepeatedLayer{N, IJ, typeof(model)}(model) function RepeatedLayer(model::AbstractExplicitLayer; repeats::Val{N}=Val(10), input_injection::Val{IJ}=Val(false)) where {N, IJ} - return RepeatedLayer{N, IJ, typeof(model)}(model) + return RepeatedLayer{N, IJ}(model) end (m::RepeatedLayer)(x, ps, st) = repeatedlayer(m, m.model, x, ps, st) diff --git a/src/layers/display.jl b/src/layers/display.jl index 5e661a6b69..df4bd71f34 100644 --- a/src/layers/display.jl +++ b/src/layers/display.jl @@ -8,83 +8,59 @@ function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitContainerLayer end end +function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitLayer) + !get(io, :compact, false) && return _layer_show(io, x) + show(io, x) +end + function _big_show(io::IO, obj, indent::Int=0, name=nothing) - pre, post = "(", ")" - children = _get_children(obj) - if obj isa Function - println(io, " "^indent, isnothing(name) ? "" : "$name = ", obj) - elseif all(_show_leaflike, children) + if obj isa Function || obj isa Nothing + print(io, " "^indent, isnothing(name) ? "" : "$name = ", obj) + indent != 0 && println(io, ",") + return + end + + children = _children(obj) + if all(_show_leaflike, children) _layer_show(io, obj, indent, name) else - println(io, " "^indent, isnothing(name) ? "" : "$name = ", display_name(obj), pre) - if obj isa Chain - for k in Base.keys(obj) - _big_show(io, obj.layers[k], indent + 4, k) - end - elseif obj isa Parallel - if obj.connection !== nothing - _big_show(io, obj.connection, indent + 4) - end - for k in Base.keys(obj) - _big_show(io, obj.layers[k], indent + 4, k) - end - elseif obj isa PairwiseFusion - _big_show(io, obj.connection, indent + 4) - for k in Base.keys(obj) - _big_show(io, obj.layers[k], indent + 4, k) - end - elseif obj isa BranchLayer - for k in Base.keys(obj) - _big_show(io, obj.layers[k], indent + 4, k) - end - elseif obj isa Maxout - for k in Base.keys(obj) - _big_show(io, obj.layers[k], indent + 4, k) - end - elseif children isa NamedTuple - for (k, c) in pairs(children) - _big_show(io, c, indent + 4, k) - end - else - for c in children - _big_show(io, c, indent + 4) - end + println(io, " "^indent, isnothing(name) ? "" : "$name = ", display_name(obj), "(") + for (k, c) in pairs(children) + _big_show(io, c, indent + 4, k) end if indent == 0 # i.e. this is the outermost container - print(io, rpad(post, 2)) + print(io, rpad(")", 2)) _big_finale(io, obj) else - println(io, " "^indent, post, ",") + println(io, " "^indent, ")", ",") end end end -_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for: -_show_leaflike(x::AbstractExplicitLayer) = false - -function _get_children(l::AbstractExplicitContainerLayer{names}) where {names} - return NamedTuple{names}(getfield.((l,), names)) -end -function _get_children(p::Parallel) - return p.connection === nothing ? p.layers : (p.connection, p.layers...) +_children(x::AbstractExplicitLayer) = Functors.children(x) +function _children(m::AbstractExplicitContainerLayer{layers}) where {layers} + children = Functors.children(m) + length(layers) ≥ 2 && return children + field = first(layers) + hasfield(typeof(children), field) || return children + nt = getfield(children, field) + nt isa NamedTuple || (nt = NamedTuple{(field,)}((nt,))) + return merge(Base.structdiff(children, NamedTuple{(field,)}), nt) end -_get_children(s::SkipConnection) = (s.layers, s.connection) -_get_children(s::WeightNorm) = (s.layer,) -_get_children(nt::NamedTuple) = nt -_get_children(::Any) = () -function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitLayer) - if !get(io, :compact, false) - _layer_show(io, x) - else - show(io, x) - end -end +_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for: +_show_leaflike(x::AbstractExplicitLayer) = false function _layer_show(io::IO, layer, indent::Int=0, name=nothing) _str = isnothing(name) ? "" : "$name = " str = _str * sprint(show, layer; context=io) print(io, " "^indent, str, indent == 0 ? "" : ",") + _show_parameters_count(io, layer, indent, str) + indent == 0 || println(io) + return +end + +function _show_parameters_count(io::IO, layer, indent, str::String) paramlength = parameterlength(layer) if paramlength > 0 print(io, " "^max(2, (indent == 0 ? 20 : 39) - indent - length(str))) @@ -95,24 +71,16 @@ function _layer_show(io::IO, layer, indent::Int=0, name=nothing) indent == 0 ? " non-trainable" : ""; color=:light_black) end end - indent == 0 || println(io) - return end function _big_finale(io::IO, m, len=8) - paramlength = parameterlength(m) - nonparamlength = statelength(m) - pars = underscorise(paramlength) - nonparam = underscorise(nonparamlength) printstyled(io, " "^len, "# Total: "; color=:light_black) - println(io, pars, " parameters,") + println(io, underscorise(parameterlength(m)), " parameters,") printstyled(io, " "^10, "# plus "; color=:light_black) - print(io, nonparam, " states.") + print(io, underscorise(statelength(m)), " states.") return end -# utility functions - function underscorise(n::Integer) return join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') end diff --git a/src/layers/dropout.jl b/src/layers/dropout.jl index 7e00b1f85a..eb84a8a4dd 100644 --- a/src/layers/dropout.jl +++ b/src/layers/dropout.jl @@ -180,7 +180,7 @@ function (d::VariationalHiddenDropout)(x, ps, st::NamedTuple) return y, merge(st, (; mask, rng, update_mask=Val(false))) end -function Base.show(io::IO, d::VariationalHiddenDropout) +function Base.show(io::IO, ::MIME"text/plain", d::VariationalHiddenDropout) print(io, "VariationalHiddenDropout(", d.p) d.dims != Colon() && print(io, ", dims=", d.dims) return print(io, ")") diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 43c1c7d61f..76c59241d6 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -90,7 +90,7 @@ true bumper = Val(false) end -function Base.show(io::IO, l::DynamicExpressionsLayer) +function Base.show(io::IO, ::MIME"text/plain", l::DynamicExpressionsLayer) return print(io, "DynamicExpressionNode($(l.expression))") end @@ -183,7 +183,7 @@ Lux.initialparameters(::AbstractRNG, l::FluxLayer) = (p=l.init_parameters(),) (l::FluxLayer)(x, ps, st) = l.re(ps.p)(x), st -Base.show(io::IO, l::FluxLayer) = print(io, "FluxLayer($(l.layer))") +Base.show(io::IO, ::MIME"text/plain", l::FluxLayer) = print(io, "FluxLayer($(l.layer))") ## SimpleChains.jl @@ -224,7 +224,8 @@ struct SimpleChainsLayer{ToArray, SL, LL <: Union{Nothing, AbstractExplicitLayer end end -function Base.show(io::IO, s::SimpleChainsLayer{ToArray}) where {ToArray} +function Base.show( + io::IO, ::MIME"text/plain", s::SimpleChainsLayer{ToArray}) where {ToArray} _print_wrapper_model(io, "SimpleChainsLayer{$ToArray}", s.lux_layer) end diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 629b958bbd..dfa9c67ec5 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -422,9 +422,14 @@ parameters: one specifying the magnitude (e.g. `weight_g`) and one specifying th dims end +function WeightNorm{which_params}(layer::AbstractExplicitLayer; + dims::Union{Tuple, Nothing}=nothing) where {which_params} + return WeightNorm{which_params}(layer, dims) +end + function WeightNorm(layer::AbstractExplicitLayer, which_params::NTuple{N, Symbol}, dims::Union{Tuple, Nothing}=nothing) where {N} - return WeightNorm{which_params}(layer, dims) + return WeightNorm{which_params}(layer; dims) end @inline _norm(x; dims=Colon()) = sqrt.(sum(abs2, x; dims)) @@ -552,7 +557,7 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. - `affine=false`: Empty `NamedTuple()` - `affine=true` - + + `bias`: Bias of shape `(shape..., 1)` + `scale`: Scale of shape `(shape..., 1)` """ diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 41618acec2..6273848cce 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -297,7 +297,7 @@ end function Base.show(io::IO, r::RNNCell{use_bias, TS}) where {use_bias, TS} print(io, "RNNCell($(r.in_dims) => $(r.out_dims)") (r.activation == identity) || print(io, ", $(r.activation)") - use_bias || print(io, ", bias=false") + use_bias || print(io, ", use_bias=false") TS && print(io, ", train_state=true") return print(io, ")") end @@ -480,7 +480,7 @@ end function Base.show(io::IO, lstm::LSTMCell{use_bias, TS, train_memory}) where {use_bias, TS, train_memory} print(io, "LSTMCell($(lstm.in_dims) => $(lstm.out_dims)") - use_bias || print(io, ", bias=false") + use_bias || print(io, ", use_bias=false") TS && print(io, ", train_state=true") train_memory && print(io, ", train_memory=true") return print(io, ")") @@ -525,7 +525,7 @@ Gated Recurrent Unit (GRU) Cell updated hidden state is returned. ## Returns - + - Tuple containing + Output ``h_{new}`` of shape `(out_dims, batch_size)` @@ -631,7 +631,7 @@ end function Base.show(io::IO, g::GRUCell{use_bias, TS}) where {use_bias, TS} print(io, "GRUCell($(g.in_dims) => $(g.out_dims)") - use_bias || print(io, ", bias=false") + use_bias || print(io, ", use_bias=false") TS && print(io, ", train_state=true") return print(io, ")") end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 0103d5605f..a3578342b3 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -17,8 +17,13 @@ function get_default_rng(mode::String) return rng isa TaskLocalRNG ? copy(rng) : deepcopy(rng) end +function test_roundtrip_show(l) + display(l) + @test eval(Meta.parse(sprint(show, l))) == l +end + export @jet, @test_gradients, check_approx export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, get_default_rng, - StableRNG + StableRNG, test_roundtrip_show end