Skip to content

Commit

Permalink
Merge pull request #232 from FluxML/metalhead-update
Browse files Browse the repository at this point in the history
Bump compat for Metalhead
  • Loading branch information
ablaom authored Aug 24, 2023
2 parents 19dc08b + fa7133b commit 4bf0a3d
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 50 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ ColorTypes = "0.10.3, 0.11"
ComputationalResources = "0.3.2"
Flux = "0.13, 0.14"
MLJModelInterface = "1.1.1"
Metalhead = "0.7"
Metalhead = "0.8"
ProgressMeter = "1.7.1"
Tables = "1.0"
julia = "1.6"

[extras]
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -36,4 +37,4 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["LinearAlgebra", "MLJBase", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"]
test = ["cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"]
38 changes: 22 additions & 16 deletions src/metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ TODO: After https://github.com/FluxML/Metalhead.jl/issues/176:
- Export and externally document `image_builder` method
- Delete definition of `ResNetHack` below
- Delete definition of `VGGHack` below
- Change default builder in ImageClassifier (see /src/types.jl) from
`image_builder(ResNetHack)` to `image_builder(Metalhead.ResNet)`.
`image_builder(VGGHack)` to `image_builder(Metalhead.VGG)`.
=#

Expand Down Expand Up @@ -51,7 +51,7 @@ Base.show(io::IO, w::MetalheadBuilder) =
Return an MLJFlux builder object based on the Metalhead.jl constructor/type
`metalhead_constructor` (eg, `Metalhead.ResNet`). Here `args` and `kwargs` are
passed to the `MetalheadType` constructor at "build time", along with
passed as arguments to `metalhead_constructor` at "build time", along with
the extra keyword specifiers `imsize=...`, `inchannels=...` and
`nclasses=...`, with values inferred from the data.
Expand All @@ -61,14 +61,14 @@ If in Metalhead.jl you would do
```julia
using Metalhead
model = ResNet(50, pretrain=true, inchannels=1, nclasses=10)
model = ResNet(50, pretrain=false, inchannels=1, nclasses=10)
```
then in MLJFlux, it suffices to do
```julia
using MLJFlux, Metalhead
builder = image_builder(ResNet, 50, pretrain=true)
builder = image_builder(ResNet, 50, pretrain=false)
```
which can be used in `ImageClassifier` as in
Expand Down Expand Up @@ -122,25 +122,31 @@ function VGGHack(
pretrain=false,
)

# Adapted from
# https://github.com/FluxML/Metalhead.jl/blob/9edff63222720ff84671b8087dd71eb370a6c35a/src/convnets/vgg.jl#L165
# Adapted from
# https://github.com/FluxML/Metalhead.jl/blob/4e5b8f16964468518eeb6eb8d7e5f85af4ecf959/src/convnets/vgg.jl#L161
# But we do not ignore `imsize`.

@assert(
depth in keys(Metalhead.vgg_config),
"depth must be from one in $(sort(collect(keys(Metalhead.vgg_config))))"
depth in keys(Metalhead.VGG_CONFIGS),
"depth must be from one in $(sort(collect(keys(Metalhead.VGG_CONFIGS))))"
)
model = Metalhead.VGG(imsize;
config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]],
config = Metalhead.VGG_CONV_CONFIGS[Metalhead.VGG_CONFIGS[depth]],
inchannels,
batchnorm,
nclasses,
fcsize = 4096,
dropout = 0.5)
if pretrain && !batchnorm
Metalhead.loadpretrain!(model, string("VGG", depth))
elseif pretrain
Metalhead.loadpretrain!(model, "VGG$(depth)-BN)")
dropout_prob = 0.5)
if pretrain
imsize == (224, 224) || @warn "Using `pretrain=true` may not work unless "*
"image size is `(224, 224)`, which it is not. "
artifact_name = string("vgg", depth)
if batchnorm
artifact_name *= "_bn"
else
artifact_name *= "-IMAGENET1K_V1"
end
loadpretrain!(model, artifact_name)
end

return model
end
9 changes: 5 additions & 4 deletions src/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ end

# # FIT AND UPDATE

const ERR_BUILDER =
const ERR_BUILDER =
"Builder does not appear to build an architecture compatible with supplied data. "

true_rng(model) = model.rng isa Integer ? MersenneTwister(model.rng) : model.rng
Expand All @@ -60,17 +60,18 @@ function MLJModelInterface.fit(model::MLJFluxModel,
catch ex
@error ERR_BUILDER
end

penalty = Penalty(model)
data = move.(collate(model, X, y))

x = data |> first |> first
x = data[1][1]

try
chain(x)
catch ex
@error ERR_BUILDER
throw(ex)
end
end

optimiser = deepcopy(model.optimiser)

Expand Down
2 changes: 1 addition & 1 deletion test/builders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ end

# reproducibility (without dropout):
chain2 = MLJFlux.build(builder, StableRNGs.StableRNG(123), 5, 3)
x = rand(5)
x = rand(Float32, 5)
@test chain(x) chain2(x)
end

Expand Down
10 changes: 5 additions & 5 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ rowvec(y::Vector) = reshape(y, 1, length(y))
@test MLJFlux.MLJModelInterface.istransparent(Flux.Adam(0.1))

@testset "nrows" begin
Xmatrix = rand(stable_rng, 10, 3)
Xmatrix = rand(stable_rng, Float32, 10, 3)
X = MLJBase.table(Xmatrix)
@test MLJFlux.nrows(X) == 10
@test MLJFlux.nrows(Tables.columntable(X)) == 10
Expand All @@ -19,7 +19,7 @@ end
# convert to a column table:
X = MLJBase.table(Xmatrix)

y = rand(stable_rng, 10)
y = rand(stable_rng, Float32, 10)
model = MLJFlux.NeuralNetworkRegressor()
model.batch_size= 3
@test MLJFlux.collate(model, X, y) ==
Expand All @@ -38,7 +38,7 @@ end
reshape([1; 0], (2,1))]))

# MultitargetNeuralNetworRegressor:
ymatrix = rand(stable_rng, 10, 2)
ymatrix = rand(stable_rng, Float32, 10, 2)
y = MLJBase.table(ymatrix) # a rowaccess table
model = MLJFlux.NeuralNetworkRegressor()
model.batch_size= 3
Expand All @@ -54,7 +54,7 @@ end
ymatrix'[:,7:9], ymatrix'[:,10:10]]))

# ImageClassifier
Xmatrix = coerce(rand(stable_rng, 6, 6, 1, 10), GrayImage)
Xmatrix = coerce(rand(stable_rng, Float32, 6, 6, 1, 10), GrayImage)
y = categorical(['a', 'b', 'a', 'a', 'b', 'a', 'a', 'a', 'b', 'a'])
model = MLJFlux.ImageClassifier(batch_size=2)

Expand All @@ -69,7 +69,7 @@ end

end

Xmatrix = rand(stable_rng, 100, 5)
Xmatrix = rand(stable_rng, Float32, 100, 5)
X = MLJBase.table(Xmatrix)
y = Xmatrix[:, 1] + Xmatrix[:, 2] + Xmatrix[:, 3] +
Xmatrix[:, 4] + Xmatrix[:, 5]
Expand Down
35 changes: 23 additions & 12 deletions test/image.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,29 @@ mutable struct MyNeuralNetwork <: MLJFlux.Builder
kernel2
end

function MLJFlux.build(model::MyNeuralNetwork, rng, ip, op, n_channels)
# to get a matrix whose last dimension mathces that of the array input (the batch size):
function make2d(x)
l = length(x)
b = size(x)[end]
reshape(x, div(l, b), b)
end

function MLJFlux.build(builder::MyNeuralNetwork, rng, ip, op, n_channels)
init = Flux.glorot_uniform(rng)
Flux.Chain(
Flux.Conv(model.kernel1, n_channels=>2, init=init),
Flux.Conv(model.kernel2, 2=>1, init=init),
x->reshape(x, :, size(x)[end]),
Flux.Dense(16, op, init=init))
front = Flux.Chain(
Flux.Conv(builder.kernel1, n_channels=>2, init=init),
Flux.Conv(builder.kernel2, 2=>1, init=init),
make2d,
)
d = Flux.outputsize(front, (ip..., n_channels, 1))[1]
return Flux.Chain(
front,
Flux.Dense(d, op, init=init)
)
end

builder = MyNeuralNetwork((2,2), (2,2))
images, labels = MLJFlux.make_images(stable_rng)
images, labels = MLJFlux.make_images(stable_rng);
losses = []

@testset_accelerated "ImageClassifier basic tests" accel begin
Expand Down Expand Up @@ -69,8 +81,6 @@ reference = losses[1]

# # BASIC IMAGE TESTS COLOR

# In this case we use the default ResNet builder

builder = MyNeuralNetwork((2,2), (2,2))
images, labels = MLJFlux.make_images(stable_rng, color=true)
losses = []
Expand Down Expand Up @@ -112,12 +122,13 @@ reference = losses[1]
@test all(x->abs(x - reference)/reference < 1e-5, losses[2:end])


# # SMOKE TEST FOR DEFAULT BUILDER
# # SMOKE TEST FOR DEFAULT BUILDER

images, labels = MLJFlux.make_images(stable_rng, image_size=(32, 32), n_images=12, noise=0.2, color=true);
images, labels = MLJFlux.make_images(stable_rng, image_size=(32, 32), n_images=12,
noise=0.2, color=true);

@testset_accelerated "ImageClassifier basic tests" accel begin
model = MLJFlux.ImageClassifier(epochs=10,
model = MLJFlux.ImageClassifier(epochs=5,
batch_size=4,
acceleration=accel,
rng=stable_rng)
Expand Down
19 changes: 9 additions & 10 deletions test/metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,15 @@ end
@test builder.metalhead_constructor == Metalhead.VGG
@test builder.args == (depth, )
@test (; builder.kwargs...) == (; batchnorm=true)
ref_chain = Metalhead.VGG(
imsize;
config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]],
inchannels,
batchnorm=true,
nclasses,
fcsize = 4096,
dropout = 0.5
)
# needs https://github.com/FluxML/Metalhead.jl/issues/176

## needs https://github.com/FluxML/Metalhead.jl/issues/176:
# ref_chain = Metalhead.VGG(
# imsize;
# config = Metalhead.VGG_CONV_CONFIGS[Metalhead.VGG_CONFIGS[depth]],
# inchannels,
# batchnorm=true,
# nclasses,
# )
# chain =
# MLJFlux.build(builder, StableRNGs.StableRNG(123), imsize, nclasses, inchannels)
# @test length.(MLJFlux.Flux.params(ref_chain)) ==
Expand Down
1 change: 1 addition & 0 deletions test/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ end

# integration test:
X, y = MLJBase.make_regression(10)
X = Float32.(MLJBase.Tables.matrix(X)) |> MLJBase.Tables.table
mach = MLJBase.machine(model, X, y)
MLJBase.fit!(mach, verbosity=0)
losses = MLJBase.training_losses(mach)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Random.seed!
using Statistics
import StatsBase
using StableRNGs
using cuDNN

using ComputationalResources
using ComputationalResources: CPU1, CUDALibs
Expand Down

0 comments on commit 4bf0a3d

Please sign in to comment.