Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

automatically convert input matrix to Float32 #272

Merged
merged 4 commits into from
Sep 29, 2024

Conversation

tiemvanderdeure
Copy link
Contributor

This follows up on JuliaAI/MLJModels.jl#565 and just adds a tiny step that automatically converts input to Float32 before passing it to Flux.

I can't see why anyone would ever not want this, as the eltype of any neural net generated through MLJFlux will always be Float32, and any other input type will be converted to Float32 anyways, but with a much bigger computational cost. So I didn't build in an option to disable this behaviour.

The reason to have this in MLJFlux in particular is that other machines such as MLJModels.OneHotEncoder output Float64 types.

I didn't look into #267 in detail and it might make this redundant in some cases, but maybe not in all?

@ablaom
Copy link
Collaborator

ablaom commented Sep 8, 2024

Thanks indeed for this PR. Can you please look at the CI fails? For example, we have:

collate: Error During Test at /home/runner/work/MLJFlux.jl/MLJFlux.jl/test/core.jl:25
[648](https://github.com/FluxML/MLJFlux.jl/actions/runs/10738332003/job/29781691377?pr=272#step:6:651)
  Test threw exception
[649](https://github.com/FluxML/MLJFlux.jl/actions/runs/10738332003/job/29781691377?pr=272#step:6:652)
  Expression: MLJFlux.collate(model, X, y) == ([(Xmatrix')[:, 1:3], (Xmatrix')[:, 4:6], (Xmatrix')[:, 7:9], (Xmatrix')[:, 10:10]], rowvec.([y[1:3], y[4:6], y[7:9], y[10:10]]))
[650](https://github.com/FluxML/MLJFlux.jl/actions/runs/10738332003/job/29781691377?pr=272#step:6:653)
  MethodError: no method matching collate(::MLJFlux.NeuralNetworkRegressor{MLJFlux.Linear, Optimisers.Adam, typeof(Flux.Losses.mse)}, ::Tables.MatrixTable{Matrix{Float64}}, ::Vector{Float32})
[651](https://github.com/FluxML/MLJFlux.jl/actions/runs/10738332003/job/29781691377?pr=272#step:6:654)

@ablaom
Copy link
Collaborator

ablaom commented Sep 9, 2024

@tiemvanderdeure If you can fix the fails here, I can sort out the conflict.

@tiemvanderdeure
Copy link
Contributor Author

Works now!

@ablaom
Copy link
Collaborator

ablaom commented Sep 9, 2024

Okay @tiemvanderdeure, thanks.

Turns out conflict resolution is a bit more complicated than anticipated. So I'm waiting on resolution of the following before trying again:

@ablaom ablaom merged commit 1cfcd1e into FluxML:dev Sep 29, 2024
3 checks passed
@ablaom ablaom mentioned this pull request Sep 29, 2024
@ablaom
Copy link
Collaborator

ablaom commented Sep 29, 2024

@tiemvanderdeure Your PR has now been merged and is part of 0.6, just released.

Thanks for you patience as we sorted out the conflicts.

@tiemvanderdeure
Copy link
Contributor Author

Awesome, thanks for seeing this one through!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants