Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement dot_product_attention #455

Merged
merged 15 commits into from
Feb 3, 2023
Merged

implement dot_product_attention #455

merged 15 commits into from
Feb 3, 2023

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Jan 3, 2023

Factored out from FluxML/Flux.jl#2146

Fix #385

In the process, also extends batched_mul to multiple batch dimensions. Fix #451 fix #391

We may want to consider hooking cudnn in a later PR

src/attention.jl Outdated Show resolved Hide resolved
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this mostly looks good, and is about the right level -- fairly simple-to-read implementation with no magic.

The thing to get right now seems to be: Does this function line up fairly nicely with what CUDA provides, so that an overload dot_product_attention(::CuArray, ...) can smoothly provide the same functionality? From a quick look

  1. It seems that that wants a weight array which if I understand right would be steps before this function: https://github.com/JuliaGPU/CUDA.jl/blob/8a4cbdee50c716ff642eb3d9268f1a7ea4c29eb0/lib/cudnn/src/multiheadattn.jl#L20

  2. I'm not entirely sure what its masking options are, what am I missing e.g. here https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnSetAttnDescriptor

  3. Would need to read slower to see where it does & doesn't include bias, and dropout.

Maybe 1. points towards this function being the core of a larger one, which more closely matches the CUDA one? If so then it probably shouldn't have dropout, as that can trivially be composed by whatever calls it.

Lots of random comments, but probably you should ignore some on internal details as these can be fixed, but the overall interface is the question.

src/attention.jl Outdated Show resolved Hide resolved
src/attention.jl Outdated Show resolved Hide resolved
src/attention.jl Outdated Show resolved Hide resolved
src/attention.jl Outdated Show resolved Hide resolved
src/attention.jl Outdated Show resolved Hide resolved
src/attention.jl Outdated Show resolved Hide resolved
src/attention.jl Show resolved Hide resolved
src/attention.jl Outdated Show resolved Hide resolved
@@ -42,6 +44,16 @@ This will be copied, as doing so is faster than `batched_mul_generic!`.
Both this `copy` and `batched_mul_generic!` produce `@debug` messages,
and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them.
"""

function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My vote is to make this an internal _batched_mul_4 or something for now. Partly because I think explaining what does and doesn't work becomes more complicated with this method. And that doesn't have to be solved to add attention.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a pity to not make things available. Maybe I can leave the previous docstring unchanged and add a new one for the new method?

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Jan 5, 2023

Regarding CUDNN, the descriptor there says

Weight matrices WQ,i , WK,i , WV,i and WO,i play similar roles, adjusting vector lengths in q , K,V inputs and in the multi-head attention final output. The user can disable any or all projections by setting qProjSize , kProjSize , vProjSize or oProjSize arguments to zero.

so we can deactivate the projections to match the api introduced in this PR. Conversely, we can add overloads here in the future if we want to consider also the projections.

For dropout:

The attnDropoutDesc and postDropoutDesc arguments are descriptors that define two dropout layers active in the training mode. The first dropout operation defined by attnDropoutDesc, is applied directly to the softmax output. The second dropout operation, specified by postDropoutDesc, alters the multi-head attention output, just before the point where residual connections are added.

The support for masking is only in the form of windows, according to the following inputs to cudnnMultiHeadAttnForward:

loWinIdx[], hiWinIdx[]

Input. Two host integer arrays specifying the start and end indices of the attention window for each Q time-step. 
The start index in K, V sets is inclusive, and the end index is exclusive.

Bias in the attention logits doesn't seem to be supported.

@CarloLucibello
Copy link
Member Author

I think this is good to go

src/attention.jl Outdated
Comment on lines 105 to 107
if mask === :causal
mask = make_causal_mask(logits)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a cleaner API would be to let the mask keyword be a function. The nothing case is mask = identity and the causal case is mask = make_causal_mask (which I feel should be just causal_mask to be succinct).

Is there a reason to construct the mask on the fly? The calling layer in Flux can probably make and store the mask once. Then the other option is to allow nothing or an array. Then the user passes in mask = causal_mask(ntoken).

Copy link
Member

@mcabbott mcabbott Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the function which you pass, in this proposal?

  • mask = identity means this is applied to the array.

  • mask = make_causal_mask means it constructs a boolean matrix.

Agree that constructing the same matrix every time seems a bit wasteful, although probably not a big cost, there are quite a few larger copies made in this thing.

With mask = identity, the usual masking could be causal_mask! which is basically for i,j in ...; if i<j; x[i,j] = -Inf end; i.e. it just mutates the data array. This should be safe as the gradient of batched_mul does not need the original values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it shouldn't be identity, it should be trues_like though I'd be okay with nothing in order skip computing a mask at all.

My comment about constructing on the fly was not a performance concern. I just think it is more intuitive to pass in exactly the mask array I want used. It's an easier rule to remember and also scalable to whatever masking scheme is desired.

Copy link
Member

@mcabbott mcabbott Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downside is that you have to make an array the right size. If you have several layers and want the same scheme for each, then perhaps it's a pain. Whereas a function like trues_like is told the size automatically.

(The implementation can branch on mask === trues_like to avoid work in the default case. We can also branch on the type of const causal_mask = triu ∘ trues_like if necc.)

While encoding this as a bool array makes some sense, it's also a little weird in that the implementation doesn't directly consume this. Maybe better than my mutating idea above, we can modify softmax to take a mask argument, and fuse it into the broadcast there, I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, but generally the size of this matrix which is # of tokens X # of tokens is known ahead of time. Even so, I agree that not needing to pass in this info is cleaner.

I mostly wanted to avoid "symbol switches" for arguments.

Copy link
Member

@mcabbott mcabbott Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to avoiding symbols. I like this mask = trues_like proposal the best so far.

One question I haven't looked at is what format the CUDNN thing is going to want.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of saying mask is either an array or callable, could we say it should be either an array or marker type for which one can override some init_mask(x, k, v) function? This would allow us to shift the conditionals out of the attention functions, while still allowing for relatively terse syntax like mask = CausalMask() when users don't want to precompute their own. You could imagine nice party tricks like passing mask = I.

Copy link
Member

@mcabbott mcabbott Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#460 is a go at this masked softmax idea.

With that, the default of no mask can in fact be mask = Returns(true) here, instead of trues_like. And the terse causal mask can be const causal_mask = triu ∘ trues_like, or a function equivalent to this (maybe it can be more efficient, not sure triu works on CuArrays). No conditionals required.

Edit: making #460 work on GPU too won't be just a few lines. But even without that, mask::Function = trues_like as the interface seems nice, instead of having to independently make something the right size.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triu only works on AbstractMatrix, which is not sufficient for the attention.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this first implementation, I prefer to keep it more minimalistic and just accept nothing or arrays (I will remove :causal)

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Jan 22, 2023

:causal removed, now we only accept array masks or nothing. Good to go?

src/attention.jl Outdated Show resolved Hide resolved
Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM and appears to give us enough surface area for cuDNN. Just a couple final questions:

test/attention.jl Outdated Show resolved Hide resolved
src/attention.jl Outdated
Comment on lines 101 to 108
if bias !== nothing
logits = logits .+ bias
end

if mask !== nothing
neginf = typemin(eltype(logits))
logits = ifelse.(mask, logits, neginf)
end
Copy link
Member

@ToucheSir ToucheSir Jan 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about making these internal methods which dispatch on nothing? That way there's zero control flow and Zygote is happy. The main question is whether the additional code + complexity introduced would be worth the compile and runtime reduction.

@jondeuce
Copy link
Contributor

jondeuce commented Feb 3, 2023

This looks great, I've long thought that there should be some basic attention mechanisms available here.

Before settling on the API for passing masks, I thought I should mention that "non-boolean" masking such as Attention with Linear Biases (ALiBi) have been used for some pretty substantial projects, e.g. BigScience's BLOOM, and it would be useful to have this option.

This would only be a slight generalization from what is currently implemented here (see Fig. 3 in the ALiBi paper), and could easily be incorporated by dispatching on the type of the mask:

  1. mask::AbstractArray{Bool} acts like ifelse.(mask, logits, -Inf)
  2. mask::AbstractArray{<:Real} acts like logits .+ mask

This would actually be almost exactly what PyTorch does as well, which is a nice bonus. From the torch.nn.MultiheadAttention docs:

Binary, byte, and float masks are supported. For a binary mask, a True value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.

Interesting that PyTorch appears to have the meaning of true and false reversed compared to what is implemented in this PR, but Keras has the same convention as this PR (see the docs and the code). Not sure which meaning is more natural 🤷‍♂️.

@CarloLucibello
Copy link
Member Author

@jondeuce the attention bias in the PR already does what you suggest. Should we collapse bias into mask or keep the two separate? I was inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html?highlight=attention
for the interface

@jondeuce
Copy link
Contributor

jondeuce commented Feb 3, 2023

@CarloLucibello Ahh, yes that is a flexible approach too, and clearly covers ALiBi-style masks. It's funny, I did notice the bias being added, but my brain did not register that bias would be an additive mask, instead I thought of learnable biases like in Dense layers.

The more I think about it, the more I like the way you have it. They are orthogonal and mutually compatible ways to apply a mask, and I don't think combining them and then doing different operations based on the mask type is conceptually any simpler anyways.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
6 participants