Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yieldthought/llama31 8b/ttembed #12560

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions models/demos/wormhole/llama31_8b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env):
# This module requires the env paths above for CI runs
from models.demos.wormhole.llama31_8b.tt.model_config import TtModelArgs

embed_on_device = False
embed_on_device = True
dtype = ttnn.bfloat8_b

# Load model args, weights, and tokenizer
Expand Down Expand Up @@ -155,15 +155,18 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env):
iteration_time_start = time()
curr_pos = generation_start_pos + iteration

# Prepare inputs for decode mode (rotary embeddings, attention mask, padding)
# TODO Move the attn mask to device
decode_input, current_pos = prepare_inputs_ttnn(
tt_decode_input,
curr_pos,
model_args.dim,
model_args.sliding_window,
tt_model.device,
)
if embed_on_device and iteration > 0:
current_pos = curr_pos
decode_input = tt_decode_input
else:
# Prepare inputs for decode mode
decode_input, current_pos = prepare_inputs_ttnn(
tt_decode_input,
curr_pos,
model_args.dim,
model_args.sliding_window,
tt_model.device,
)

# Run ttnn llama model
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
Expand Down Expand Up @@ -211,7 +214,16 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env):
users_decoding = False

if embed_on_device:
tt_out_tok = ttnn.from_torch(tt_out_tok, device=device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT)
# Pad tt_out_tok to batch size of 32
padded_tt_out_tok = torch.zeros(1, 32, dtype=tt_out_tok.dtype, device=tt_out_tok.device)
padded_tt_out_tok[: tt_out_tok.shape[1]] = tt_out_tok
tt_out_tok = ttnn.from_torch(
padded_tt_out_tok,
device=device,
dtype=ttnn.uint32,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
tt_decode_input = tt_embd(tt_out_tok)
else:
tt_decode_input = embd(tt_out_tok)
Expand Down
57 changes: 22 additions & 35 deletions models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num
# This module requires the env paths above for CI runs
from models.demos.wormhole.llama31_8b.tt.model_config import TtModelArgs

embed_on_device = False
embed_on_device = True
dtype = ttnn.bfloat8_b

# We disregard any warmup iteration for profiling, in favour of just measuring compile time on the first iteration
Expand Down Expand Up @@ -341,15 +341,18 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num
curr_pos = generation_start_pos + iteration

# Prepare inputs for decode mode (rotary embeddings, attention mask, padding)
# TODO Move the attn mask to device
profiler.start(f"prepare_input_decode", iteration=batch_idx)
decode_input, current_pos = prepare_inputs_ttnn(
pt_encoded_input,
curr_pos,
model_args.dim,
model_args.sliding_window,
tt_model.device,
)
if embed_on_device and iteration > 0:
current_pos = curr_pos
decode_input = pt_encoded_input
else:
decode_input, current_pos = prepare_inputs_ttnn(
pt_encoded_input,
curr_pos,
model_args.dim,
model_args.sliding_window,
tt_model.device,
)
profiler.end(f"prepare_input_decode", iteration=batch_idx)

profiler.start(f"decode_and_argmax", iteration=batch_idx)
Expand Down Expand Up @@ -404,7 +407,16 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num

profiler.start(f"decode_embedding", iteration=batch_idx)
if embed_on_device:
tt_out_tok = ttnn.from_torch(tt_out_tok, device=device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT)
# Pad tt_out_tok to batch size of 32
padded_tt_out_tok = torch.zeros(1, 32, dtype=tt_out_tok.dtype, device=tt_out_tok.device)
padded_tt_out_tok[: tt_out_tok.shape[1]] = tt_out_tok
tt_out_tok = ttnn.from_torch(
padded_tt_out_tok,
device=device,
dtype=ttnn.uint32,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
pt_encoded_input = tt_embd(tt_out_tok)
else:
pt_encoded_input = embd(tt_out_tok)
Expand Down Expand Up @@ -470,26 +482,6 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num

profiler.end(f"inference_decode", iteration=batch_idx)

# When running in CI, check the output against the expected output to avoid accuracy regressions
# TODO Extend the expected output validation to further batches
if is_ci_env and batch_idx == 0: # Only check output of batch 0
expected_output = "models/demos/wormhole/llama31_8b/demo/expected_outputs_prefill_128.json"
with open(expected_output, "r") as f:
expected_out = json.load(f)
# assert (
# len(expected_out) >= batch_size * 2
# ), f"expected_outputs.json should have {batch_size * 2} outputs: {batch_size} for general weights and {batch_size} for instruct weights!"

for i in range(batch_size):
user_output = "".join(tokenizer.decode(all_outputs[i]))
if instruct_mode: # The instruct outputs are at the end of the expected outputs file
user_expect = expected_out[i + batch_size]["output_instruct"]
else:
user_expect = expected_out[i]["output_general"]

assert user_output == user_expect, f"Output for user {i} does not match expected output!"
logger.info("[CI-Only] Output token validation passed!")

# Finish profiling at the end of all batches
profiler.end("run")

Expand Down Expand Up @@ -555,11 +547,6 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num
decode_tsu = 33
targets = {"prefill_t/s": target_prefill_ts, "decode_t/s": target_decode_ts, "decode_t/s/u": decode_tsu}

# TODO move token verification here?
# if expected_greedy_output_path is not None:
# token_check_does_pass, expected_output = check_tokens_match(generated_text, expected_greedy_output_path)
# measurements["token_verification"] = float(token_check_does_pass)

# Save benchmark data for CI dashboard
if is_ci_env and is_n300:
benchmark_data = create_benchmark_data(profiler, measurements, N_warmup_iter, targets)
Expand Down
28 changes: 17 additions & 11 deletions models/demos/wormhole/llama31_8b/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,27 +170,33 @@ def prepare_inputs_ttnn(x, current_pos, hidden_size, sliding_window, device):
start_pos: int
"""

assert len(x.shape) == 3
assert x.shape[2] == hidden_size
if len(x.shape) == 3:
batch = x.shape[0]
seq_len = x.shape[1]
assert x.shape[2] == hidden_size
elif len(x.shape) == 4:
seq_len = x.shape[0]
assert x.shape[1] == 1
batch = x.shape[2]
assert x.shape[3] == hidden_size

batch = x.shape[0]
seq_len = x.shape[1]
hidden_size = x.shape[2]
assert seq_len == 1, "Only supporting decode mode"

# Support input on device
if torch.is_tensor(x): # Input on host -> Use torch
x = x.transpose(0, 1).unsqueeze(1) # [seq_len, 1, batch, hidden_dim]
else: # Input on device -> Use ttnn
# Pad small batches to 32
if batch < 32:
zeros = torch.zeros(1, seq_len, 32, hidden_size)
zeros[:, :, :batch, :] = x
x = zeros
elif len(x.shape) == 3: # Input on device -> Use ttnn
x = ttnn.reshape(
x, (batch, seq_len, 1, hidden_size)
) # [batch, seqlen, hidden_dim] -> [batch, seqlen, 1, hidden_dim]
x = ttnn.permute(x, (1, 2, 0, 3)) # [seq_len, 1, batch, hidden_dim]
# Pad small batches to 32
if batch < 32:
zeros = torch.zeros(1, seq_len, 32, hidden_size)
zeros[:, :, :batch, :] = x
x = zeros
elif len(x.shape) == 4:
pass # already in [seq_len, 1, batch, hidden_dim]

current = current_pos % sliding_window

Expand Down
6 changes: 5 additions & 1 deletion models/demos/wormhole/llama31_8b/tt/llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,8 @@ def __init__(
)

def forward(self, x: ttnn.Tensor) -> ttnn.Tensor:
return ttnn.embedding(x, self.weights)
x = ttnn.embedding(x, self.weights, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG)
x = ttnn.reshape(x, [x.shape[0], 1, x.shape[1], x.shape[2]])
# x = ttnn.pad(x, padding=((0, 0), (0, 0), (0, 32-x.shape[2]), (0, 0)), value=0)
# x = ttnn.tilize(x, use_multicore=True)
return x
2 changes: 2 additions & 0 deletions models/demos/wormhole/llama31_8b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def __init__(self, device, instruct=False, dummy_weights=False):

# Enable workarounds by default until di/dt issues are fixed
self.di_dt_workaround = os.getenv("DISABLE_DI_DT_WORKAROUND") != "1"
if not self.di_dt_workaround:
logger.info("Disabling di/dt workaround, re-enable if you see hangs")

DRAM_MEMCFG = ttnn.DRAM_MEMORY_CONFIG
L1_MEMCFG = ttnn.L1_MEMORY_CONFIG
Expand Down
Loading