Skip to content

Commit

Permalink
#10936: Enable llama tg unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
johanna-rock-tt committed Sep 5, 2024
1 parent dd62884 commit 77a1613
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions tests/ttnn/multichip_unit_tests/test_multidevice_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
indirect=True,
)
def test_galaxy_matmul_1d_fracture(mesh_device):
torch.manual_seed(1234)

act_pt = torch.randn(1, 1, 32, 8192)
weights_pt = torch.randn(1, 1, 8192, 32768)
act = ttnn.from_torch(
Expand Down Expand Up @@ -89,14 +91,14 @@ def test_galaxy_matmul_1d_fracture(mesh_device):
pytest.param(128, 52 * 1024, 16 * 1024, ttnn.bfloat8_b, id="Llama3-405B_prefill_seq128_FF2"),
pytest.param(256, 16 * 1024, 52 * 1024, ttnn.bfloat4_b, id="Llama3-405B_prefill_seq256_FF1"),
pytest.param(256, 52 * 1024, 16 * 1024, ttnn.bfloat8_b, id="Llama3-405B_prefill_seq256_FF2"),
# pytest.param(
# 512, 16 * 1024, 52 * 1024, ttnn.bfloat4_b, id="Llama3-405B_prefill_seq512_FF1"
# ), # PCC check failed, PCC: -0.00014127559109112134, see issue 10936
pytest.param(512, 16 * 1024, 52 * 1024, ttnn.bfloat4_b, id="Llama3-405B_prefill_seq512_FF1"),
pytest.param(512, 52 * 1024, 16 * 1024, ttnn.bfloat8_b, id="Llama3-405B_prefill_seq512_FF2"),
],
)
# Llama FF1, FF2, FF3 in MLP with dram interleaved weights
def test_galaxy_matmul_2d_fracture(M, K, N, weights_dtype, mesh_shape, mesh_device):
torch.manual_seed(1234)

act_pt = torch.randn(1, 1, M, K)
weights_pt = torch.randn(1, 1, K, N)

Expand Down Expand Up @@ -146,6 +148,16 @@ def test_galaxy_matmul_2d_fracture(M, K, N, weights_dtype, mesh_shape, mesh_devi
act,
weights,
dtype=ttnn.bfloat16,
# program_config=ttnn.MatmulMultiCoreReuseMultiCastProgramConfig(
# compute_with_storage_grid_size=(8,8),
# in0_block_w=1,
# out_subblock_h=1,
# out_subblock_w=1,
# per_core_M=2,
# per_core_N=26,
# transpose_mcast=False,
# fused_activation=None,
# ), # if M == 512 and N == 52 * 1024 else None, # use specific ProgramConfig to avoid PCC issue
compute_kernel_config=compute_kernel_lofi,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG if M == 32 else ttnn.DRAM_MEMORY_CONFIG,
)
Expand All @@ -171,6 +183,8 @@ def test_galaxy_matmul_2d_fracture(M, K, N, weights_dtype, mesh_shape, mesh_devi
)
# Llama FF1, FF2, FF3 in MLP with dram sharded weights
def test_galaxy_matmul_2d_fracture_dram_sharded(M, K, N, weights_dtype, mesh_shape, mesh_device):
torch.manual_seed(1234)

act_pt = torch.randn(1, 1, M, K)
weights_pt = torch.randn(1, 1, K, N)

Expand Down Expand Up @@ -268,6 +282,8 @@ def test_galaxy_matmul_2d_fracture_dram_sharded(M, K, N, weights_dtype, mesh_sha
)
# Llama FF1 * FF3 in MLP
def test_galaxy_eltwise_mul_2d_fracture(M, N, mesh_shape, mesh_device):
torch.manual_seed(1234)

FF1_pt = torch.randn(1, 1, M, N)
FF3_pt = torch.randn(1, 1, M, N)

Expand Down Expand Up @@ -316,6 +332,8 @@ def test_galaxy_eltwise_mul_2d_fracture(M, N, mesh_shape, mesh_device):
)
# Llama residual add
def test_galaxy_eltwise_add(M, N, mesh_device):
torch.manual_seed(1234)

residual_pt = torch.randn(1, 1, M, N)
attn_output_pt = torch.randn(1, 1, M, N)

Expand Down Expand Up @@ -396,6 +414,8 @@ def test_galaxy_eltwise_add(M, N, mesh_device):
)
# Llama attention matmuls
def test_galaxy_attn_matmul(M, N, head_dim, num_heads, mesh_shape, mesh_device):
torch.manual_seed(1234)

act_pt = torch.randn(1, 1, M, N)
weights_pt = torch.randn(1, 1, N, head_dim * num_heads)

Expand Down Expand Up @@ -491,6 +511,8 @@ def num_to_corerange(total_max_cores):
def test_galaxy_nlp_create_heads_decode(
batch, seq_len, head_dim, n_local_heads, n_local_kv_heads, is_multicore, mesh_device
):
torch.manual_seed(1234)

total_heads = n_local_heads + n_local_kv_heads * 2
qkv_heads_pt = torch.rand(1, seq_len, batch, head_dim * total_heads)
total_max_cores = total_heads * head_dim // 32 if is_multicore else 1 # 40 for llama3-70B; 72 for llama3-405B
Expand Down Expand Up @@ -576,6 +598,8 @@ def test_galaxy_nlp_create_heads_decode(
)
# Llama rotary matmul (decode only)
def test_galaxy_rotary_matmul(batch, seq_len, head_dim, n_local_heads, n_local_kv_heads, mesh_device):
torch.manual_seed(1234)

q_heads_pt = torch.rand(
seq_len, batch, max(n_local_heads, 32), head_dim
) # Unpad batch=32 to 8 for each column group
Expand Down Expand Up @@ -694,6 +718,8 @@ class TestUpdateCache:
def test_fill_cache(
self, seq_len, head_dim, max_seq_len, num_users, num_heads, input_dtype, mesh_device, use_program_cache
):
torch.manual_seed(1234)

cache_dtype = input_dtype
input_shape = [1, num_heads, seq_len, head_dim]
cache_shape = [num_users, num_heads, max_seq_len, head_dim]
Expand Down Expand Up @@ -762,6 +788,8 @@ def test_update_cache_decode(
mesh_device,
use_program_cache,
):
torch.manual_seed(1234)

if num_users > 32 or (num_users + batch_offset) > 32:
pytest.skip("Batch offset is only used when num_users < 32 and batch_offset + num_users <= 32")
input_shape = [num_users, num_heads, 1, head_dim]
Expand Down Expand Up @@ -1023,6 +1051,8 @@ def test_sdpa_decode_sharded(mesh_device, b, nh, nkv, s, d, dtype, grid_size, q_
def test_galaxy_nlp_concat_heads_decode(
batch, seq_len, head_dim, n_local_heads, n_local_kv_heads, padded_local_heads, mesh_device
):
torch.manual_seed(1234)

concat_head_input = torch.rand(seq_len, batch, padded_local_heads, head_dim)

mesh_shape = ttnn.CoreRangeSet({num_to_corerange(batch)})
Expand Down Expand Up @@ -1083,6 +1113,8 @@ def rmsnorm(x, gamma, beta, eps):
ids=["Llama3-70B-decode", "Llama3-405B-decode"],
)
def test_galaxy_layernorm(M, N, mesh_device):
torch.manual_seed(1234)

layernorm_input = torch.rand(1, 1, M, N) * 2 - 0.95
norm_weights = torch.rand(1, 1, N // 32, 32) * 2 - 1
norm_eps = 1e-05
Expand Down

0 comments on commit 77a1613

Please sign in to comment.