Skip to content

Commit

Permalink
Fix #2086 re @autosize (#2087)
Browse files Browse the repository at this point in the history
* fix 2086

* Embedding, but not yet
  • Loading branch information
mcabbott authored Oct 16, 2022
1 parent 4c38c8a commit 74bd04b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,13 @@ function _makelazy(ex::Expr)
n == 0 && return ex
n == 1 && error("@autosize doesn't expect an underscore here: $ex")
n == 2 && return :($LazyLayer($(string(ex)), $(_makefun(ex)), nothing))
n > 2 && return Expr(ex.head, ex.args[1], map(_makelazy, ex.args[2:end])...)
n > 2 && return Expr(ex.head, map(_makelazy, ex.args)...)
end
_makelazy(x) = x

function _underscoredepth(ex::Expr)
# Meta.isexpr(ex, :tuple) && :_ in ex.args && return 10
ex.head in (:call, :kw, :(->), :block) || return 0
ex.head in (:call, :kw, :(->), :block, :parameters) || return 0
ex.args[1] === :(=>) && ex.args[2] === :_ && return 1
m = maximum(_underscoredepth, ex.args)
m == 0 ? 0 : m+1
Expand All @@ -279,6 +279,7 @@ is needed to make `@autosize (2,3,4) Dense(_ => 5)` return
"""
autosizefor(::Type, x::AbstractArray) = size(x, max(1, ndims(x)-1))
autosizefor(::Type{<:Dense}, x::AbstractArray) = size(x, 1)
autosizefor(::Type{<:Embedding}, x::AbstractArray) = size(x, 1)
autosizefor(::Type{<:LayerNorm}, x::AbstractArray) = size(x, 1)

_replaceunderscore(e, s) = e === :_ ? s : e
Expand Down
14 changes: 13 additions & 1 deletion test/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,20 @@ end

m = @autosize (2, 3, 4, 5) Dense(_ => 10) # goes by first dim, not 2nd-last
@test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5)


@test_broken begin # outputsize fails on Embedding
m = @autosize (2, 3, 4, 5) Embedding(_ => 10) # goes by first dim, not 2nd-last
@test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5)
end

m = @autosize (9,) Dense(_ => div(_,2))
@test randn(9) |> m |> size == (4,)

m = @autosize (3,) Chain(one = Dense(_ => 4), two = softmax) # needs kw
@test randn(3) |> m |> size == (4,)

m = @autosize (3,) Chain(; one = Dense(_ => 4), two = softmax) # needs parameters
@test randn(3) |> m |> size == (4,)

m = @autosize (3, 45) Maxout(() -> Dense(_ => 6, tanh), 2) # needs ->, block
@test randn(3, 45) |> m |> size == (6, 45)
Expand Down Expand Up @@ -222,6 +230,10 @@ end
Dense(_ => 10),
)
@test randn(Float32, img..., 1, 32) |> m |> size == (10, 32)

# https://github.com/FluxML/Flux.jl/issues/2086
m = @autosize (3, 1) Chain(; c = Dense(_ => 2, sigmoid), b = BatchNorm(_, affine=false))
@test randn(Float32, 3, 32) |> m |> size == (2, 32)
end

@testset "LazyLayer" begin
Expand Down

0 comments on commit 74bd04b

Please sign in to comment.