Skip to content

Commit

Permalink
Fix more cases
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 26, 2024
1 parent a01c7ee commit be0789a
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/contrib/freeze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ parameter types, users need to implement `Lux._merge(::CustomParamType, ::NamedT
```jldoctest
julia> Lux.Experimental.FrozenLayer(Dense(2 => 2), (:weight,))
Dense(2 => 2) (with `weight` frozen) # 2 parameters, plus 4 non-trainable
FrozenLayer(Dense(2 => 2), (:weight,)) # 2 parameters, plus 4 non-trainable
```
See also [`Lux.Experimental.freeze`](@ref), [`Lux.Experimental.unfreeze`](@ref).
Expand Down
4 changes: 2 additions & 2 deletions src/contrib/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ julia> c = Parallel(
+; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)),
dense_3=Dense(5 => 1))
Parallel(
+
connection = +,
chain = Chain(
dense_1 = Dense(2 => 3), # 9 parameters
bn = BatchNorm(3, affine=true, track_stats=true), # 6 parameters, plus 7
Expand Down Expand Up @@ -90,7 +90,7 @@ julia> c = Parallel(
+; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)),
dense_3=Dense(5 => 1))
Parallel(
+
connection = +,
chain = Chain(
dense_1 = Dense(2 => 3), # 9 parameters
bn = BatchNorm(3, affine=true, track_stats=true), # 6 parameters, plus 7
Expand Down
5 changes: 1 addition & 4 deletions src/helpers/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,7 @@ printout, which gives a verbatim representation of the code used to construct th
julia> model = @compact(w=rand(3), name="Linear(3 => 1)") do x
@return sum(w .* x)
end
Linear(3 => 1)() # 3 parameters
julia> println(model)
Linear(3 => 1)()
Linear(3 => 1) # 3 parameters
```
This can be useful when using `@compact` to hierarchically construct complex models to be
Expand Down
4 changes: 2 additions & 2 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Reverse the specified dimension `dims` of the passed array
```jldoctest
julia> model = ReverseSequence()
ReverseSequence()
ReverseSequence{Nothing}(nothing)
julia> rng = Random.default_rng();
Random.seed!(rng, 0);
Expand Down Expand Up @@ -127,7 +127,7 @@ Flattens the passed array into a matrix.
```jldoctest
julia> model = FlattenLayer()
FlattenLayer()
FlattenLayer{Nothing}(nothing)
julia> rng = Random.default_rng();
Random.seed!(rng, 0);
Expand Down
21 changes: 0 additions & 21 deletions src/layers/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,6 @@ See [`Parallel`](@ref) for a more general implementation.
name
end

function Functors.functor(::Type{<:SkipConnection}, sc)
recon = let name = sc.name
nt -> SkipConnection(nt.layers, nt.connection, name)
end
return (; sc.layers, sc.connection), recon
end

function SkipConnection(layers, connection; name::NAME_TYPE=nothing)
return SkipConnection(; layers, connection, name)
end
Expand Down Expand Up @@ -162,13 +155,6 @@ julia> size.(first(model((x1, x2), ps, st)))
name
end

function Functors.functor(::Type{<:Parallel}, p)
recon = let name = p.name
nt -> Parallel(nt.connection, nt.layers, name)
end
return (; p.connection, p.layers), recon
end

function Parallel(connection, layers...; name::NAME_TYPE=nothing)
return Parallel(connection, __named_tuple_layers(layers...), name)
end
Expand Down Expand Up @@ -367,13 +353,6 @@ end
name
end

function Functors.functor(::Type{<:PairwiseFusion}, p)
recon = let name = p.name
nt -> PairwiseFusion(nt.connection, nt.layers, name)
end
return (; p.connection, p.layers), recon
end

function PairwiseFusion(connection, layers...; name::NAME_TYPE=nothing)
return PairwiseFusion(connection, __named_tuple_layers(layers...), name)
end
Expand Down
21 changes: 15 additions & 6 deletions src/layers/display.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
return
end

children = _children(obj)
children = _printable_children(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent, name)
else
Expand All @@ -37,8 +37,8 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
end
end

_children(x::AbstractExplicitLayer) = Functors.children(x)
function _children(m::AbstractExplicitContainerLayer{layers}) where {layers}
_printable_children(x) = Functors.children(x)
function _printable_children(m::AbstractExplicitContainerLayer{layers}) where {layers}
children = Functors.children(m)
length(layers) 2 && return children
field = first(layers)
Expand All @@ -47,6 +47,12 @@ function _children(m::AbstractExplicitContainerLayer{layers}) where {layers}
nt isa NamedTuple || (nt = NamedTuple{(field,)}((nt,)))
return merge(Base.structdiff(children, NamedTuple{(field,)}), nt)
end
function _printable_children(l::Union{PairwiseFusion, Parallel})
children = Functors.children(l)
l.connection === nothing && return children.layers
return merge((; l.connection), children.layers)
end
_printable_children(l::SkipConnection) = (; l.connection, l.layers)

_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for:
_show_leaflike(x::AbstractExplicitLayer) = false
Expand Down Expand Up @@ -89,11 +95,14 @@ function _print_wrapper_model(io::IO, desc::String, model::AbstractExplicitLayer
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
print(io, desc, "(\n")
_big_show(io, model, 4)
elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix
print(io, desc, "(")
print(io, ") ")
_big_finale(io, model)
return
end
print(io, desc, "(")
if !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix

Check warning on line 103 in src/layers/display.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/display.jl#L102-L103

Added lines #L102 - L103 were not covered by tests
_layer_show(io, model)
else
print(io, desc, "(")
show(io, model)
end
print(io, ")")
Expand Down
11 changes: 6 additions & 5 deletions src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ julia> expr_2 = x2 - x1 * x2 + 2.5 - 1.0 * x1
julia> layer = DynamicExpressionsLayer(operators, expr_1, expr_2)
DynamicExpressionsLayer(
layer_1 = Parallel(
layer_1 = DynamicExpressionNode(x1 * cos(x2 - 3.2)), # 1 parameters
layer_2 = DynamicExpressionNode(((x2 - (x1 * x2)) + 2.5) - (1.0 * x1)), # 2 parameters
layer_1 = DynamicExpressionsLayer(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{typeof(cos)}}((+, -, *), (cos,)), x1 * cos(x2 - 3.2); turbo=Val{false}(), bumper=Val{false}()), # 1 parameters
layer_2 = DynamicExpressionsLayer(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{typeof(cos)}}((+, -, *), (cos,)), ((x2 - (x1 * x2)) + 2.5) - (1.0 * x1); turbo=Val{false}(), bumper=Val{false}()), # 2 parameters
),
layer_2 = WrappedFunction(__stack1),
layer_2 = WrappedFunction{:direct_call}(__stack1),
) # Total: 3 parameters,
# plus 0 states.
Expand Down Expand Up @@ -90,8 +90,9 @@ true
bumper = Val(false)
end

function Base.show(io::IO, ::MIME"text/plain", l::DynamicExpressionsLayer)
return print(io, "DynamicExpressionNode($(l.expression))")
function Base.show(io::IO, l::DynamicExpressionsLayer)
print(io,
"DynamicExpressionsLayer($(l.operator_enum), $(l.expression); turbo=$(l.turbo), bumper=$(l.bumper))")
end

function initialparameters(::AbstractRNG, layer::DynamicExpressionsLayer)
Expand Down
5 changes: 3 additions & 2 deletions src/transform/simplechains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ SimpleChainsLayer{false}(
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = FlattenLayer(),
layer_5 = FlattenLayer{Int64}(3),
layer_6 = Dense(256 => 128, relu), # 32_896 parameters
layer_7 = Dense(128 => 84, relu), # 10_836 parameters
layer_8 = Dense(84 => 10), # 850 parameters
),
) # 47_154 parameters
) # Total: 47_154 parameters,
# plus 0 states.
julia> ps, st = Lux.setup(Random.default_rng(), simple_chains_model);
Expand Down
6 changes: 3 additions & 3 deletions test/helpers/compact_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
tmp = sum(w(x))
@return tmp + y
end
expected_string = "Linear(...)() # 1_056 parameters"
expected_string = "Linear(...) # 1_056 parameters"
@test similar_strings(get_model_string(model), expected_string)
end

Expand Down Expand Up @@ -242,7 +242,7 @@
@return w2 * w1(x)
end
expected_string = """@compact(
w1 = Model(32)(), # 1_024 parameters
w1 = Model(32), # 1_024 parameters
w2 = 32×32 Matrix{Float64},
w3 = 32-element Vector{Float64},
) do x
Expand All @@ -258,7 +258,7 @@
end, w2=randn(32, 32), w3=randn(32), name="Model(32)") do x
@return w2 * w1(x)
end
expected_string = """Model(32)() # 2_080 parameters"""
expected_string = """Model(32) # 2_080 parameters"""
@test similar_strings(get_model_string(model), expected_string)
end

Expand Down

0 comments on commit be0789a

Please sign in to comment.