Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump compat for Metalhead #232

Merged
merged 5 commits into from
Aug 24, 2023
Merged

Bump compat for Metalhead #232

merged 5 commits into from
Aug 24, 2023

Conversation

ablaom
Copy link
Collaborator

@ablaom ablaom commented Aug 23, 2023

This PR bumps the [compat] for Metalhead to "0.8" and addresses resulting breakages.

Replaces #226

@ablaom
Copy link
Collaborator Author

ablaom commented Aug 24, 2023

Failing on GPU only. The complaint is about scalar indexing.

I've spent some time on this today but this is hard for me to debug because I don't currently have GPU access. I conjecture that the following code is failing on a GPU but not a CPU, and this contains the issue. Be good if someone can confirm this indeed fails. And if so, where is the scalar indexing?

import Flux
import MLJFlux
import StableRNGs.StableRNG

rng = StableRNG(123)
X, y = MLJFlux.make_images(rng);

typeof(X)
# Vector{Matrix{Gray{Float64}}}

data = MLJFlux.collate(ImageClassifier(), X, y);

Flux.gpu(data) # no effect on my CPU-only machine
typeof(data)
# Tuple{Vector{Array{Float32, 4}}, Vector{OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}}

n_channels = 1
n_classes = 3
init = Flux.glorot_uniform(rng)

chain = Flux.Chain(
    Flux.Conv((2, 2), n_channels=>2, init=init),
    Flux.Conv((2, 2), 2=>1, init=init),
    x->reshape(x, :, size(x)[end]),
    Flux.Dense(16, n_classes, init=init))

x = data[1][1]
typeof(x)
# Array{Float32, 4}

sizeof(x)
# (6, 6, 1, 1)

chain(x)

@ablaom ablaom mentioned this pull request Aug 24, 2023
@ablaom
Copy link
Collaborator Author

ablaom commented Aug 24, 2023

Okay. I guess the reshape is the issue. JuliaGPU/CUDA.jl#228 .

But these docs say that reshape is special-cased by CuArrays. Mmm...

@ablaom
Copy link
Collaborator Author

ablaom commented Aug 24, 2023

Okay, looks like the colon is not supported in the CuArray-specialized reshape.

@mohamed82008
Copy link
Contributor

@ablaom would you also be open to making Metalhead an optional dependency?

function make2d(x)
l = length(x)
b = size(x)[end]
reshape(x, div(l, b), b)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I thought this use of reshape would avoid the "scalar indexing error", but tests are still not passing.

@ablaom
Copy link
Collaborator Author

ablaom commented Aug 24, 2023

@ablaom would you also be open to making Metalhead an optional dependency?

That's tricky because the default builder for ImageClassifier is a VGG architecture from Metalhead.jl. We could throw an error if Metalhead.jl is not loaded, but that would violate a general principle currently holding true among 200+ MLJ models (and assumed in applications of MLJTestIntegration.jl): the empty argument constructor always works unless the model is a wrapper, like TunedModel.

An alternative, which already looks too complicated to me, is to make the default

  • something simple (and essentially useless) in the case Metalhead.jl is not loaded, but throw an informative warning suggesting it be loaded.
  • the VGG model otherwise

Related discussion: #162

Maybe you have a better idea?

@ablaom ablaom added the help wanted Extra attention is needed label Aug 24, 2023
@mohamed82008
Copy link
Contributor

Not a better idea but moving ImageClassifier itself out to another package is another option. Not a great option though if you want all Flux-related wrappers to be in this repo.

@ToucheSir
Copy link
Member

Okay, looks like the colon is not supported in the CuArray-specialized reshape.

That seems strange to me, because it's basically what MLUtils.flatten does and I don't recall having any issues with it. Indeed, making the same reshape calls manually results in a CuArray.

I had a look through the failing CI runs, and the problem is instead this warning: https://buildkite.com/julialang/mljflux-dot-jl/builds/339#018a25d7-e2c7-4c2b-a7af-c1a9d97436c4/425-777. Because we switched to package extensions in Flux 0.14, cuDNN needs to be separately added to an environment to enable the CUDA conv routines in NNlib. I'm guessing MLJFlux doesn't want to take it on as a dep, so adding it into your test env/extras should be enough.

@codecov-commenter
Copy link

Codecov Report

Patch coverage: 25.00% and project coverage change: -1.21% ⚠️

Comparison is base (19dc08b) 93.26% compared to head (fa7133b) 92.06%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@            Coverage Diff             @@
##              dev     #232      +/-   ##
==========================================
- Coverage   93.26%   92.06%   -1.21%     
==========================================
  Files          12       12              
  Lines         312      315       +3     
==========================================
- Hits          291      290       -1     
- Misses         21       25       +4     
Files Changed Coverage Δ
src/metalhead.jl 80.64% <14.28%> (-12.22%) ⬇️
src/mlj_model_interface.jl 94.20% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ablaom
Copy link
Collaborator Author

ablaom commented Aug 24, 2023

Thanks indeed for the help @ToucheSir. I guess that scalar indexing error was a Red Herring. I've added cuDNN to the tests deps and there's no sign of it now.

@ablaom ablaom merged commit 4bf0a3d into dev Aug 24, 2023
5 checks passed
@ablaom ablaom deleted the metalhead-update branch August 24, 2023 22:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants