Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Decode GQA (and MQA) Improvements (Round 1) #12739

Merged
merged 7 commits into from
Sep 20, 2024
Merged

Conversation

caixunshiren
Copy link
Contributor

@caixunshiren caixunshiren commented Sep 16, 2024

Ticket

This PR contains the round 1 improvements outlined in #12330 :

  • Support transpose_q, which both q of shape [1 x qh x b x h] and [1 x b x qh x h] are supported as input
  • Support GQA on a shared cache, which KV of shape [1 x kh x s x h] is supported
  • Add support for tensor indices for GQA

FYI @sraizada-tt @cglagovichTT

Post commit pipeline: https://github.com/tenstorrent/tt-metal/actions/runs/10956618476/job/30423043917

@caixunshiren
Copy link
Contributor Author

caixunshiren commented Sep 18, 2024

@caixunshiren caixunshiren changed the title Flash Decode GQA (and MQA) Improvements Flash Decode GQA (and MQA) Improvements (Round 1) Sep 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kernels kernels, such as hlks or llks or below llama3 LLM_feature P1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants