-
-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement dot_product_attention #455
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this mostly looks good, and is about the right level -- fairly simple-to-read implementation with no magic.
The thing to get right now seems to be: Does this function line up fairly nicely with what CUDA provides, so that an overload dot_product_attention(::CuArray, ...)
can smoothly provide the same functionality? From a quick look
-
It seems that that wants a weight array which if I understand right would be steps before this function: https://github.com/JuliaGPU/CUDA.jl/blob/8a4cbdee50c716ff642eb3d9268f1a7ea4c29eb0/lib/cudnn/src/multiheadattn.jl#L20
-
I'm not entirely sure what its masking options are, what am I missing e.g. here https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnSetAttnDescriptor
-
Would need to read slower to see where it does & doesn't include bias, and dropout.
Maybe 1. points towards this function being the core of a larger one, which more closely matches the CUDA one? If so then it probably shouldn't have dropout, as that can trivially be composed by whatever calls it.
Lots of random comments, but probably you should ignore some on internal details as these can be fixed, but the overall interface is the question.
@@ -42,6 +44,16 @@ This will be copied, as doing so is faster than `batched_mul_generic!`. | |||
Both this `copy` and `batched_mul_generic!` produce `@debug` messages, | |||
and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them. | |||
""" | |||
|
|||
function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My vote is to make this an internal _batched_mul_4
or something for now. Partly because I think explaining what does and doesn't work becomes more complicated with this method. And that doesn't have to be solved to add attention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a pity to not make things available. Maybe I can leave the previous docstring unchanged and add a new one for the new method?
Regarding CUDNN, the descriptor there says
so we can deactivate the projections to match the api introduced in this PR. Conversely, we can add overloads here in the future if we want to consider also the projections. For dropout:
The support for masking is only in the form of windows, according to the following inputs to cudnnMultiHeadAttnForward:
Bias in the attention logits doesn't seem to be supported. |
820d45e
to
958171b
Compare
I think this is good to go |
src/attention.jl
Outdated
if mask === :causal | ||
mask = make_causal_mask(logits) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a cleaner API would be to let the mask
keyword be a function. The nothing
case is mask = identity
and the causal case is mask = make_causal_mask
(which I feel should be just causal_mask
to be succinct).
Is there a reason to construct the mask on the fly? The calling layer in Flux can probably make and store the mask once. Then the other option is to allow nothing
or an array. Then the user passes in mask = causal_mask(ntoken)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the function which you pass, in this proposal?
-
mask = identity
means this is applied to the array. -
mask = make_causal_mask
means it constructs a boolean matrix.
Agree that constructing the same matrix every time seems a bit wasteful, although probably not a big cost, there are quite a few larger copies made in this thing.
With mask = identity
, the usual masking could be causal_mask!
which is basically for i,j in ...; if i<j; x[i,j] = -Inf end;
i.e. it just mutates the data array. This should be safe as the gradient of batched_mul does not need the original values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, it shouldn't be identity
, it should be trues_like
though I'd be okay with nothing
in order skip computing a mask at all.
My comment about constructing on the fly was not a performance concern. I just think it is more intuitive to pass in exactly the mask array I want used. It's an easier rule to remember and also scalable to whatever masking scheme is desired.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The downside is that you have to make an array the right size. If you have several layers and want the same scheme for each, then perhaps it's a pain. Whereas a function like trues_like
is told the size automatically.
(The implementation can branch on mask === trues_like
to avoid work in the default case. We can also branch on the type of const causal_mask = triu ∘ trues_like
if necc.)
While encoding this as a bool array makes some sense, it's also a little weird in that the implementation doesn't directly consume this. Maybe better than my mutating idea above, we can modify softmax
to take a mask argument, and fuse it into the broadcast there, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, but generally the size of this matrix which is # of tokens X # of tokens
is known ahead of time. Even so, I agree that not needing to pass in this info is cleaner.
I mostly wanted to avoid "symbol switches" for arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes to avoiding symbols. I like this mask = trues_like
proposal the best so far.
One question I haven't looked at is what format the CUDNN thing is going to want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of saying mask
is either an array or callable, could we say it should be either an array or marker type for which one can override some init_mask(x, k, v)
function? This would allow us to shift the conditionals out of the attention functions, while still allowing for relatively terse syntax like mask = CausalMask()
when users don't want to precompute their own. You could imagine nice party tricks like passing mask = I
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#460 is a go at this masked softmax idea.
With that, the default of no mask can in fact be mask = Returns(true)
here, instead of trues_like
. And the terse causal mask can be const causal_mask = triu ∘ trues_like
, or a function equivalent to this (maybe it can be more efficient, not sure triu
works on CuArrays). No conditionals required.
Edit: making #460 work on GPU too won't be just a few lines. But even without that, mask::Function = trues_like
as the interface seems nice, instead of having to independently make something the right size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
triu
only works on AbstractMatrix
, which is not sufficient for the attention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this first implementation, I prefer to keep it more minimalistic and just accept nothing
or arrays (I will remove :causal
)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM and appears to give us enough surface area for cuDNN. Just a couple final questions:
src/attention.jl
Outdated
if bias !== nothing | ||
logits = logits .+ bias | ||
end | ||
|
||
if mask !== nothing | ||
neginf = typemin(eltype(logits)) | ||
logits = ifelse.(mask, logits, neginf) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT about making these internal methods which dispatch on nothing
? That way there's zero control flow and Zygote is happy. The main question is whether the additional code + complexity introduced would be worth the compile and runtime reduction.
This looks great, I've long thought that there should be some basic attention mechanisms available here. Before settling on the API for passing masks, I thought I should mention that "non-boolean" masking such as Attention with Linear Biases (ALiBi) have been used for some pretty substantial projects, e.g. BigScience's BLOOM, and it would be useful to have this option. This would only be a slight generalization from what is currently implemented here (see Fig. 3 in the ALiBi paper), and could easily be incorporated by dispatching on the type of the mask:
This would actually be almost exactly what PyTorch does as well, which is a nice bonus. From the
Interesting that PyTorch appears to have the meaning of |
@jondeuce the attention bias in the PR already does what you suggest. Should we collapse |
@CarloLucibello Ahh, yes that is a flexible approach too, and clearly covers The more I think about it, the more I like the way you have it. They are orthogonal and mutually compatible ways to apply a mask, and I don't think combining them and then doing different operations based on the mask type is conceptually any simpler anyways. |
Factored out from FluxML/Flux.jl#2146
Fix #385
In the process, also extends
batched_mul
to multiple batch dimensions. Fix #451 fix #391We may want to consider hooking cudnn in a later PR