From 85fa7e8bd9725197347b0b6f447ea7395646a9ef Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 11 Mar 2022 12:26:40 -0600 Subject: [PATCH 01/15] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 600053ea8..423c94e2f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Metalhead" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" -version = "0.7.0-DEV" +version = "0.7.0" [deps] Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" From d034260a019ad7829411295ffbb789603fcffe32 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 11 Mar 2022 12:35:46 -0600 Subject: [PATCH 02/15] Allow custom activation in MobileNet --- src/convnets/mobilenet.jl | 84 +++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index d201e39bb..de2bc9d4f 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -14,6 +14,7 @@ Create a MobileNetv2 model. - `c`: The number of output feature maps - `n`: The number of times a block is repeated - `s`: The stride of the convolutional kernel + - `a`: The activation function used in the bottleneck layer - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: The number of output classes """ @@ -24,10 +25,10 @@ function mobilenetv2(width_mult, configs; max_width = 1280, nclasses = 1000) append!(layers, conv_bn((3, 3), 3, inplanes, stride = 2)) # building inverted residual blocks - for (t, c, n, s) in configs + for (t, c, n, s, a) in configs outplanes = _round_channels(c * width_mult, width_mult == 0.1 ? 4 : 8) for i in 1:n - push!(layers, invertedresidual(3, inplanes, inplanes * t, outplanes, relu6; + push!(layers, invertedresidual(3, inplanes, inplanes * t, outplanes, a; stride = i == 1 ? s : 1)) inplanes = outplanes end @@ -44,14 +45,14 @@ end # Layer configurations for MobileNetv2 const mobilenetv2_configs = [ - # t, c, n, s - (1, 16, 1, 1), - (6, 24, 2, 2), - (6, 32, 3, 2), - (6, 64, 4, 2), - (6, 96, 3, 1), - (6, 160, 3, 2), - (6, 320, 1, 1) +# t, c, n, s, a + (1, 16, 1, 1, relu6), + (6, 24, 2, 2, relu6), + (6, 32, 3, 2, relu6), + (6, 64, 4, 2, relu6), + (6, 96, 3, 1, relu6), + (6, 160, 3, 2, relu6), + (6, 320, 1, 1, relu6) ] # Model definition for MobileNetv2 @@ -106,8 +107,8 @@ Create a MobileNetv3 model. - `c::Float` - The multiplier factor for deciding the number of feature maps in the hidden layer - `t::Int` - The number of output feature maps for a given block - `r::Int` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers - - `use_hs::Bool` - Whether to use Hard-Swish activation function - `s::Int` - The stride of the convolutional kernel + - `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: the number of output classes """ @@ -118,12 +119,11 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000) append!(layers, conv_bn((3, 3), 3, inplanes, hardswish; stride = 2)) explanes = 0 # building inverted residual blocks - for (k, t, c, r, use_hs, s) in configs + for (k, t, c, r, a, s) in configs # inverted residual layers outplanes = _round_channels(c * width_mult, 8) explanes = _round_channels(inplanes * t, 8) - activation = use_hs ? hardswish : relu - push!(layers, invertedresidual(k, inplanes, explanes, outplanes, activation; + push!(layers, invertedresidual(k, inplanes, explanes, outplanes, a; stride = s, reduction = r)) inplanes = outplanes end @@ -143,36 +143,36 @@ end # Configurations for small and large mode for MobileNetv3 mobilenetv3_configs = Dict( :small => [ - # k, t, c, SE, HS, s - (3, 1, 16, 4, false, 2), - (3, 4.5, 24, nothing, false, 2), - (3, 3.67, 24, nothing, false, 1), - (5, 4, 40, 4, true, 2), - (5, 6, 40, 4, true, 1), - (5, 6, 40, 4, true, 1), - (5, 3, 48, 4, true, 1), - (5, 3, 48, 4, true, 1), - (5, 6, 96, 4, true, 2), - (5, 6, 96, 4, true, 1), - (5, 6, 96, 4, true, 1), + # k, t, c, SE, a, s + (3, 1, 16, 4, relu, 2), + (3, 4.5, 24, nothing, relu, 2), + (3, 3.67, 24, nothing, relu, 1), + (5, 4, 40, 4, hardswish, 2), + (5, 6, 40, 4, hardswish, 1), + (5, 6, 40, 4, hardswish, 1), + (5, 3, 48, 4, hardswish, 1), + (5, 3, 48, 4, hardswish, 1), + (5, 6, 96, 4, hardswish, 2), + (5, 6, 96, 4, hardswish, 1), + (5, 6, 96, 4, hardswish, 1), ], :large => [ - # k, t, c, SE, HS, s - (3, 1, 16, nothing, false, 1), - (3, 4, 24, nothing, false, 2), - (3, 3, 24, nothing, false, 1), - (5, 3, 40, 4, false, 2), - (5, 3, 40, 4, false, 1), - (5, 3, 40, 4, false, 1), - (3, 6, 80, nothing, true, 2), - (3, 2.5, 80, nothing, true, 1), - (3, 2.3, 80, nothing, true, 1), - (3, 2.3, 80, nothing, true, 1), - (3, 6, 112, 4, true, 1), - (3, 6, 112, 4, true, 1), - (5, 6, 160, 4, true, 2), - (5, 6, 160, 4, true, 1), - (5, 6, 160, 4, true, 1) + # k, t, c, SE, a, s + (3, 1, 16, nothing, relu, 1), + (3, 4, 24, nothing, relu, 2), + (3, 3, 24, nothing, relu, 1), + (5, 3, 40, 4, relu, 2), + (5, 3, 40, 4, relu, 1), + (5, 3, 40, 4, relu, 1), + (3, 6, 80, nothing, hardswish, 2), + (3, 2.5, 80, nothing, hardswish, 1), + (3, 2.3, 80, nothing, hardswish, 1), + (3, 2.3, 80, nothing, hardswish, 1), + (3, 6, 112, 4, hardswish, 1), + (3, 6, 112, 4, hardswish, 1), + (5, 6, 160, 4, hardswish, 2), + (5, 6, 160, 4, hardswish, 1), + (5, 6, 160, 4, hardswish, 1) ] ) From 02ecea7bca48fc2c5c4db6a1d9075fc9c0c4c72c Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 11 Mar 2022 12:36:26 -0600 Subject: [PATCH 03/15] Disable ViTs --- src/Metalhead.jl | 3 ++- test/runtests.jl | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 64fc85c56..b23dd49a2 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -34,7 +34,8 @@ include("convnets/convmixer.jl") include("other/mlpmixer.jl") # ViT-based models -include("vit-based/vit.jl") +# TODO: add this back in when we are ready +# include("vit-based/vit.jl") include("pretrain.jl") diff --git a/test/runtests.jl b/test/runtests.jl index 66ddbbaa8..e8f6e439d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,6 @@ end GC.gc() # ViT tests -@testset verbose = true "ViTs" begin - include("vit-based.jl") -end +# @testset verbose = true "ViTs" begin +# include("vit-based.jl") +# end From 7aa92ac403b6da1f3017e2bb9504dd7a70286b6a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 27 Mar 2022 22:11:57 -0500 Subject: [PATCH 04/15] Fix docs issues on CI --- .github/workflows/Publish.yml | 2 +- .github/workflows/PublishStable.yml | 32 +++++++++++++++++++++++++++++ .gitignore | 3 +++ Project.toml | 2 ++ docs/make.jl | 16 ++++++--------- docs/serve.jl | 7 +++---- src/Metalhead.jl | 3 +-- 7 files changed, 48 insertions(+), 17 deletions(-) create mode 100644 .github/workflows/PublishStable.yml diff --git a/.github/workflows/Publish.yml b/.github/workflows/Publish.yml index 48c28812b..5d59ed275 100644 --- a/.github/workflows/Publish.yml +++ b/.github/workflows/Publish.yml @@ -25,7 +25,7 @@ jobs: run: | cd gh-pages julia --color=yes --project=../docs -e "using Pkg; Pkg.instantiate()" - julia --color=yes --project=../docs ../docs/make.jl + julia --color=yes --project=../docs -e 'include("../docs/make.jl"); build_and_deploy("dev")' git config user.name github-actions git config user.email github-actions@github.com git add . diff --git a/.github/workflows/PublishStable.yml b/.github/workflows/PublishStable.yml new file mode 100644 index 000000000..d18b27482 --- /dev/null +++ b/.github/workflows/PublishStable.yml @@ -0,0 +1,32 @@ +name: PublishStable + +on: + push: + tags: + - '*' + +jobs: + publish: + name: "Publish Documentation (stable)" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/checkout@v2 + with: + ref: gh-pages + path: gh-pages + - uses: julia-actions/setup-julia@latest + with: + version: '1.6' + - name: Install dependencies + run: julia --color=yes --project=. -e 'using Pkg; Pkg.instantiate()' + - name: Build and deploy + run: | + cd gh-pages + julia --color=yes --project=../docs -e "using Pkg; Pkg.instantiate()" + julia --color=yes --project=../docs -e 'include("../docs/make.jl"); build_and_deploy("stable")' + git config user.name github-actions + git config user.email github-actions@github.com + git add . + git commit -m "(stable) Built by Publish.jl." + git push diff --git a/.gitignore b/.gitignore index b04f8b796..4b473d3fd 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,6 @@ # manifests docs/Manifest.toml Manifest.toml + +# docs theme +_flux-theme diff --git a/Project.toml b/Project.toml index 423c94e2f..3176da8d2 100644 --- a/Project.toml +++ b/Project.toml @@ -27,6 +27,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [publish] title = "Metalhead.jl" +theme = "_flux-theme" +ignore = ["^(gh-pages|juliamnt|julia.dmg)$"] [targets] test = ["Test"] diff --git a/docs/make.jl b/docs/make.jl index f2d86d295..db03f1d76 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,21 +1,17 @@ using Pkg Pkg.develop(path = "..") -# this is needed since Publish v0.9 breaks our theming hack -Pkg.pin(name = "Publish", version = "0.8") using Publish -using Pkg.Artifacts +using Artifacts, LazyArtifacts using Metalhead # override default theme -Publish.Themes.default() = artifact"flux-theme" +cp(artifact"flux-theme", "../_flux-theme"; force = true) p = Publish.Project(Metalhead) -# needed to prevent error when overwriting -rm("dev", recursive = true, force = true) -rm(p.env["version"], recursive = true, force = true) - -# build documentation -deploy(Metalhead; root = "/Metalhead.jl", force = true, label = "dev") +function build_and_deploy(label) + rm(label; recursive = true, force = true) + deploy(Metalhead; root = "/Metalhead.jl", label = label) +end diff --git a/docs/serve.jl b/docs/serve.jl index 75ec54a74..763e77e93 100644 --- a/docs/serve.jl +++ b/docs/serve.jl @@ -1,16 +1,15 @@ using Pkg Pkg.develop(path = "..") -# this is needed since Publish v0.9 breaks our theming hack -Pkg.pin(name = "Publish", version = "0.8") +using Revise using Publish -using Pkg.Artifacts +using Artifacts, LazyArtifacts using Metalhead # override default theme -Publish.Themes.default() = artifact"flux-theme" +cp(artifact"flux-theme", "../_flux-theme"; force = true) p = Publish.Project(Metalhead) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index b23dd49a2..64fc85c56 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -34,8 +34,7 @@ include("convnets/convmixer.jl") include("other/mlpmixer.jl") # ViT-based models -# TODO: add this back in when we are ready -# include("vit-based/vit.jl") +include("vit-based/vit.jl") include("pretrain.jl") From 6d0d71c8d4f612c943a5d8c29df671f534a4273b Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Thu, 31 Mar 2022 17:02:57 -0500 Subject: [PATCH 05/15] Delete datasets folder --- datasets/README.md | 27 --------------------------- 1 file changed, 27 deletions(-) delete mode 100644 datasets/README.md diff --git a/datasets/README.md b/datasets/README.md deleted file mode 100644 index b92f0900a..000000000 --- a/datasets/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# Image datasets - -Drop image datasets in this folder for auto-detection support -using `Metalhead.datasets()`. That function will attempt to -auto-detect any standard datasets present in this folder and -make them easily available. This file, lists, for each data -set, instructions for how to obtain them in a format that -is suitable for auto-detection. - -# ImageNet - -As of 2020, the ImageNet dataset is not openly available. -ImageNet website http://image-net.org makes the images available -only for non-commercial research and/or educational purposes. - -To get the dataset go to http://image-net.org/download-images -and sign up for an account. Then you may submit a request. After your request -gets approved then you will be able to download the dataset. - -After downloading, unpack the file `imagenet_object_localization.tar.gz` and move the resulting `ILSVRC` folder here. - -For more information visit http://image-net.org/download-faq - -# CIFAR-10 - -We support the CIFAR-10 dataset in binary format from https://www.cs.toronto.edu/~kriz/cifar.html. Simply download the archive from https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz and move the resulting `cifar-10-batches-bin` folder here -for autodetection to support it. From d51e3a4914ea80c2c54e521ec8a5cd2d88108c39 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Thu, 31 Mar 2022 20:44:18 -0500 Subject: [PATCH 06/15] Enable some ViT tests --- test/runtests.jl | 6 +++--- test/vit-based.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index e8f6e439d..66ddbbaa8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,6 @@ end GC.gc() # ViT tests -# @testset verbose = true "ViTs" begin -# include("vit-based.jl") -# end +@testset verbose = true "ViTs" begin + include("vit-based.jl") +end diff --git a/test/vit-based.jl b/test/vit-based.jl index 592268be6..1ab105954 100644 --- a/test/vit-based.jl +++ b/test/vit-based.jl @@ -2,7 +2,7 @@ using Metalhead, Test using Flux @testset "ViT" begin - for mode in [:tiny, :small, :base, :large, :huge, :giant, :gigantic] + for mode in [:tiny, :small, :base, :large] #,:huge, :giant, :gigantic] m = ViT(mode) @test size(m(rand(Float32, 256, 256, 3, 2))) == (1000, 2) @test_skip gradtest(m, rand(Float32, 256, 256, 3, 2)) From c1f7dd86fc95b54d0d892a4c642fbc8c3a69d386 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Apr 2022 12:19:51 -0500 Subject: [PATCH 07/15] Add MobileNetv1 --- src/convnets/mobilenet.jl | 102 ++++++++++++++++++++++++++++++++++++++ src/layers/conv.jl | 43 +++++++++++++++- 2 files changed, 144 insertions(+), 1 deletion(-) diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index de2bc9d4f..37530f660 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -1,3 +1,105 @@ +# MobileNetv1 + +""" + mobilenetv1(imsize, width_mult, config; + activation = relu, + inchannels = 3, + nclasses = 1000, + fcsize = 1024) + +Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). + +# Arguments +- `imsize`: A 2-tuple indicating the input spatial dimensions +- `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper) +- `configs`: A "list of tuples" configuration for each layer that details: + - `dw`: Set true to use a depthwise separable convolution or false for regular convolution + - `o`: The number of output feature maps + - `s`: The stride of the convolutional kernel + - `r`: The number of time this configuration block is repeated +- `fcsize`: The intermediate fully-connected size between the convolution and final layers +- `nclasses`: The number of output classes +""" +function mobilenetv1(imsize, width_mult, config; + activation = relu, + inchannels = 3, + nclasses = 1000, + fcsize = 1024) + layers = [] + inch = inchannels + for (dw, outch, stride, repeats) in config + for _ in 1:repeats + outch = outch * width_mult + layer = if dw + depthwise_sep_conv_bn((3, 3), inch, outch, activation; stride = stride) + else + conv_bn((3, 3), inch, outch, activation; stride = stride) + end + append!(layers, layer) + inch = outch + end + end + push!(layers, MeanPool((7, 7))) + convsize = prod(Flux.outputsize(layers, (imsize..., inchannels, 1))) + + return Chain(Chain(layers...), + Chain(MLUtils.flatten, + Dense(convsize, fcsize, activation), + Dense(fcsize, nclasses))) +end + +const mobilenetv1_configs = [ +# dw, c, s, r + (false, 32, 2, 1), + ( true, 64, 1, 1), + ( true, 128, 2, 1), + ( true, 128, 1, 1), + ( true, 256, 2, 1), + ( true, 256, 1, 1), + ( true, 512, 2, 1), + ( true, 512, 1, 5), + ( true, 1024, 2, 1), + ( true, 1024, 1, 1) +] + +""" + MobileNetv1(imsize::NTuple{2, Int} = (224, 224), width_mult = 1; + pretrain = false, nclasses = 1000) + +Create a MobileNetv1 model with the baseline configuration +([reference](https://arxiv.org/abs/1704.04861v1)). +Set `pretrain` to `true` to load the pretrained weights for ImageNet. + +# Arguments +- `imsize`: A 2-tuple indicating the input spatial dimensions +- `width_mult`: Controls the number of output feature maps in each block + (with 1.0 being the default in the paper; + this is usually a value between 0.1 and 1.4) +- `pretrain`: Whether to load the pre-trained weights for ImageNet +- `nclasses`: The number of output classes + +See also [`Metalhead.mobilenetv1`](#). +""" +struct MobileNetv1 + layers +end + +function MobileNetv1(imsize::NTuple{2, Int} = (224, 224), width_mult = 1; + pretrain = false, nclasses = 1000) + layers = mobilenetv1(imsize, width_mult, mobilenetv1_configs; nclasses = nclasses) + pretrain && loadpretrain!(layers, string("MobileNetv1")) + + return MobileNetv1(layers) +end + +@functor MobileNetv1 + +(m::MobileNetv1)(x) = m.layers(x) + +backbone(m::MobileNetv1) = m.layers[1] +classifier(m::MobileNetv1) = m.layers[2] + # MobileNetv2 """ diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 5a7008eb3..b13d8d424 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -4,7 +4,7 @@ stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init], initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1) -Create a convolution + batch normalization pair with ReLU activation. +Create a convolution + batch normalization pair with activation. # Arguments - `kernelsize`: size of the convolution kernel (tuple) @@ -48,6 +48,47 @@ function conv_bn(kernelsize, inplanes, outplanes, activation = relu; return rev ? reverse(layers) : layers end +""" + depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu; + rev = false, + stride = 1, pad = 0, dilation = 1, [bias, weight, init], + initβ = Flux.zeros32, initγ = Flux.ones32, + ϵ = 1f-5, momentum = 1f-1) + +Create a depthwise separable convolution chain as used in MobileNet v1. +This is sequence of layers: +- a `kernelsize` depthwise convolution from `inplanes => inplanes` +- a batch norm layer + `activation` +- a `kernelsize` convolution from `inplanes => outplanes` +- a batch norm layer + `activation` + +See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). + +# Arguments +- `kernelsize`: size of the convolution kernel (tuple) +- `inplanes`: number of input feature maps +- `outplanes`: number of output feature maps +- `activation`: the activation function for the final layer +- `rev`: set to `true` to place the batch norm before the convolution +- `stride`: stride of the first convolution kernel +- `pad`: padding of the first convolution kernel +- `dilation`: dilation of the first convolution kernel +- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) +- `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#)) +- `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#)) +""" +depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu; + rev = false, + initβ = Flux.zeros32, initγ = Flux.ones32, + ϵ = 1f-5, momentum = 1f-1, + stride = 1, kwargs...) = + vcat(conv_bn(kernelsize, inplanes, inplanes, activation; + rev = rev, initβ = initβ, initγ = initγ, + ϵ = ϵ, momentum = momentum, + stride = stride, groups = inplanes, kwargs...), + conv_bn((1, 1), inplanes, outplanes, activation; + rev = rev, initβ = initβ, initγ = initγ, + ϵ = ϵ, momentum = momentum)) """ skip_projection(inplanes, outplanes, downsample = false) From 191043df5f4da4f6c8c7e153360d2a88211ca43c Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Apr 2022 12:20:55 -0500 Subject: [PATCH 08/15] Add tests --- test/convnets.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/convnets.jl b/test/convnets.jl index 5976ff40f..fde6a4004 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -100,8 +100,19 @@ GC.gc() end @testset "MobileNet" verbose = true begin - @testset "MobileNetv2" begin + @testset "MobileNetv1" begin + m = MobileNetv1() + + @test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2) + if MobileNetv1 in PRETRAINED_MODELS + @test (MobileNetv1(pretrain = true); true) + else + @test_throws ArgumentError MobileNetv1(pretrain = true) + end + @test_skip gradtest(m, rand(Float32, 224, 224, 3, 2)) + end + @testset "MobileNetv2" begin m = MobileNetv2() @test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2) From 9a5e902e0e7204ac0e01110d62b345557c139790 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Apr 2022 12:32:35 -0500 Subject: [PATCH 09/15] Actually export `MobileNetv1` --- src/Metalhead.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 64fc85c56..8fbfa9eb1 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -44,14 +44,15 @@ export AlexNet, GoogLeNet, Inception3, SqueezeNet, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, ResNeXt, - MobileNetv2, MobileNetv3, + MobileNetv1, MobileNetv2, MobileNetv3, MLPMixer, ResMLP, gMLP, ViT, ConvNeXt, ConvMixer # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt, - :MobileNetv2, :MobileNetv3, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvNeXt, :ConvMixer) +for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt, + :MobileNetv1, :MobileNetv2, :MobileNetv3, + :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvNeXt, :ConvMixer) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end From d899b56f2e95469f239eda18996a14c26f8efa5b Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Apr 2022 12:52:41 -0500 Subject: [PATCH 10/15] Actually export `Layers.depthwise_sep_conv_bn` --- src/layers/Layers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 9a88caaa3..33658307a 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -22,6 +22,6 @@ export MHAttention, LayerScale, DropPath, ChannelLayerNorm, prenorm, skip_identity, skip_projection, - conv_bn, + conv_bn, depthwise_sep_conv_bn, invertedresidual, squeeze_excite end From ab8fd4984dbb2eae344e85ad8c3ad5fcaa3716c7 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Apr 2022 12:58:02 -0500 Subject: [PATCH 11/15] `NTuple{2, Int}` -> `Dims{2}` --- src/convnets/mobilenet.jl | 4 ++-- src/convnets/vgg.jl | 4 ++-- src/layers/embeddings.jl | 12 ++++++------ src/other/mlpmixer.jl | 30 +++++++++++++++--------------- src/vit-based/vit.jl | 12 ++++++------ 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index 37530f660..c88158033 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -64,7 +64,7 @@ const mobilenetv1_configs = [ ] """ - MobileNetv1(imsize::NTuple{2, Int} = (224, 224), width_mult = 1; + MobileNetv1(imsize::Dims{2} = (224, 224), width_mult = 1; pretrain = false, nclasses = 1000) Create a MobileNetv1 model with the baseline configuration @@ -85,7 +85,7 @@ struct MobileNetv1 layers end -function MobileNetv1(imsize::NTuple{2, Int} = (224, 224), width_mult = 1; +function MobileNetv1(imsize::Dims{2} = (224, 224), width_mult = 1; pretrain = false, nclasses = 1000) layers = mobilenetv1(imsize, width_mult, mobilenetv1_configs; nclasses = nclasses) pretrain && loadpretrain!(layers, string("MobileNetv1")) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 2ee7675a4..a0e63d689 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -112,7 +112,7 @@ struct VGG end """ - VGG(imsize::NTuple{2,Int}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) + VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) Construct a VGG model with the specified input image size. Typically, the image size is `(224, 224)`. @@ -125,7 +125,7 @@ Construct a VGG model with the specified input image size. Typically, the image (see [`Metalhead.vgg_classifier_layers`](#)) - `dropout`: dropout level between fully connected layers """ -function VGG(imsize::NTuple{2, <:Integer}; +function VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) layers = vgg(imsize; config = config, inchannels = inchannels, diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index e59fdb1c7..afc6d868d 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -1,8 +1,8 @@ _flatten_spatial(x) = permutedims(reshape(x, (:, size(x, 3), size(x, 4))), (2, 1, 3)) """ - PatchEmbedding(imsize::NTuple{2, Int} = (224, 224); inchannels = 3, - patch_size::NTuple{2, Int} = (16, 16), embedplanes = 768, + PatchEmbedding(imsize::Dims{2} = (224, 224); inchannels = 3, + patch_size::Dims{2} = (16, 16), embedplanes = 768, norm_layer = planes -> identity, flatten = true) Patch embedding layer used by many vision transformer-like models to split the input image into @@ -17,8 +17,8 @@ patches. single argument constructor for a normalization layer like LayerNorm or BatchNorm - `flatten`: set true to flatten the input spatial dimensions after the embedding """ -function PatchEmbedding(imsize::NTuple{2, Int} = (224, 224); inchannels = 3, - patch_size::NTuple{2, Int} = (16, 16), embedplanes = 768, +function PatchEmbedding(imsize::Dims{2} = (224, 224); inchannels = 3, + patch_size::Dims{2} = (16, 16), embedplanes = 768, norm_layer = planes -> identity, flatten = true) im_height, im_width = imsize @@ -33,7 +33,7 @@ function PatchEmbedding(imsize::NTuple{2, Int} = (224, 224); inchannels = 3, end """ - ViPosEmbedding(embedsize, npatches; init = (dims) -> rand(Float32, dims)) + ViPosEmbedding(embedsize, npatches; init = (dims::Dims{2}) -> rand(Float32, dims)) Positional embedding layer used by many vision transformer-like models. """ @@ -41,7 +41,7 @@ struct ViPosEmbedding{T} vectors::T end -ViPosEmbedding(embedsize, npatches; init = (dims::NTuple{2, Int}) -> rand(Float32, dims)) = +ViPosEmbedding(embedsize, npatches; init = (dims::Dims{2}) -> rand(Float32, dims)) = ViPosEmbedding(init((embedsize, npatches))) (p::ViPosEmbedding)(x) = x .+ p.vectors diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index a36f7466f..1e18d306e 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -29,8 +29,8 @@ function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_bl end """ - mlpmixer(block, imsize::NTuple{2, Int} = (224, 224); inchannels = 3, norm_layer = LayerNorm, - patch_size::NTuple{2, Int} = (16, 16), embedplanes = 512, drop_path_rate = 0., + mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, norm_layer = LayerNorm, + patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0., depth = 12, nclasses = 1000, kwargs...) Creates a model with the MLPMixer architecture. @@ -50,8 +50,8 @@ Creates a model with the MLPMixer architecture. - `kwargs`: additional arguments (if any) to pass to the mixer block. Will use the defaults if not specified. """ -function mlpmixer(block, imsize::NTuple{2, Int} = (224, 224); inchannels = 3, norm_layer = LayerNorm, - patch_size::NTuple{2, Int} = (16, 16), embedplanes = 512, drop_path_rate = 0., +function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, norm_layer = LayerNorm, + patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0., depth = 12, nclasses = 1000, kwargs...) npatches = prod(imsize .÷ patch_size) dp_rates = LinRange{Float32}(0., drop_path_rate, depth) @@ -74,8 +74,8 @@ struct MLPMixer end """ - MLPMixer(size::Symbol = :base; patch_size::NTuple{2, Int} = (16, 16), - imsize::NTuple{2, Int} = (224, 224), drop_path_rate = 0., nclasses = 1000) + MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) Creates a model with the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)). @@ -89,8 +89,8 @@ Creates a model with the MLPMixer architecture. See also [`Metalhead.mlpmixer`](#). """ -function MLPMixer(size::Symbol = :base; patch_size::NTuple{2, Int} = (16, 16), - imsize::NTuple{2, Int} = (224, 224), drop_path_rate = 0., nclasses = 1000) +function MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" depth = mixer_configs[size][:depth] embedplanes = mixer_configs[size][:planes] @@ -142,7 +142,7 @@ struct ResMLP end """ - ResMLP(size::Symbol = :base; patch_size::NTuple{2, Int} = (16, 16), imsize::NTuple{2, Int} = (224, 224), + ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) Creates a model with the ResMLP architecture. @@ -157,8 +157,8 @@ Creates a model with the ResMLP architecture. See also [`Metalhead.mlpmixer`](#). """ -function ResMLP(size::Symbol = :base; patch_size::NTuple{2, Int} = (16, 16), - imsize::NTuple{2, Int} = (224, 224), drop_path_rate = 0., nclasses = 1000) +function ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" depth = mixer_configs[size][:depth] embedplanes = mixer_configs[size][:planes] @@ -249,8 +249,8 @@ struct gMLP end """ - gMLP(size::Symbol = :base; patch_size::NTuple{2, Int} = (16, 16), - imsize::NTuple{2, Int} = (224, 224), drop_path_rate = 0., nclasses = 1000) + gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) Creates a model with the gMLP architecture. ([reference](https://arxiv.org/abs/2105.08050)). @@ -264,8 +264,8 @@ Creates a model with the gMLP architecture. See also [`Metalhead.mlpmixer`](#). """ -function gMLP(size::Symbol = :base; patch_size::NTuple{2, Int} = (16, 16), - imsize::NTuple{2, Int} = (224, 224), drop_path_rate = 0., nclasses = 1000) +function gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), + imsize::Dims{2} = (224, 224), drop_path_rate = 0., nclasses = 1000) @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" depth = mixer_configs[size][:depth] embedplanes = mixer_configs[size][:planes] diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 26d2cc332..f49b42be2 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -21,7 +21,7 @@ function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0 end """ - vit(imsize::NTuple{2, Int} = (256, 256); inchannels = 3, patch_size::NTuple{2, Int} = (16, 16), + vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1, emb_dropout = 0.1, pool = :class, nclasses = 1000) @@ -41,7 +41,7 @@ Creates a Vision Transformer (ViT) model. - `pool`: pooling type, either :class or :mean - `nclasses`: number of classes in the output """ -function vit(imsize::NTuple{2, Int} = (256, 256); inchannels = 3, patch_size::NTuple{2, Int} = (16, 16), +function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1, emb_dropout = 0.1, pool = :class, nclasses = 1000) @@ -66,8 +66,8 @@ vit_configs = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3), :gigantic => (depth = 48, embedplanes = 1664, nheads = 16, mlp_ratio = 64/13)) """ - ViT(mode::Symbol = base; imsize::NTuple{2, Int} = (256, 256), inchannels = 3, - patch_size::NTuple{2, Int} = (16, 16), pool = :class, nclasses = 1000) + ViT(mode::Symbol = base; imsize::Dims{2} = (256, 256), inchannels = 3, + patch_size::Dims{2} = (16, 16), pool = :class, nclasses = 1000) Creates a Vision Transformer (ViT) model. ([reference](https://arxiv.org/abs/2010.11929)). @@ -86,8 +86,8 @@ struct ViT layers end -function ViT(mode::Symbol = :base; imsize::NTuple{2, Int} = (256, 256), inchannels = 3, - patch_size::NTuple{2, Int} = (16, 16), pool = :class, nclasses = 1000) +function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256), inchannels = 3, + patch_size::Dims{2} = (16, 16), pool = :class, nclasses = 1000) @assert mode in keys(vit_configs) "`mode` must be one of $(keys(vit_configs))" kwargs = vit_configs[mode] layers = vit(imsize; inchannels, patch_size, nclasses, pool, kwargs...) From 9392aaab35dc7c5f2cc06a3aa600fe7482a97f39 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Apr 2022 13:02:43 -0500 Subject: [PATCH 12/15] Add `MobileNetv1` to README and clean up README table --- README.md | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 80a86692a..cabaa8ccd 100644 --- a/README.md +++ b/README.md @@ -16,21 +16,22 @@ | Model Name | Function | Pre-trained? | |:-------------------------------------------------|:------------------------------------------------------------------------------------------|:------------:| -| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.VGG.html) | N | -| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNet.html) | N | +| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.VGG.html) | N | +| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNet.html) | N | | [GoogLeNet](https://arxiv.org/abs/1409.4842) | [`GoogLeNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.GoogLeNet.html) | N | | [Inception-v3](https://arxiv.org/abs/1512.00567) | [`Inception3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.Inception3.html) | N | | [SqueezeNet](https://arxiv.org/abs/1602.07360) | [`SqueezeNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.SqueezeNet.html) | N | -| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.DenseNet.html) | N | -| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNeXt.html) | N | -| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv2.html) | N | -| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv3.html) | N | -| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MLPMixer.html) | N | -| [ResMLP](https://arxiv.org/abs/2105.03404) | [`ResMLP`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResMLP.html) | N | -| [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.gMLP.html) | N | -| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ViT.html) | N | -| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ConvNeXt.html) | N | -| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ConvMixer.html) | N | +| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.DenseNet.html) | N | +| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNeXt.html) | N | +| [MobileNetv1](https://arxiv.org/abs/1704.04861) | [`MobileNetv1`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv1.html) | N | +| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv2.html) | N | +| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv3.html) | N | +| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MLPMixer.html) | N | +| [ResMLP](https://arxiv.org/abs/2105.03404) | [`ResMLP`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResMLP.html) | N | +| [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.gMLP.html) | N | +| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ViT.html) | N | +| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ConvNeXt.html) | N | +| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ConvMixer.html) | N | ## Getting Started From a8dd9250ce021b531b60c3559937aaae7e9da9e3 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Apr 2022 13:21:12 -0500 Subject: [PATCH 13/15] Add padding = 1 to `MobileNetv1` --- src/convnets/mobilenet.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index c88158033..6bc20e8f3 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -32,9 +32,9 @@ function mobilenetv1(imsize, width_mult, config; for _ in 1:repeats outch = outch * width_mult layer = if dw - depthwise_sep_conv_bn((3, 3), inch, outch, activation; stride = stride) + depthwise_sep_conv_bn((3, 3), inch, outch, activation; stride = stride, pad = 1) else - conv_bn((3, 3), inch, outch, activation; stride = stride) + conv_bn((3, 3), inch, outch, activation; stride = stride, pad = 1) end append!(layers, layer) inch = outch From b0044525785c43f280b27c9dd65805c769af771a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Apr 2022 14:06:26 -0500 Subject: [PATCH 14/15] Garbage collector is no longer my friend. Manual `GC.gc()` is my new best friend. --- test/convnets.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/convnets.jl b/test/convnets.jl index fde6a4004..526949c5b 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -149,6 +149,7 @@ GC.gc() @test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2) @test_skip gradtest(m, rand(Float32, 224, 224, 3, 2)) end + GC.gc() end end From b278e290dea7538314ff9aed9697ef3055ac5eb9 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Apr 2022 14:10:38 -0500 Subject: [PATCH 15/15] Use global pooling before FC layer instead --- src/convnets/mobilenet.jl | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index 6bc20e8f3..568fc3971 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -10,7 +10,6 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). # Arguments -- `imsize`: A 2-tuple indicating the input spatial dimensions - `width_mult`: Controls the number of output feature maps in each block (with 1.0 being the default in the paper) - `configs`: A "list of tuples" configuration for each layer that details: @@ -21,31 +20,29 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). - `fcsize`: The intermediate fully-connected size between the convolution and final layers - `nclasses`: The number of output classes """ -function mobilenetv1(imsize, width_mult, config; +function mobilenetv1(width_mult, config; activation = relu, inchannels = 3, nclasses = 1000, fcsize = 1024) layers = [] - inch = inchannels for (dw, outch, stride, repeats) in config for _ in 1:repeats outch = outch * width_mult layer = if dw - depthwise_sep_conv_bn((3, 3), inch, outch, activation; stride = stride, pad = 1) + depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1) else - conv_bn((3, 3), inch, outch, activation; stride = stride, pad = 1) + conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1) end append!(layers, layer) - inch = outch + inchannels = outch end end - push!(layers, MeanPool((7, 7))) - convsize = prod(Flux.outputsize(layers, (imsize..., inchannels, 1))) return Chain(Chain(layers...), - Chain(MLUtils.flatten, - Dense(convsize, fcsize, activation), + Chain(GlobalMeanPool(), + MLUtils.flatten, + Dense(inchannels, fcsize, activation), Dense(fcsize, nclasses))) end @@ -64,15 +61,13 @@ const mobilenetv1_configs = [ ] """ - MobileNetv1(imsize::Dims{2} = (224, 224), width_mult = 1; - pretrain = false, nclasses = 1000) + MobileNetv1(width_mult = 1; pretrain = false, nclasses = 1000) Create a MobileNetv1 model with the baseline configuration ([reference](https://arxiv.org/abs/1704.04861v1)). Set `pretrain` to `true` to load the pretrained weights for ImageNet. # Arguments -- `imsize`: A 2-tuple indicating the input spatial dimensions - `width_mult`: Controls the number of output feature maps in each block (with 1.0 being the default in the paper; this is usually a value between 0.1 and 1.4) @@ -85,9 +80,8 @@ struct MobileNetv1 layers end -function MobileNetv1(imsize::Dims{2} = (224, 224), width_mult = 1; - pretrain = false, nclasses = 1000) - layers = mobilenetv1(imsize, width_mult, mobilenetv1_configs; nclasses = nclasses) +function MobileNetv1(width_mult::Number = 1; pretrain = false, nclasses = 1000) + layers = mobilenetv1(width_mult, mobilenetv1_configs; nclasses = nclasses) pretrain && loadpretrain!(layers, string("MobileNetv1")) return MobileNetv1(layers)