-
Notifications
You must be signed in to change notification settings - Fork 633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
5/N Ragged Inference - Move triton_v2 ragged inference code to new experimental
directory
#189
5/N Ragged Inference - Move triton_v2 ragged inference code to new experimental
directory
#189
Conversation
thanks for the PR @nottombrown, having a look asap ! @fmassa @dianaml0 I'll try to put up a PR this week end to move the mem efficient attention to triton 2 and move it to /experimental also, and same could be done with the favor specific kernel which was breaking triton 1.0 |
Great! I'll make a separate branch for further changes so as not to collide with this one! |
experimental/ragged_inference_v2/ragged_inference_v2/garbage_pad_ragged_acts.py
Outdated
Show resolved
Hide resolved
see nottombrown#2 for minor changes, I hope that works |
interesting, on a ampere laptop some tests do not pass, and others segfault. How HW ready is triton v2.0 @ptillet ? Note that it could be something else, wrong gcc or Cuda version, but I've tried a few |
Oh this is strange. What are the nature of the failures? It's possible that some tests are not configured to skip configs that require too much shared memory |
Not a shared memory size issue, segfault at JIT time in Triton/compile. (No
errors when installing or importing Triton). Only happens for some tests,
not all of them
…On Sat, Jan 22, 2022 at 11:28 AM Philippe Tillet ***@***.***> wrote:
Oh this is strange. What are the nature of the failures? It's possible
that some tests are not configured to skip configs that require too much
shared memory
—
Reply to this email directly, view it on GitHub
<#189 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAXOGXLVBNFVH2PRZOJIHGTUXMANZANCNFSM5MOYADSA>
.
You are receiving this because your review was requested.Message ID:
***@***.***>
|
Gcc 9 and 10, cuda 11.5 by the way
On Sat, Jan 22, 2022 at 11:39 AM Benjamin Lefaudeux <
***@***.***> wrote:
… Not a shared memory size issue, segfault at JIT time in Triton/compile.
(No errors when installing or importing Triton). Only happens for some
tests, not all of them
On Sat, Jan 22, 2022 at 11:28 AM Philippe Tillet ***@***.***>
wrote:
> Oh this is strange. What are the nature of the failures? It's possible
> that some tests are not configured to skip configs that require too much
> shared memory
>
> —
> Reply to this email directly, view it on GitHub
> <#189 (comment)>,
> or unsubscribe
> <https://github.com/notifications/unsubscribe-auth/AAXOGXLVBNFVH2PRZOJIHGTUXMANZANCNFSM5MOYADSA>
> .
> You are receiving this because your review was requested.Message ID:
> ***@***.***>
>
|
[ragged attention] suggested minor changes (will update the other PR if accepted)
return scores_out.reshape((n_ctx_q, n_ctx_k)) | ||
|
||
|
||
def ragged_qk_dotprod( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious to get some perf numbers on that one, even if probably early
|
||
bytes_in_keys_per_seq = n_key_ctx_per_seq * d_model_per_gpu * 2 # 2 from bf16 | ||
bytes_in_keys_total = bytes_in_keys_per_seq * n_seqs | ||
hbm_bw_bytes_per_gpu = 1555e9 # 1.5TB/s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not for this PR, but I would not do this (compare a number to a theoretical one), (a) it's HW specific -how does this test relate to another accelerator ?- and (b) does some suppositions on what's going on -here the data format for instance-.
For other benchmarks we calcul the user-facing throughput, it's also for instance what Phil does here, in that you consider the implementation as a black box, and you count the bytes going in and out (at best for instance you read the seqs and write the attention matrix, the rest is history). It's mostly what happens in this code already, but I would
-
count the BW with num elem * elem_size() (and not suppose bfloat16, would be nice to compare across types actually, it can give an idea on how the kernels are compute or bandwidth bound, at least it helped me on other tasks)
-
test for a bunch of sizes, from experience there are a lot of possible holes / scheduling, and testing with one size only is like russian roulette (see for instance, there are helpers for this in the repo if that helps)
I can do that later on if this ends up running locally, brain dump here :)
|
||
# Define indices ranges, we follow the triton convention of prefixing | ||
# with "r" to denote a range like "rq" is the range for queries below | ||
rq = in_q_token_offset + tl.arange(0, BLOCK_Q) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, "all" (well, a lot of) the magic is there, looks like nothing but super well done I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks very good to me, I especially like the wrap for the ragged attentions and how they play well with the kernel. Great comments also, will be nice for newcomers ! Thanks a bunch @nottombrown
85a07a0
to
450b178
Compare
450b178
to
ffb1d89
Compare
ragged_acts_offset_ptr = ragged_acts_offset_per_seq_ptr + seq_idx | ||
ragged_acts_offset = tl.load(ragged_acts_offset_ptr) | ||
|
||
# Create a mask to guard memory operations against out-of-bounds accesses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this unavoidable? How much more memory does this consume (if any)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's part of the Triton API (masked load), it should not take any registers really. This is all in SoC space, so nothing visible in the RAM, worst case it's a throwaway mask on the actual gpu chip
# We just use one program per n_ctx position for simplicity | ||
assert d_model >= 128, f"bad {d_model=}" | ||
assert d_model <= 8 * 1024, f"bad {d_model=}" | ||
assert d_model % 32 == 0, f"bad {d_model=}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 32 and not 64? Where did these requirements come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it comes from BLOCK_K = 32, a scheduling/tiling constraint for the matmuls. This could be relaxed if we were to mask over that dimension in the kernel
for n_ctx in n_ctx_per_kv_cache: | ||
for idx_into_seq in range(max_n_ctx): | ||
if idx_into_seq < n_ctx: | ||
indices_list.append(ragged_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering why we need O(n) append calls here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please submit a PR :D yes this is not optimal indeed
def get_all_configs(): | ||
return [ | ||
# basic configs for compute-bound matmuls | ||
triton.Config( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do these magic numbers / configs come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comes from here, empirical good values for Ampere GPUs, architecture dependent but Triton does navigate around some specifics thanks for all these scheduling options https://github.com/openai/triton/blob/v2.0/python/triton/ops/matmul.py#L35
|
||
# In einsum notation, the tl.dot does: qd,dk->qk | ||
# This should use tensorcores, so the inputs might be fp16, but the outputs | ||
# and all the internal accumulators are fp32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do internal accumulators default to fp32 or tf32 on A100s?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tl.dot() always return fp32, it's a very good question, it must be documented somewhere on nvidia's side
Just asking a bunch of questions for my own learnings - feel free to ignore! |
no worries, very good questions I think, I tried to give some insights but @nottombrown could probably add a little. Note that the code changed a tiny bit since this PR, and some updates were planned |
Since this code using
triton_v2
it's currently incompatible with our CI pipeline. This PR moves it to a separate package that can avoid breaking CI while still letting imports work correctly.Once triton v2 is stable then we can upgrade
xformers
core to 2.0 and pull theexperimental
code into the core package