Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 0.7 (round 2) #142

Merged
merged 13 commits into from
Apr 8, 2022
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
BSON = "0.3.2"
Flux = "0.12"
Flux = "0.13"
Functors = "0.2"
MLUtils = "0.1.2, 0.2"
NNlib = "0.7.34, 0.8"
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Now, we can use this model with Flux like any other model. Below, we train it on
```julia
using Flux: onehotbatch

batchsize = 4
batchsize = 1
data = [(rand(Float32, 224, 224, 3, batchsize), onehotbatch(rand(1:1000, batchsize), 1:1000))
for _ in 1:3]
opt = ADAM()
Expand Down
6 changes: 4 additions & 2 deletions src/convnets/mobilenet.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# MobileNetv1

"""
mobilenetv1(imsize, width_mult, config;
mobilenetv1(width_mult, config;
activation = relu,
inchannels = 3,
nclasses = 1000,
Expand All @@ -17,6 +17,8 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)).
- `o`: The number of output feature maps
- `s`: The stride of the convolutional kernel
- `r`: The number of time this configuration block is repeated
- `activate`: The activation function to use throughout the network
- `inchannels`: The number of input feature maps``
- `fcsize`: The intermediate fully-connected size between the convolution and final layers
- `nclasses`: The number of output classes
"""
Expand All @@ -27,8 +29,8 @@ function mobilenetv1(width_mult, config;
fcsize = 1024)
layers = []
for (dw, outch, stride, repeats) in config
outch = Int(outch * width_mult)
for _ in 1:repeats
outch = outch * width_mult
layer = if dw
depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
else
Expand Down
2 changes: 1 addition & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu;
vcat(conv_bn(kernelsize, inplanes, inplanes, activation;
rev = rev, initβ = initβ, initγ = initγ,
ϵ = ϵ, momentum = momentum,
stride = stride, groups = inplanes, kwargs...),
stride = stride, groups = Int(inplanes), kwargs...),
conv_bn((1, 1), inplanes, outplanes, activation;
rev = rev, initβ = initβ, initγ = initγ,
ϵ = ϵ, momentum = momentum))
Expand Down
9 changes: 4 additions & 5 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,16 @@ Note that this is specifically for inputs with 4 dimensions in the format
(H, W, C, N) where H, W are the height and width of the input, C is the number
of channels, and N is the batch size.
"""
struct ChannelLayerNorm{F,D,T}
λ::F
struct ChannelLayerNorm{D,T}
diag::D
ϵ::T
end

@functor ChannelLayerNorm

(m::ChannelLayerNorm)(x) = m.λ.(m.diag(MLUtils.normalise(x, dims = ndims(x) - 1, ϵ = m.ϵ)))
(m::ChannelLayerNorm)(x) = m.diag(MLUtils.normalise(x, dims = ndims(x) - 1, ϵ = m.ϵ))

function ChannelLayerNorm(sz::Int, λ = identity; ϵ = 1f-5)
diag = Flux.Diagonal(1, 1, sz)
return ChannelLayerNorm(λ, diag, ϵ)
diag = Flux.Scale(1, 1, sz, λ)
return ChannelLayerNorm(diag, ϵ)
end
23 changes: 4 additions & 19 deletions src/layers/others.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,15 @@
"""
LayerScale(scale)

Implements LayerScale.
([reference](https://arxiv.org/abs/2103.17239))

# Arguments
- `scale`: Scaling factor, a learnable diagonal matrix which is multiplied to the input.
"""
struct LayerScale{T<:AbstractVector{<:Real}}
scale::T
end

"""
LayerScale(λ, planes::Int)

Implements LayerScale.
([reference](https://arxiv.org/abs/2103.17239))
Creates a `Flux.Scale` layer that performs "`LayerScale`"
([reference](https://arxiv.org/abs/2103.17239)).

# Arguments
- `planes`: Size of channel dimension in the input.
- `λ`: initialisation value for the learnable diagonal matrix.
"""
LayerScale(planes::Int, λ) = λ > 0 ? LayerScale(fill(Float32(λ), planes)) : identity

@functor LayerScale
(m::LayerScale)(x::AbstractArray) = m.scale .* x
LayerScale(planes::Int, λ) =
λ > 0 ? Flux.Scale(fill(Float32(λ), planes), false) : identity

"""
DropPath(p)
Expand Down
4 changes: 2 additions & 2 deletions src/other/mlpmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ Creates a block for the ResMixer architecture.
"""
function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block,
dropout = 0., drop_path_rate = 0., activation = gelu, λ = 1e-4)
return Chain(SkipConnection(Chain(Flux.Diagonal(planes),
return Chain(SkipConnection(Chain(Flux.Scale(planes),
swapdims((2, 1, 3)),
Dense(npatches, npatches),
swapdims((2, 1, 3)),
LayerScale(planes, λ),
DropPath(drop_path_rate)), +),
SkipConnection(Chain(Flux.Diagonal(planes),
SkipConnection(Chain(Flux.Scale(planes),
mlp_layer(planes, Int(mlp_ratio * planes); dropout, activation),
LayerScale(planes, λ),
DropPath(drop_path_rate)), +))
Expand Down
9 changes: 7 additions & 2 deletions src/pretrain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ Load the pre-trained weights for `model` using the stored artifacts.
function weights(model)
try
path = joinpath(@artifact_str(model), "$model.bson")
return BSON.load(path, @__MODULE__)[:weights]
artifact = BSON.load(path, @__MODULE__)
if haskey(artifact, :model)
return artifact[:model]
else
throw(ArgumentError("No pre-trained weights available for $model."))
end
catch e
throw(ArgumentError("No pre-trained weights available for $model."))
end
Expand All @@ -17,4 +22,4 @@ end

Load the pre-trained weight artifacts matching `<name>.bson` into `model`.
"""
loadpretrain!(model, name) = Flux.loadparams!(model, weights(name))
loadpretrain!(model, name) = Flux.loadmodel!(model, weights(name))
87 changes: 46 additions & 41 deletions test/convnets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,34 @@ PRETRAINED_MODELS = []

@testset "AlexNet" begin
model = AlexNet()
@test size(model(rand(Float32, 256, 256, 3, 2))) == (1000, 2)
@test size(model(rand(Float32, 256, 256, 3, 1))) == (1000, 1)
@test_throws ArgumentError AlexNet(pretrain = true)
@test_skip gradtest(model, rand(Float32, 256, 256, 3, 2))
@test_skip gradtest(model, rand(Float32, 256, 256, 3, 1))
end

@testset "VGG" begin
@testset "$model(BN=$bn)" for model in [VGG11, VGG13, VGG16, VGG19], bn in [true, false]
imsize = (224, 224)
m = model(batchnorm = bn)
@testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false]
m = VGG(sz, batchnorm = bn)

@test size(m(rand(Float32, imsize..., 3, 2))) == (1000, 2)
if (model, bn) in PRETRAINED_MODELS
@test (model(batchnorm = bn, pretrain = true); true)
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
if (VGG, sz, bn) in PRETRAINED_MODELS
@test (VGG(sz, batchnorm = bn, pretrain = true); true)
else
@test_throws ArgumentError model(batchnorm = bn, pretrain = true)
@test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true)
end
@test_skip gradtest(m, rand(Float32, imsize..., 3, 2))
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 1))
end
end

@testset "ResNet" begin
@testset for model in [ResNet18, ResNet34, ResNet50, ResNet101, ResNet152]
m = model()
@testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
m = ResNet(sz)

@test size(m(rand(Float32, 256, 256, 3, 2))) == (1000, 2)
if model in PRETRAINED_MODELS
@test (model(pretrain = true); true)
@test size(m(rand(Float32, 256, 256, 3, 1))) == (1000, 1)
if (ResNet, sz) in PRETRAINED_MODELS
@test (ResNet(sz, pretrain = true); true)
else
@test_throws ArgumentError model(pretrain = true)
@test_throws ArgumentError ResNet(sz, pretrain = true)
end
@test_skip gradtest(m, rand(Float32, 256, 256, 3, 2))
end
Expand All @@ -44,15 +43,15 @@ end
channel_config = [1, 1],
block_config = [2, 2, 2, 2])

@test size(m(rand(Float32, 256, 256, 3, 2))) == (1000, 2)
@test size(m(rand(Float32, 256, 256, 3, 1))) == (1000, 1)
end
end

@testset "ResNeXt" begin
@testset for depth in [50, 101, 152]
m = ResNeXt(depth)

@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
if ResNeXt in PRETRAINED_MODELS
@test (ResNeXt(depth, pretrain = true); true)
else
Expand All @@ -64,77 +63,83 @@ end

@testset "GoogLeNet" begin
m = GoogLeNet()
@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
@test_throws ArgumentError (GoogLeNet(pretrain = true); true)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 1))
end

@testset "Inception3" begin
m = Inception3()
@test size(m(rand(Float32, 299, 299, 3, 2))) == (1000, 2)
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
@test_throws ArgumentError Inception3(pretrain = true)
@test_skip gradtest(m, rand(Float32, 299, 299, 3, 2))
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
end

@testset "SqueezeNet" begin
m = SqueezeNet()
@test size(m(rand(Float32, 227, 227, 3, 2))) == (1000, 2)
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
@test_throws ArgumentError (SqueezeNet(pretrain = true); true)
@test_skip gradtest(m, rand(Float32, 227, 227, 3, 2))
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 1))
end

GC.gc()

@testset "DenseNet" begin
@testset for model in [DenseNet121, DenseNet161, DenseNet169, DenseNet201]
m = model()
@testset for sz in [121, 161, 169, 201]
m = DenseNet(sz)

@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
if model in PRETRAINED_MODELS
@test (model(pretrain = true); true)
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
if (DenseNet, sz) in PRETRAINED_MODELS
@test (DenseNet(sz, pretrain = true); true)
else
@test_throws ArgumentError model(pretrain = true)
@test_throws ArgumentError DenseNet(sz, pretrain = true)
end
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 1))
end
end

GC.gc()

@testset "MobileNet" verbose = true begin
@testset "MobileNetv1" begin
m = MobileNetv1()

@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
if MobileNetv1 in PRETRAINED_MODELS
@test (MobileNetv1(pretrain = true); true)
else
@test_throws ArgumentError MobileNetv1(pretrain = true)
end
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 1))
end

GC.gc()

@testset "MobileNetv2" begin
m = MobileNetv2()

@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
if MobileNetv2 in PRETRAINED_MODELS
@test (MobileNetv2(pretrain = true); true)
else
@test_throws ArgumentError MobileNetv2(pretrain = true)
end
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 1))
end

GC.gc()

@testset "MobileNetv3" verbose = true begin
@testset for mode in [:small, :large]
m = MobileNetv3(mode)

@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
if MobileNetv3 in PRETRAINED_MODELS
@test (MobileNetv3(mode; pretrain = true); true)
else
@test_throws ArgumentError MobileNetv3(mode; pretrain = true)
end
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 1))
end
end
end
Expand All @@ -146,8 +151,8 @@ GC.gc()
@testset for drop_path_rate in [0.0, 0.5, 0.99]
m = ConvNeXt(mode; drop_path_rate)

@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 1))
end
GC.gc()
end
Expand All @@ -159,7 +164,7 @@ GC.gc()
@testset for mode in [:base, :large, :small]
m = ConvMixer(mode)

@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 1))
end
end
16 changes: 8 additions & 8 deletions test/other.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,31 @@ using Flux
@testset for mode in [:small, :base, :large, :huge]
@testset for drop_path_rate in [0.0, 0.5, 0.99]
m = MLPMixer(mode; drop_path_rate)
@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
GC.gc()
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 1))
end
GC.gc()
end
end

@testset "ResMLP" begin
@testset for mode in [:small, :base, :large, :huge]
@testset for drop_path_rate in [0.0, 0.5, 0.99]
m = ResMLP(mode; drop_path_rate)
@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 1))
GC.gc()
end
GC.gc()
end
end

@testset "gMLP" begin
@testset for mode in [:small, :base, :large, :huge]
@testset for drop_path_rate in [0.0, 0.5, 0.99]
m = gMLP(mode; drop_path_rate)
@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
GC.gc()
@test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 1))
end
GC.gc()
end
end
6 changes: 3 additions & 3 deletions test/vit-based.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using Flux
@testset "ViT" begin
for mode in [:tiny, :small, :base, :large] #,:huge, :giant, :gigantic]
m = ViT(mode)
@test size(m(rand(Float32, 256, 256, 3, 2))) == (1000, 2)
@test_skip gradtest(m, rand(Float32, 256, 256, 3, 2))
GC.gc()
@test size(m(rand(Float32, 256, 256, 3, 1))) == (1000, 1)
@test_skip gradtest(m, rand(Float32, 256, 256, 3, 1))
end
GC.gc()
end