Skip to content

Commit

Permalink
#9370: added support for B*n_kv_heads > num_cores case for sdpa decode
Browse files Browse the repository at this point in the history
  • Loading branch information
caixunshiren committed Oct 1, 2024
1 parent db33f97 commit 36efba7
Show file tree
Hide file tree
Showing 5 changed files with 385 additions and 390 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def run_test_sdpa_decode_single_iter(
# [32, 8, 1, 32768, 128, (8, 6), True, False], # Llama2-70B
# [4, 32, 8, 32768, 128, (8, 8), True, False], # llama 3.1 8b
[4, 32, 8, 32768, 128, (8, 8), True, True], # llama 3.1 8b
# [4, 32, 8, 32768, 128, (8, 8), False, False], # llama 3.1 8b
[32, 32, 8, 8192, 128, (8, 8), True, False], # llama 3.1 8b
# [4, 16, 4, 32768, 128, (8, 8), False, False], # llama 3.1 8b
),
)
Expand Down Expand Up @@ -722,7 +722,9 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc
(
[32, 8, 1, 32768, 128, (8, 6), True], # Llama2-70B
[4, 32, 8, 32768, 128, (8, 8), True], # llama 3.1 8b
[4, 16, 4, 32768, 128, (8, 8), True],
# [4, 16, 4, 32768, 128, (8, 8), True],
# [32, 32, 8, 4096, 128, (8, 8), True], # llama 3.1 8b
[8, 16, 4, 4096, 128, (8, 2), True], # llama 3.1 8b N300
# [1, 8, 1, 32768, 128, (8, 1), True], # Llama2-70B
# [16, 8, 1, 32768, 128, (8, 6), False, False], # Llama2-70B
# [8, 8, 1, 32768, 128, (8, 6), True, False], # Llama2-70B
Expand Down
Loading

0 comments on commit 36efba7

Please sign in to comment.