Skip to content

Commit

Permalink
Make show non-trippable
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 26, 2024
1 parent 201bcb4 commit a6de2af
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 97 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ GPUArraysCore = "0.1.6"
LinearAlgebra = "1.10"
Logging = "1.10"
LossFunctions = "0.11.1"
LuxCore = "0.1.14"
LuxCore = "0.1.16"
LuxDeviceUtils = "0.1.22"
LuxLib = "0.3.23"
LuxTestUtils = "0.1.15"
Expand Down
5 changes: 2 additions & 3 deletions src/contrib/freeze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,9 @@ end

function Base.show(io::IO, f::FrozenLayer{which_params}) where {which_params}
if which_params === nothing
return print(io, f.layer, " (with all parameters frozen)")
return print(io, "FrozenLayer(", f.layer, ")")

Check warning on line 98 in src/contrib/freeze.jl

View check run for this annotation

Codecov / codecov/patch

src/contrib/freeze.jl#L98

Added line #L98 was not covered by tests
end
wp = join(map(x -> "`$(x)`", which_params), ", ", " & ")
return print(io, f.layer, " (with ", wp, " frozen)")
print(io, "FrozenLayer(", f.layer, ", ", which_params, ")")

Check warning on line 100 in src/contrib/freeze.jl

View check run for this annotation

Codecov / codecov/patch

src/contrib/freeze.jl#L100

Added line #L100 was not covered by tests
end

"""
Expand Down
10 changes: 7 additions & 3 deletions src/helpers/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,19 +516,23 @@ end
function Lux._big_show(io::IO, obj::CompactLuxLayer, indent::Int=0, name=nothing)
setup_strings = obj.setup_strings
local_name = obj.name
layer, input, block = obj.strings
if local_name !== nothing && local_name != ""
Lux._layer_show(io, obj, indent, name)
_str = name === nothing ? "" : "$name = "
str = _str * local_name
print(io, " "^indent, str, indent == 0 ? "" : ",")
_show_parameters_count(io, obj, indent, str)
indent == 0 || println(io)

Check warning on line 524 in src/helpers/compact.jl

View check run for this annotation

Codecov / codecov/patch

src/helpers/compact.jl#L520-L524

Added lines #L520 - L524 were not covered by tests
return
end
layer, input, block = obj.strings

Check warning on line 527 in src/helpers/compact.jl

View check run for this annotation

Codecov / codecov/patch

src/helpers/compact.jl#L527

Added line #L527 was not covered by tests
pre, post = ("(", ")")
println(io, " "^indent, isnothing(name) ? "" : "$name = ", layer, pre)
for (k, v) in pairs(setup_strings)
val = _getproperty(obj.layers, Val(k))
if val === nothing
println(io, " "^(indent + 4), "$k = $v,")
else
Lux._big_show(io, val, indent + 4, k)
_big_show(io, val, indent + 4, k)

Check warning on line 535 in src/helpers/compact.jl

View check run for this annotation

Codecov / codecov/patch

src/helpers/compact.jl#L535

Added line #L535 was not covered by tests
end
end
if indent == 0 # i.e. this is the outermost container
Expand Down
2 changes: 1 addition & 1 deletion src/helpers/stateful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ mutable struct StatefulLuxLayer{ST, M <: AbstractExplicitLayer, psType, stType}
end
end

function Base.show(io::IO, s::StatefulLuxLayer{ST}) where {ST}
function Base.show(io::IO, ::MIME"text/plain", s::StatefulLuxLayer{ST}) where {ST}
_print_wrapper_model(io, "StatefulLuxLayer{$ST}", s.model)
end

Expand Down
8 changes: 4 additions & 4 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ SelectDim(dim, index) = SelectDim{dim, index}()
return selectdim(x, dim, index), st
end

function Base.show(io::IO, s::SelectDim{dim, index}) where {dim, index}
return print(io, "SelectDim(dim = ", dim, ", index = ", index, ")")
function Base.show(io::IO, ::SelectDim{dim, index}) where {dim, index}
return print(io, "SelectDim(", dim, ", ", index, ")")

Check warning on line 185 in src/layers/basic.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/basic.jl#L184-L185

Added lines #L184 - L185 were not covered by tests
end

"""
Expand Down Expand Up @@ -267,8 +267,8 @@ end
@inline __maybe_direct_call(f, x, ps, st, ::Val{false}) = f(x, ps, st)
@inline __maybe_direct_call(f, x, ps, st, ::Val{true}) = f(x), st

function Base.show(io::IO, w::WrappedFunction)
return print(io, "WrappedFunction(", w.func, ")")
function Base.show(io::IO, w::WrappedFunction{T}) where {T}
return print(io, "WrappedFunction{$(Meta.quot(T))}($(w.func))")
end

"""
Expand Down
47 changes: 42 additions & 5 deletions src/layers/containers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
SkipConnection(layer, connection; name=nothing)
SkipConnection(layers, connection; name=nothing)
SkipConnection(; layers, connection, name=nothing)
Create a skip connection which consists of a layer or [`Chain`](@ref) of consecutive layers
and a shortcut connection linking the block's input to the output through a user-supplied
Expand Down Expand Up @@ -50,7 +51,18 @@ See [`Parallel`](@ref) for a more general implementation.
name = nothing
end

function Functors.functor(::Type{<:SkipConnection}, sc)
recon = let name = sc.name
(l, c) -> SkipConnection(l, c, name)

Check warning on line 56 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L54-L56

Added lines #L54 - L56 were not covered by tests
end
return (; sc.layers, sc.connection), recon

Check warning on line 58 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L58

Added line #L58 was not covered by tests
end

function SkipConnection(layers, connection; name::NAME_TYPE=nothing)
return SkipConnection(; layers, connection, name)

Check warning on line 62 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L62

Added line #L62 was not covered by tests
end

function SkipConnection(; layers, connection, name::NAME_TYPE=nothing)

Check warning on line 65 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L65

Added line #L65 was not covered by tests
return SkipConnection(layers, connection, name)
end

Expand Down Expand Up @@ -81,6 +93,7 @@ end
"""
Parallel(connection, layers...; name=nothing)
Parallel(connection; name=nothing, layers...)
Parallel(; connection, layers..., name=nothing)
Create a layer which passes an input to each path in `layers`, before reducing the output
with `connection`.
Expand Down Expand Up @@ -149,11 +162,22 @@ julia> size.(first(model((x1, x2), ps, st)))
name
end

function Functors.functor(::Type{<:Parallel}, p)
recon = let name = p.name
(c, ls) -> Parallel(c, ls, name)

Check warning on line 167 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L165-L167

Added lines #L165 - L167 were not covered by tests
end
return (; p.connection, p.layers), recon

Check warning on line 169 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L169

Added line #L169 was not covered by tests
end

function Parallel(connection, layers...; name::NAME_TYPE=nothing)
return Parallel(connection, __named_tuple_layers(layers...), name)
end

function Parallel(connection; name::NAME_TYPE=nothing, kwargs...)
return Parallel(; connection, name, kwargs...)

Check warning on line 177 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L177

Added line #L177 was not covered by tests
end

function Parallel(; name::NAME_TYPE=nothing, connection, kwargs...)

Check warning on line 180 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L180

Added line #L180 was not covered by tests
return Parallel(connection, (; kwargs...), name)
end

Expand Down Expand Up @@ -276,6 +300,7 @@ Base.keys(m::BranchLayer) = Base.keys(getfield(m, :layers))
"""
PairwiseFusion(connection, layers...; name=nothing)
PairwiseFusion(connection; name=nothing, layers...)
PairwiseFusion(; connection, layers..., name=nothing)
```
x1 → layer1 → y1 ↘
Expand Down Expand Up @@ -342,11 +367,22 @@ end
name
end

function Functors.functor(::Type{<:PairwiseFusion}, p)
recon = let name = p.name
(c, ls) -> PairwiseFusion(c, ls, name)

Check warning on line 372 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L370-L372

Added lines #L370 - L372 were not covered by tests
end
return (; p.connection, p.layers), recon

Check warning on line 374 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L374

Added line #L374 was not covered by tests
end

function PairwiseFusion(connection, layers...; name::NAME_TYPE=nothing)
return PairwiseFusion(connection, __named_tuple_layers(layers...), name)
end

function PairwiseFusion(connection; name::NAME_TYPE=nothing, kwargs...)
return PairwiseFusion(; connection, name, kwargs...)

Check warning on line 382 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L382

Added line #L382 was not covered by tests
end

function PairwiseFusion(; name::NAME_TYPE=nothing, connection, kwargs...)

Check warning on line 385 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L385

Added line #L385 was not covered by tests
return PairwiseFusion(connection, (; kwargs...), name)
end

Expand Down Expand Up @@ -658,13 +694,14 @@ struct RepeatedLayer{N, IJ, M <: AbstractExplicitLayer} <:
model::M
end

function LuxCore.display_name(::RepeatedLayer{N, IJ}) where {N, IJ}
return "RepeatedLayer{repeats = $N, input_injection = $IJ}"
end
LuxCore.display_name(::RepeatedLayer{N, IJ}) where {N, IJ} = "RepeatedLayer{$N, $IJ}"

Check warning on line 697 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L697

Added line #L697 was not covered by tests

RepeatedLayer{N, IJ}(model) where {N, IJ} = RepeatedLayer{N, IJ, typeof(model)}(model)
RepeatedLayer{N, IJ}(; model) where {N, IJ} = RepeatedLayer{N, IJ, typeof(model)}(model)

Check warning on line 700 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L699-L700

Added lines #L699 - L700 were not covered by tests

function RepeatedLayer(model::AbstractExplicitLayer; repeats::Val{N}=Val(10),
input_injection::Val{IJ}=Val(false)) where {N, IJ}
return RepeatedLayer{N, IJ, typeof(model)}(model)
return RepeatedLayer{N, IJ}(model)

Check warning on line 704 in src/layers/containers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/containers.jl#L704

Added line #L704 was not covered by tests
end

(m::RepeatedLayer)(x, ps, st) = repeatedlayer(m, m.model, x, ps, st)
Expand Down
106 changes: 37 additions & 69 deletions src/layers/display.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,83 +8,59 @@ function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitContainerLayer
end
end

function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitLayer)
!get(io, :compact, false) && return _layer_show(io, x)
show(io, x)

Check warning on line 13 in src/layers/display.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/display.jl#L11-L13

Added lines #L11 - L13 were not covered by tests
end

function _big_show(io::IO, obj, indent::Int=0, name=nothing)
pre, post = "(", ")"
children = _get_children(obj)
if obj isa Function
println(io, " "^indent, isnothing(name) ? "" : "$name = ", obj)
elseif all(_show_leaflike, children)
if obj isa Function || obj isa Nothing
print(io, " "^indent, isnothing(name) ? "" : "$name = ", obj)
indent != 0 && println(io, ",")
return

Check warning on line 20 in src/layers/display.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/display.jl#L18-L20

Added lines #L18 - L20 were not covered by tests
end

children = _children(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent, name)
else
println(io, " "^indent, isnothing(name) ? "" : "$name = ", display_name(obj), pre)
if obj isa Chain
for k in Base.keys(obj)
_big_show(io, obj.layers[k], indent + 4, k)
end
elseif obj isa Parallel
if obj.connection !== nothing
_big_show(io, obj.connection, indent + 4)
end
for k in Base.keys(obj)
_big_show(io, obj.layers[k], indent + 4, k)
end
elseif obj isa PairwiseFusion
_big_show(io, obj.connection, indent + 4)
for k in Base.keys(obj)
_big_show(io, obj.layers[k], indent + 4, k)
end
elseif obj isa BranchLayer
for k in Base.keys(obj)
_big_show(io, obj.layers[k], indent + 4, k)
end
elseif obj isa Maxout
for k in Base.keys(obj)
_big_show(io, obj.layers[k], indent + 4, k)
end
elseif children isa NamedTuple
for (k, c) in pairs(children)
_big_show(io, c, indent + 4, k)
end
else
for c in children
_big_show(io, c, indent + 4)
end
println(io, " "^indent, isnothing(name) ? "" : "$name = ", display_name(obj), "(")
for (k, c) in pairs(children)
_big_show(io, c, indent + 4, k)
end
if indent == 0 # i.e. this is the outermost container
print(io, rpad(post, 2))
print(io, rpad(")", 2))
_big_finale(io, obj)
else
println(io, " "^indent, post, ",")
println(io, " "^indent, ")", ",")
end
end
end

_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for:
_show_leaflike(x::AbstractExplicitLayer) = false

function _get_children(l::AbstractExplicitContainerLayer{names}) where {names}
return NamedTuple{names}(getfield.((l,), names))
end
function _get_children(p::Parallel)
return p.connection === nothing ? p.layers : (p.connection, p.layers...)
_children(x::AbstractExplicitLayer) = Functors.children(x)
function _children(m::AbstractExplicitContainerLayer{layers}) where {layers}
children = Functors.children(m)
length(layers) 2 && return children
field = first(layers)
hasfield(typeof(children), field) || return children
nt = getfield(children, field)
nt isa NamedTuple || (nt = NamedTuple{(field,)}((nt,)))
return merge(Base.structdiff(children, NamedTuple{(field,)}), nt)
end
_get_children(s::SkipConnection) = (s.layers, s.connection)
_get_children(s::WeightNorm) = (s.layer,)
_get_children(nt::NamedTuple) = nt
_get_children(::Any) = ()

function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitLayer)
if !get(io, :compact, false)
_layer_show(io, x)
else
show(io, x)
end
end
_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for:

Check warning on line 51 in src/layers/display.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/display.jl#L51

Added line #L51 was not covered by tests
_show_leaflike(x::AbstractExplicitLayer) = false

function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
_str = isnothing(name) ? "" : "$name = "
str = _str * sprint(show, layer; context=io)
print(io, " "^indent, str, indent == 0 ? "" : ",")
_show_parameters_count(io, layer, indent, str)
indent == 0 || println(io)
return
end

function _show_parameters_count(io::IO, layer, indent, str::String)
paramlength = parameterlength(layer)
if paramlength > 0
print(io, " "^max(2, (indent == 0 ? 20 : 39) - indent - length(str)))
Expand All @@ -95,24 +71,16 @@ function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
indent == 0 ? " non-trainable" : ""; color=:light_black)
end
end
indent == 0 || println(io)
return
end

function _big_finale(io::IO, m, len=8)
paramlength = parameterlength(m)
nonparamlength = statelength(m)
pars = underscorise(paramlength)
nonparam = underscorise(nonparamlength)
printstyled(io, " "^len, "# Total: "; color=:light_black)
println(io, pars, " parameters,")
println(io, underscorise(parameterlength(m)), " parameters,")
printstyled(io, " "^10, "# plus "; color=:light_black)
print(io, nonparam, " states.")
print(io, underscorise(statelength(m)), " states.")
return
end

# utility functions

function underscorise(n::Integer)
return join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_')
end
Expand Down
2 changes: 1 addition & 1 deletion src/layers/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ function (d::VariationalHiddenDropout)(x, ps, st::NamedTuple)
return y, merge(st, (; mask, rng, update_mask=Val(false)))
end

function Base.show(io::IO, d::VariationalHiddenDropout)
function Base.show(io::IO, ::MIME"text/plain", d::VariationalHiddenDropout)

Check warning on line 183 in src/layers/dropout.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/dropout.jl#L183

Added line #L183 was not covered by tests
print(io, "VariationalHiddenDropout(", d.p)
d.dims != Colon() && print(io, ", dims=", d.dims)
return print(io, ")")
Expand Down
7 changes: 4 additions & 3 deletions src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ true
bumper = Val(false)
end

function Base.show(io::IO, l::DynamicExpressionsLayer)
function Base.show(io::IO, ::MIME"text/plain", l::DynamicExpressionsLayer)

Check warning on line 93 in src/layers/extension.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/extension.jl#L93

Added line #L93 was not covered by tests
return print(io, "DynamicExpressionNode($(l.expression))")
end

Expand Down Expand Up @@ -183,7 +183,7 @@ Lux.initialparameters(::AbstractRNG, l::FluxLayer) = (p=l.init_parameters(),)

(l::FluxLayer)(x, ps, st) = l.re(ps.p)(x), st

Base.show(io::IO, l::FluxLayer) = print(io, "FluxLayer($(l.layer))")
Base.show(io::IO, ::MIME"text/plain", l::FluxLayer) = print(io, "FluxLayer($(l.layer))")

Check warning on line 186 in src/layers/extension.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/extension.jl#L186

Added line #L186 was not covered by tests

## SimpleChains.jl

Expand Down Expand Up @@ -224,7 +224,8 @@ struct SimpleChainsLayer{ToArray, SL, LL <: Union{Nothing, AbstractExplicitLayer
end
end

function Base.show(io::IO, s::SimpleChainsLayer{ToArray}) where {ToArray}
function Base.show(

Check warning on line 227 in src/layers/extension.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/extension.jl#L227

Added line #L227 was not covered by tests
io::IO, ::MIME"text/plain", s::SimpleChainsLayer{ToArray}) where {ToArray}
_print_wrapper_model(io, "SimpleChainsLayer{$ToArray}", s.lux_layer)
end

Expand Down
9 changes: 7 additions & 2 deletions src/layers/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,14 @@ parameters: one specifying the magnitude (e.g. `weight_g`) and one specifying th
dims
end

function WeightNorm{which_params}(layer::AbstractExplicitLayer;

Check warning on line 425 in src/layers/normalize.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalize.jl#L425

Added line #L425 was not covered by tests
dims::Union{Tuple, Nothing}=nothing) where {which_params}
return WeightNorm{which_params}(layer, dims)

Check warning on line 427 in src/layers/normalize.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalize.jl#L427

Added line #L427 was not covered by tests
end

function WeightNorm(layer::AbstractExplicitLayer, which_params::NTuple{N, Symbol},
dims::Union{Tuple, Nothing}=nothing) where {N}
return WeightNorm{which_params}(layer, dims)
return WeightNorm{which_params}(layer; dims)

Check warning on line 432 in src/layers/normalize.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalize.jl#L432

Added line #L432 was not covered by tests
end

@inline _norm(x; dims=Colon()) = sqrt.(sum(abs2, x; dims))
Expand Down Expand Up @@ -552,7 +557,7 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`.
- `affine=false`: Empty `NamedTuple()`
- `affine=true`
+ `bias`: Bias of shape `(shape..., 1)`
+ `scale`: Scale of shape `(shape..., 1)`
"""
Expand Down
Loading

0 comments on commit a6de2af

Please sign in to comment.