diff --git a/src/contrib/freeze.jl b/src/contrib/freeze.jl index b715393dd7..cb24bb8365 100644 --- a/src/contrib/freeze.jl +++ b/src/contrib/freeze.jl @@ -47,7 +47,7 @@ parameter types, users need to implement `Lux._merge(::CustomParamType, ::NamedT ```jldoctest julia> Lux.Experimental.FrozenLayer(Dense(2 => 2), (:weight,)) -Dense(2 => 2) (with `weight` frozen) # 2 parameters, plus 4 non-trainable +FrozenLayer(Dense(2 => 2), (:weight,)) # 2 parameters, plus 4 non-trainable ``` See also [`Lux.Experimental.freeze`](@ref), [`Lux.Experimental.unfreeze`](@ref). diff --git a/src/contrib/map.jl b/src/contrib/map.jl index 97c844e5c5..95df68e7df 100644 --- a/src/contrib/map.jl +++ b/src/contrib/map.jl @@ -14,7 +14,7 @@ julia> c = Parallel( +; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)), dense_3=Dense(5 => 1)) Parallel( - + + connection = +, chain = Chain( dense_1 = Dense(2 => 3), # 9 parameters bn = BatchNorm(3, affine=true, track_stats=true), # 6 parameters, plus 7 diff --git a/src/helpers/compact.jl b/src/helpers/compact.jl index 6eb8b84b1d..518f6b5064 100644 --- a/src/helpers/compact.jl +++ b/src/helpers/compact.jl @@ -194,10 +194,7 @@ printout, which gives a verbatim representation of the code used to construct th julia> model = @compact(w=rand(3), name="Linear(3 => 1)") do x @return sum(w .* x) end -Linear(3 => 1)() # 3 parameters - -julia> println(model) -Linear(3 => 1)() +Linear(3 => 1) # 3 parameters ``` This can be useful when using `@compact` to hierarchically construct complex models to be diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 3bba29e996..9abfeddada 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -69,7 +69,7 @@ Reverse the specified dimension `dims` of the passed array ```jldoctest julia> model = ReverseSequence() -ReverseSequence() +ReverseSequence{Nothing}(nothing) julia> rng = Random.default_rng(); Random.seed!(rng, 0); @@ -127,7 +127,7 @@ Flattens the passed array into a matrix. ```jldoctest julia> model = FlattenLayer() -FlattenLayer() +FlattenLayer{Nothing}(nothing) julia> rng = Random.default_rng(); Random.seed!(rng, 0); diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 7a31584b98..d4ae6dc16e 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -51,13 +51,6 @@ See [`Parallel`](@ref) for a more general implementation. name end -function Functors.functor(::Type{<:SkipConnection}, sc) - recon = let name = sc.name - nt -> SkipConnection(nt.layers, nt.connection, name) - end - return (; sc.layers, sc.connection), recon -end - function SkipConnection(layers, connection; name::NAME_TYPE=nothing) return SkipConnection(; layers, connection, name) end @@ -162,13 +155,6 @@ julia> size.(first(model((x1, x2), ps, st))) name end -function Functors.functor(::Type{<:Parallel}, p) - recon = let name = p.name - nt -> Parallel(nt.connection, nt.layers, name) - end - return (; p.connection, p.layers), recon -end - function Parallel(connection, layers...; name::NAME_TYPE=nothing) return Parallel(connection, __named_tuple_layers(layers...), name) end @@ -367,13 +353,6 @@ end name end -function Functors.functor(::Type{<:PairwiseFusion}, p) - recon = let name = p.name - nt -> PairwiseFusion(nt.connection, nt.layers, name) - end - return (; p.connection, p.layers), recon -end - function PairwiseFusion(connection, layers...; name::NAME_TYPE=nothing) return PairwiseFusion(connection, __named_tuple_layers(layers...), name) end diff --git a/src/layers/display.jl b/src/layers/display.jl index df4bd71f34..c3852df857 100644 --- a/src/layers/display.jl +++ b/src/layers/display.jl @@ -20,7 +20,7 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) return end - children = _children(obj) + children = _printable_children(obj) if all(_show_leaflike, children) _layer_show(io, obj, indent, name) else @@ -37,8 +37,8 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) end end -_children(x::AbstractExplicitLayer) = Functors.children(x) -function _children(m::AbstractExplicitContainerLayer{layers}) where {layers} +_printable_children(x) = Functors.children(x) +function _printable_children(m::AbstractExplicitContainerLayer{layers}) where {layers} children = Functors.children(m) length(layers) ≥ 2 && return children field = first(layers) @@ -47,6 +47,12 @@ function _children(m::AbstractExplicitContainerLayer{layers}) where {layers} nt isa NamedTuple || (nt = NamedTuple{(field,)}((nt,))) return merge(Base.structdiff(children, NamedTuple{(field,)}), nt) end +function _printable_children(l::Union{PairwiseFusion, Parallel}) + children = Functors.children(l) + l.connection === nothing && return children.layers + return merge((; l.connection), children.layers) +end +_printable_children(l::SkipConnection) = (; l.connection, l.layers) _show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for: _show_leaflike(x::AbstractExplicitLayer) = false @@ -89,11 +95,14 @@ function _print_wrapper_model(io::IO, desc::String, model::AbstractExplicitLayer if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL print(io, desc, "(\n") _big_show(io, model, 4) - elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix - print(io, desc, "(") + print(io, ") ") + _big_finale(io, model) + return + end + print(io, desc, "(") + if !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix _layer_show(io, model) else - print(io, desc, "(") show(io, model) end print(io, ")") diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 76c59241d6..34dffa08ab 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -51,10 +51,11 @@ julia> expr_2 = x2 - x1 * x2 + 2.5 - 1.0 * x1 julia> layer = DynamicExpressionsLayer(operators, expr_1, expr_2) DynamicExpressionsLayer( layer_1 = Parallel( - layer_1 = DynamicExpressionNode(x1 * cos(x2 - 3.2)), # 1 parameters - layer_2 = DynamicExpressionNode(((x2 - (x1 * x2)) + 2.5) - (1.0 * x1)), # 2 parameters + connection = nothing, + layer_1 = DynamicExpressionsLayer(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{typeof(cos)}}((+, -, *), (cos,)), x1 * cos(x2 - 3.2); turbo=Val{false}(), bumper=Val{false}()), # 1 parameters + layer_2 = DynamicExpressionsLayer(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{typeof(cos)}}((+, -, *), (cos,)), ((x2 - (x1 * x2)) + 2.5) - (1.0 * x1); turbo=Val{false}(), bumper=Val{false}()), # 2 parameters ), - layer_2 = WrappedFunction(__stack1), + layer_2 = WrappedFunction{:direct_call}(__stack1), ) # Total: 3 parameters, # plus 0 states. @@ -90,8 +91,9 @@ true bumper = Val(false) end -function Base.show(io::IO, ::MIME"text/plain", l::DynamicExpressionsLayer) - return print(io, "DynamicExpressionNode($(l.expression))") +function Base.show(io::IO, l::DynamicExpressionsLayer) + print(io, + "DynamicExpressionsLayer($(l.operator_enum), $(l.expression); turbo=$(l.turbo), bumper=$(l.bumper))") end function initialparameters(::AbstractRNG, layer::DynamicExpressionsLayer) diff --git a/src/transform/simplechains.jl b/src/transform/simplechains.jl index fb7df87e80..a4b3b473f5 100644 --- a/src/transform/simplechains.jl +++ b/src/transform/simplechains.jl @@ -44,12 +44,13 @@ SimpleChainsLayer{false}( layer_2 = MaxPool((2, 2)), layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters layer_4 = MaxPool((2, 2)), - layer_5 = FlattenLayer(), + layer_5 = FlattenLayer{Int64}(3), layer_6 = Dense(256 => 128, relu), # 32_896 parameters layer_7 = Dense(128 => 84, relu), # 10_836 parameters layer_8 = Dense(84 => 10), # 850 parameters ), -) # 47_154 parameters +) # Total: 47_154 parameters, + # plus 0 states. julia> ps, st = Lux.setup(Random.default_rng(), simple_chains_model);