Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

5/N Ragged Inference - Move triton_v2 ragged inference code to new experimental directory #189

Merged
merged 11 commits into from
Jan 24, 2022

Conversation

nottombrown
Copy link
Contributor

Since this code using triton_v2 it's currently incompatible with our CI pipeline. This PR moves it to a separate package that can avoid breaking CI while still letting imports work correctly.

Once triton v2 is stable then we can upgrade xformers core to 2.0 and pull the experimental code into the core package

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 21, 2022
@blefaudeux
Copy link
Contributor

thanks for the PR @nottombrown, having a look asap ! @fmassa @dianaml0 I'll try to put up a PR this week end to move the mem efficient attention to triton 2 and move it to /experimental also, and same could be done with the favor specific kernel which was breaking triton 1.0

@nottombrown
Copy link
Contributor Author

Great! I'll make a separate branch for further changes so as not to collide with this one!

@blefaudeux
Copy link
Contributor

see nottombrown#2 for minor changes, I hope that works

@blefaudeux
Copy link
Contributor

blefaudeux commented Jan 22, 2022

interesting, on a ampere laptop some tests do not pass, and others segfault. How HW ready is triton v2.0 @ptillet ? Note that it could be something else, wrong gcc or Cuda version, but I've tried a few

@ptillet
Copy link

ptillet commented Jan 22, 2022

Oh this is strange. What are the nature of the failures? It's possible that some tests are not configured to skip configs that require too much shared memory

@blefaudeux
Copy link
Contributor

blefaudeux commented Jan 22, 2022 via email

@blefaudeux
Copy link
Contributor

blefaudeux commented Jan 22, 2022 via email

return scores_out.reshape((n_ctx_q, n_ctx_k))


def ragged_qk_dotprod(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious to get some perf numbers on that one, even if probably early


bytes_in_keys_per_seq = n_key_ctx_per_seq * d_model_per_gpu * 2 # 2 from bf16
bytes_in_keys_total = bytes_in_keys_per_seq * n_seqs
hbm_bw_bytes_per_gpu = 1555e9 # 1.5TB/s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for this PR, but I would not do this (compare a number to a theoretical one), (a) it's HW specific -how does this test relate to another accelerator ?- and (b) does some suppositions on what's going on -here the data format for instance-.

For other benchmarks we calcul the user-facing throughput, it's also for instance what Phil does here, in that you consider the implementation as a black box, and you count the bytes going in and out (at best for instance you read the seqs and write the attention matrix, the rest is history). It's mostly what happens in this code already, but I would

  • count the BW with num elem * elem_size() (and not suppose bfloat16, would be nice to compare across types actually, it can give an idea on how the kernels are compute or bandwidth bound, at least it helped me on other tasks)

  • test for a bunch of sizes, from experience there are a lot of possible holes / scheduling, and testing with one size only is like russian roulette (see for instance, there are helpers for this in the repo if that helps)

I can do that later on if this ends up running locally, brain dump here :)


# Define indices ranges, we follow the triton convention of prefixing
# with "r" to denote a range like "rq" is the range for queries below
rq = in_q_token_offset + tl.arange(0, BLOCK_Q)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, "all" (well, a lot of) the magic is there, looks like nothing but super well done I think

Copy link
Contributor

@blefaudeux blefaudeux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks very good to me, I especially like the wrap for the ragged attentions and how they play well with the kernel. Great comments also, will be nice for newcomers ! Thanks a bunch @nottombrown

@blefaudeux blefaudeux force-pushed the tom/experimental branch 2 times, most recently from 85a07a0 to 450b178 Compare January 24, 2022 07:29
@blefaudeux blefaudeux merged commit 02e5abc into facebookresearch:main Jan 24, 2022
@blefaudeux blefaudeux mentioned this pull request Jan 24, 2022
10 tasks
xwhan pushed a commit to xwhan/xformers that referenced this pull request Feb 8, 2022
ragged_acts_offset_ptr = ragged_acts_offset_per_seq_ptr + seq_idx
ragged_acts_offset = tl.load(ragged_acts_offset_ptr)

# Create a mask to guard memory operations against out-of-bounds accesses

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this unavoidable? How much more memory does this consume (if any)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's part of the Triton API (masked load), it should not take any registers really. This is all in SoC space, so nothing visible in the RAM, worst case it's a throwaway mask on the actual gpu chip

# We just use one program per n_ctx position for simplicity
assert d_model >= 128, f"bad {d_model=}"
assert d_model <= 8 * 1024, f"bad {d_model=}"
assert d_model % 32 == 0, f"bad {d_model=}"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 32 and not 64? Where did these requirements come from?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it comes from BLOCK_K = 32, a scheduling/tiling constraint for the matmuls. This could be relaxed if we were to mask over that dimension in the kernel

for n_ctx in n_ctx_per_kv_cache:
for idx_into_seq in range(max_n_ctx):
if idx_into_seq < n_ctx:
indices_list.append(ragged_idx)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering why we need O(n) append calls here...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please submit a PR :D yes this is not optimal indeed

def get_all_configs():
return [
# basic configs for compute-bound matmuls
triton.Config(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do these magic numbers / configs come from?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comes from here, empirical good values for Ampere GPUs, architecture dependent but Triton does navigate around some specifics thanks for all these scheduling options https://github.com/openai/triton/blob/v2.0/python/triton/ops/matmul.py#L35


# In einsum notation, the tl.dot does: qd,dk->qk
# This should use tensorcores, so the inputs might be fp16, but the outputs
# and all the internal accumulators are fp32

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do internal accumulators default to fp32 or tf32 on A100s?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.dot() always return fp32, it's a very good question, it must be documented somewhere on nvidia's side

@suchenzang
Copy link

Just asking a bunch of questions for my own learnings - feel free to ignore!

@blefaudeux
Copy link
Contributor

Just asking a bunch of questions for my own learnings - feel free to ignore!

no worries, very good questions I think, I tried to give some insights but @nottombrown could probably add a little. Note that the code changed a tiny bit since this PR, and some updates were planned

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants