Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

permit NNlibCUDA to use Float16 #363

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

bjarthur
Copy link
Contributor

in conjunction with FluxML/NNlibCUDA.jl#32, add support for half-precision gemm, for which a special kernel is provided by Nvidia. see JuliaGPU/CUDA.jl#1080

@@ -220,7 +220,7 @@ _batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C,
_batched_mul!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} =
_batched_try_gemm!(DT, C, A, B, α, β)

function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat}
Copy link
Member

@mcabbott mcabbott Nov 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern with this change (removing {T<:BlasFloat} restriction, not highlighed well) is that it may send weird numbers (like Dual, or BigFloat) down the path towards batched_gemm! which won't accept them.

Perhaps, to safely widen here, the method _batched_gemm!(::Type{<:Array} below needs to be restricted to Array{<:BlasFloat}? With a new method offering another path to batched_mul_generic! at that stage?

The dispatch in this file is pretty convoluted! Maybe there's another tidier solution.

Float16 would be good to have, though. Thanks for digging.

Copy link
Contributor Author

@bjarthur bjarthur Nov 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only place this method (ie _batched_try_gemm!) is currently called is from the method immediately above (ie _batched_mul!() where {T<:BlasFloat}). widening _batched_try_gemm! to types other than BlasFloat permits the proposed new _batched_mul!() where {T<:Float16} in FluxML/NNlibCUDA.jl#32 to call it too. i don't think there's any danger of weird number types getting where they shouldn't.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, now I see better what you're proposing. There are two jumps to the CUDA package, in order to allow Float16 only for CuArrays, not for Arrays. Which is the desired behaviour. The first jump comes back to this package's chain of functions.

It does seem slightly weird to jump twice. Let me think a bit more, I'd be happier if there was exactly one point in the chain where dispatch cared about CuArrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I dropped the ball here. I think we should do this, or at least I certainly didn't get around to thinking up a better way.

Could you perhaps add some comments explaining a bit what's going on? Having dispatch at two points, instead of just reading down the page & at some point jumping to CUDA, is one step trickier to read. Maybe the where {DT<:DenseArray{T}} where {T<:BlasFloat} = ... method can explain that there's another path through here for CuArray{Float16}?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants