From 540d11b5e3ae479939fffbf2f94d650449a88e26 Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Wed, 11 Sep 2024 13:56:11 +0000 Subject: [PATCH 1/3] #12328: Move embeddings on device --- models/demos/wormhole/llama31_8b/demo/demo.py | 193 ++++++++++-------- .../wormhole/llama31_8b/tt/llama_common.py | 28 ++- .../wormhole/llama31_8b/tt/llama_embedding.py | 4 +- .../wormhole/llama31_8b/tt/model_config.py | 2 + 4 files changed, 127 insertions(+), 100 deletions(-) diff --git a/models/demos/wormhole/llama31_8b/demo/demo.py b/models/demos/wormhole/llama31_8b/demo/demo.py index 737010517d4..a6ea7463b23 100644 --- a/models/demos/wormhole/llama31_8b/demo/demo.py +++ b/models/demos/wormhole/llama31_8b/demo/demo.py @@ -74,7 +74,7 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env): # This module requires the env paths above for CI runs from models.demos.wormhole.llama31_8b.tt.model_config import TtModelArgs - embed_on_device = False + embed_on_device = True dtype = ttnn.bfloat8_b # Load model args, weights, and tokenizer @@ -105,7 +105,7 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env): embd.load_state_dict({"emb.weight": state_dict["tok_embeddings.weight"]}) generation_start_pos = 0 - max_generated_tokens = 120 + max_generated_tokens = 5 users_decoding = True # Preprocess initial prompt inputs @@ -149,97 +149,114 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env): all_outputs = [[] for _ in range(batch_size)] user_done = [False] * batch_size # Keeps track when a user reaches EoD token - iteration = 0 - # Keep running inference as long as there is a user in the batch still decoding or max tokens per user are decoded - while users_decoding: - iteration_time_start = time() - curr_pos = generation_start_pos + iteration - - # Prepare inputs for decode mode (rotary embeddings, attention mask, padding) - # TODO Move the attn mask to device - decode_input, current_pos = prepare_inputs_ttnn( - tt_decode_input, - curr_pos, - model_args.dim, - model_args.sliding_window, - tt_model.device, - ) + from viztracer import VizTracer + + with VizTracer(output_file="llama_demo_trace.json"): + iteration = 0 + # Keep running inference as long as there is a user in the batch still decoding or max tokens per user are decoded + while users_decoding: + iteration_time_start = time() + curr_pos = generation_start_pos + iteration - # Run ttnn llama model - tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) - tt_out = ttnn.untilize( - tt_out, use_multicore=False - ) # multi-core OOMs (https://github.com/tenstorrent/tt-metal/issues/9022) - tt_output_torch = ( - ttnn.to_torch(tt_out).permute(2, 1, 0, 3).squeeze(1)[: model_args.max_batch_size, :, :] - ) # [batch, seq, hidden_dim] - # Update rotation matrix for next iteration - current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) - # If temperature is 0, does greedy decoding (top-1) - tt_out_tok = sample(tt_output_torch, temperature=0, top_p=0.8) - - # TODO argmax on device - # tt_out = ttnn.to_layout(tt_out, ttnn.ROW_MAJOR_LAYOUT) - # tt_out = ttnn.permute(tt_out, (2, 1, 0, 3)) - # tt_out = ttnn.reshape(tt_out, (tt_out.shape[0], tt_out.shape[2], tt_out.shape[3])) # Squeeze(1) - # tt_out_argmax = ttnn.argmax(tt_out, dim=-1) - # Typecast from bf16 to uint32 for embedding - # tt_out_tok = ttnn.clone(tt_out_argmax, ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.uint32) - # tt_out_tok = ttnn.experimental.tensor.typecast(tt_out_tok, dtype=ttnn.uint32) - - if iteration < input_mask.shape[1]: # If prefill - # If token is pad token, start generating new token, otherwise, push the next prompt token to the model - tt_out_tok = torch.where( - input_mask[:, iteration], pt_encoded_input[:, iteration], tt_out_tok[:, 0] - ).unsqueeze(1) - - # Save output token to print out later - for user in range(batch_size): - user_tok = tt_out_tok[user].tolist() - if user_tok[0] != 28803 and user_done[user] == False: # Stop saving the ouput after hitting the EOS token - all_outputs[user].append(user_tok[0]) + if embed_on_device and iteration > 0: + current_pos = curr_pos + decode_input = tt_decode_input else: - user_done[user] = True + # Prepare inputs for decode mode + decode_input, current_pos = prepare_inputs_ttnn( + tt_decode_input, + curr_pos, + model_args.dim, + model_args.sliding_window, + tt_model.device, + ) + + # Run ttnn llama model + tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) + tt_out = ttnn.untilize( + tt_out, use_multicore=False + ) # multi-core OOMs (https://github.com/tenstorrent/tt-metal/issues/9022) + tt_output_torch = ( + ttnn.to_torch(tt_out).permute(2, 1, 0, 3).squeeze(1)[: model_args.max_batch_size, :, :] + ) # [batch, seq, hidden_dim] + # Update rotation matrix for next iteration + current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) + # If temperature is 0, does greedy decoding (top-1) + tt_out_tok = sample(tt_output_torch, temperature=0, top_p=0.8) + + # TODO argmax on device + # tt_out = ttnn.to_layout(tt_out, ttnn.ROW_MAJOR_LAYOUT) + # tt_out = ttnn.permute(tt_out, (2, 1, 0, 3)) + # tt_out = ttnn.reshape(tt_out, (tt_out.shape[0], tt_out.shape[2], tt_out.shape[3])) # Squeeze(1) + # tt_out_argmax = ttnn.argmax(tt_out, dim=-1) + # Typecast from bf16 to uint32 for embedding + # tt_out_tok = ttnn.clone(tt_out_argmax, ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.uint32) + # tt_out_tok = ttnn.experimental.tensor.typecast(tt_out_tok, dtype=ttnn.uint32) + + if iteration < input_mask.shape[1]: # If prefill + # If token is pad token, start generating new token, otherwise, push the next prompt token to the model + tt_out_tok = torch.where( + input_mask[:, iteration], pt_encoded_input[:, iteration], tt_out_tok[:, 0] + ).unsqueeze(1) + + # Save output token to print out later + for user in range(batch_size): + user_tok = tt_out_tok[user].tolist() if ( - iteration < input_mask.shape[1] - ): # Still in prefill, so ignore EOS token and save the generated token - # all_outputs[user].append(user_tok[0]) - pass + user_tok[0] != 28803 and user_done[user] == False + ): # Stop saving the ouput after hitting the EOS token + all_outputs[user].append(user_tok[0]) else: - logger.trace(f"[User {user}] Finished decoding at iteration {iteration}") - if all(user_done): - users_decoding = False - - if embed_on_device: - tt_out_tok = ttnn.from_torch(tt_out_tok, device=device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT) - tt_decode_input = tt_embd(tt_out_tok) - else: - tt_decode_input = embd(tt_out_tok) - - # Print out generated outputs for each user at the end of every iteration - iteration_time = time() - iteration_time_start - tokens_per_second_per_user = 1 / iteration_time - if not is_ci_env: - if len(user_input) == 1: - logger.info("[User 0] {}".format("".join(tokenizer.decode(all_outputs[0])))) + user_done[user] = True + if ( + iteration < input_mask.shape[1] + ): # Still in prefill, so ignore EOS token and save the generated token + # all_outputs[user].append(user_tok[0]) + pass + else: + logger.trace(f"[User {user}] Finished decoding at iteration {iteration}") + if all(user_done): + users_decoding = False + + if embed_on_device: + # Pad tt_out_tok to batch size of 32 + padded_tt_out_tok = torch.zeros(1, 32, dtype=tt_out_tok.dtype, device=tt_out_tok.device) + padded_tt_out_tok[: tt_out_tok.shape[1]] = tt_out_tok + tt_out_tok = ttnn.from_torch( + padded_tt_out_tok, + device=device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + tt_decode_input = tt_embd(tt_out_tok) else: - for user in range(batch_size): - text = "".join(tokenizer.decode(all_outputs[user])) - if len(text) > 100: - text = "..." + text[-97:] - text = text.replace("\n", " ") - logger.info("[User {}] {}".format(user, text)) - - # Always print perf at every iteration - logger.info( - f"Iteration {iteration}: {1000*iteration_time:.0f}ms @ {tokens_per_second_per_user:.1f} tok/s/user ({batch_size*tokens_per_second_per_user:.1f} tok/s throughput)" - ) - - iteration += 1 - - # Upper limit of generated tokens for each user (to avoid infinite generation in case eos is not seen) - if iteration >= max_generated_tokens: - users_decoding = False + tt_decode_input = embd(tt_out_tok) + + # Print out generated outputs for each user at the end of every iteration + iteration_time = time() - iteration_time_start + tokens_per_second_per_user = 1 / iteration_time + if not is_ci_env: + if len(user_input) == 1: + logger.info("[User 0] {}".format("".join(tokenizer.decode(all_outputs[0])))) + else: + for user in range(batch_size): + text = "".join(tokenizer.decode(all_outputs[user])) + if len(text) > 100: + text = "..." + text[-97:] + text = text.replace("\n", " ") + logger.info("[User {}] {}".format(user, text)) + + # Always print perf at every iteration + logger.info( + f"Iteration {iteration}: {1000*iteration_time:.0f}ms @ {tokens_per_second_per_user:.1f} tok/s/user ({batch_size*tokens_per_second_per_user:.1f} tok/s throughput)" + ) + + iteration += 1 + + # Upper limit of generated tokens for each user (to avoid infinite generation in case eos is not seen) + if iteration >= max_generated_tokens: + users_decoding = False # In CI only print the final generated output to avoid spamming the logs if is_ci_env: diff --git a/models/demos/wormhole/llama31_8b/tt/llama_common.py b/models/demos/wormhole/llama31_8b/tt/llama_common.py index 473e96a7144..616c7c1b3c1 100644 --- a/models/demos/wormhole/llama31_8b/tt/llama_common.py +++ b/models/demos/wormhole/llama31_8b/tt/llama_common.py @@ -170,27 +170,33 @@ def prepare_inputs_ttnn(x, current_pos, hidden_size, sliding_window, device): start_pos: int """ - assert len(x.shape) == 3 - assert x.shape[2] == hidden_size + if len(x.shape) == 3: + batch = x.shape[0] + seq_len = x.shape[1] + assert x.shape[2] == hidden_size + elif len(x.shape) == 4: + seq_len = x.shape[0] + assert x.shape[1] == 1 + batch = x.shape[2] + assert x.shape[3] == hidden_size - batch = x.shape[0] - seq_len = x.shape[1] - hidden_size = x.shape[2] assert seq_len == 1, "Only supporting decode mode" # Support input on device if torch.is_tensor(x): # Input on host -> Use torch x = x.transpose(0, 1).unsqueeze(1) # [seq_len, 1, batch, hidden_dim] - else: # Input on device -> Use ttnn + # Pad small batches to 32 + if batch < 32: + zeros = torch.zeros(1, seq_len, 32, hidden_size) + zeros[:, :, :batch, :] = x + x = zeros + elif len(x.shape) == 3: # Input on device -> Use ttnn x = ttnn.reshape( x, (batch, seq_len, 1, hidden_size) ) # [batch, seqlen, hidden_dim] -> [batch, seqlen, 1, hidden_dim] x = ttnn.permute(x, (1, 2, 0, 3)) # [seq_len, 1, batch, hidden_dim] - # Pad small batches to 32 - if batch < 32: - zeros = torch.zeros(1, seq_len, 32, hidden_size) - zeros[:, :, :batch, :] = x - x = zeros + elif len(x.shape) == 4: + pass # already in [seq_len, 1, batch, hidden_dim] current = current_pos % sliding_window diff --git a/models/demos/wormhole/llama31_8b/tt/llama_embedding.py b/models/demos/wormhole/llama31_8b/tt/llama_embedding.py index be14cbcb36b..14a00704431 100644 --- a/models/demos/wormhole/llama31_8b/tt/llama_embedding.py +++ b/models/demos/wormhole/llama31_8b/tt/llama_embedding.py @@ -33,4 +33,6 @@ def __init__( ) def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: - return ttnn.embedding(x, self.weights) + x = ttnn.embedding(x, self.weights, layout=ttnn.TILE_LAYOUT) + x = ttnn.reshape(x, [x.shape[0], 1, x.shape[1], x.shape[2]]) + return x diff --git a/models/demos/wormhole/llama31_8b/tt/model_config.py b/models/demos/wormhole/llama31_8b/tt/model_config.py index 972d9a87337..062532aa075 100644 --- a/models/demos/wormhole/llama31_8b/tt/model_config.py +++ b/models/demos/wormhole/llama31_8b/tt/model_config.py @@ -100,6 +100,8 @@ def __init__(self, device, instruct=False, dummy_weights=False): # Enable workarounds by default until di/dt issues are fixed self.di_dt_workaround = os.getenv("DISABLE_DI_DT_WORKAROUND") != "1" + if not self.di_dt_workaround: + logger.info("Disabling di/dt workaround, re-enable if you see hangs") DRAM_MEMCFG = ttnn.DRAM_MEMORY_CONFIG L1_MEMCFG = ttnn.L1_MEMORY_CONFIG From 8320008ace5b75ee3a235a9d578e0ce6ef7740ba Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Thu, 12 Sep 2024 10:16:24 +0000 Subject: [PATCH 2/3] #12328: On-device embeddings with 17.7 t/s/u e2e no di/dt workaround --- models/demos/wormhole/llama31_8b/demo/demo.py | 197 +++++++++--------- .../llama31_8b/demo/demo_with_prefill.py | 32 ++- .../wormhole/llama31_8b/tt/llama_embedding.py | 4 +- 3 files changed, 121 insertions(+), 112 deletions(-) diff --git a/models/demos/wormhole/llama31_8b/demo/demo.py b/models/demos/wormhole/llama31_8b/demo/demo.py index a6ea7463b23..cd4cac961d8 100644 --- a/models/demos/wormhole/llama31_8b/demo/demo.py +++ b/models/demos/wormhole/llama31_8b/demo/demo.py @@ -105,7 +105,7 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env): embd.load_state_dict({"emb.weight": state_dict["tok_embeddings.weight"]}) generation_start_pos = 0 - max_generated_tokens = 5 + max_generated_tokens = 120 users_decoding = True # Preprocess initial prompt inputs @@ -149,114 +149,109 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env): all_outputs = [[] for _ in range(batch_size)] user_done = [False] * batch_size # Keeps track when a user reaches EoD token - from viztracer import VizTracer + iteration = 0 + # Keep running inference as long as there is a user in the batch still decoding or max tokens per user are decoded + while users_decoding: + iteration_time_start = time() + curr_pos = generation_start_pos + iteration - with VizTracer(output_file="llama_demo_trace.json"): - iteration = 0 - # Keep running inference as long as there is a user in the batch still decoding or max tokens per user are decoded - while users_decoding: - iteration_time_start = time() - curr_pos = generation_start_pos + iteration + if embed_on_device and iteration > 0: + current_pos = curr_pos + decode_input = tt_decode_input + else: + # Prepare inputs for decode mode + decode_input, current_pos = prepare_inputs_ttnn( + tt_decode_input, + curr_pos, + model_args.dim, + model_args.sliding_window, + tt_model.device, + ) - if embed_on_device and iteration > 0: - current_pos = curr_pos - decode_input = tt_decode_input + # Run ttnn llama model + tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) + tt_out = ttnn.untilize( + tt_out, use_multicore=False + ) # multi-core OOMs (https://github.com/tenstorrent/tt-metal/issues/9022) + tt_output_torch = ( + ttnn.to_torch(tt_out).permute(2, 1, 0, 3).squeeze(1)[: model_args.max_batch_size, :, :] + ) # [batch, seq, hidden_dim] + # Update rotation matrix for next iteration + current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) + # If temperature is 0, does greedy decoding (top-1) + tt_out_tok = sample(tt_output_torch, temperature=0, top_p=0.8) + + # TODO argmax on device + # tt_out = ttnn.to_layout(tt_out, ttnn.ROW_MAJOR_LAYOUT) + # tt_out = ttnn.permute(tt_out, (2, 1, 0, 3)) + # tt_out = ttnn.reshape(tt_out, (tt_out.shape[0], tt_out.shape[2], tt_out.shape[3])) # Squeeze(1) + # tt_out_argmax = ttnn.argmax(tt_out, dim=-1) + # Typecast from bf16 to uint32 for embedding + # tt_out_tok = ttnn.clone(tt_out_argmax, ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.uint32) + # tt_out_tok = ttnn.experimental.tensor.typecast(tt_out_tok, dtype=ttnn.uint32) + + if iteration < input_mask.shape[1]: # If prefill + # If token is pad token, start generating new token, otherwise, push the next prompt token to the model + tt_out_tok = torch.where( + input_mask[:, iteration], pt_encoded_input[:, iteration], tt_out_tok[:, 0] + ).unsqueeze(1) + + # Save output token to print out later + for user in range(batch_size): + user_tok = tt_out_tok[user].tolist() + if user_tok[0] != 28803 and user_done[user] == False: # Stop saving the ouput after hitting the EOS token + all_outputs[user].append(user_tok[0]) else: - # Prepare inputs for decode mode - decode_input, current_pos = prepare_inputs_ttnn( - tt_decode_input, - curr_pos, - model_args.dim, - model_args.sliding_window, - tt_model.device, - ) - - # Run ttnn llama model - tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) - tt_out = ttnn.untilize( - tt_out, use_multicore=False - ) # multi-core OOMs (https://github.com/tenstorrent/tt-metal/issues/9022) - tt_output_torch = ( - ttnn.to_torch(tt_out).permute(2, 1, 0, 3).squeeze(1)[: model_args.max_batch_size, :, :] - ) # [batch, seq, hidden_dim] - # Update rotation matrix for next iteration - current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) - # If temperature is 0, does greedy decoding (top-1) - tt_out_tok = sample(tt_output_torch, temperature=0, top_p=0.8) - - # TODO argmax on device - # tt_out = ttnn.to_layout(tt_out, ttnn.ROW_MAJOR_LAYOUT) - # tt_out = ttnn.permute(tt_out, (2, 1, 0, 3)) - # tt_out = ttnn.reshape(tt_out, (tt_out.shape[0], tt_out.shape[2], tt_out.shape[3])) # Squeeze(1) - # tt_out_argmax = ttnn.argmax(tt_out, dim=-1) - # Typecast from bf16 to uint32 for embedding - # tt_out_tok = ttnn.clone(tt_out_argmax, ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.uint32) - # tt_out_tok = ttnn.experimental.tensor.typecast(tt_out_tok, dtype=ttnn.uint32) - - if iteration < input_mask.shape[1]: # If prefill - # If token is pad token, start generating new token, otherwise, push the next prompt token to the model - tt_out_tok = torch.where( - input_mask[:, iteration], pt_encoded_input[:, iteration], tt_out_tok[:, 0] - ).unsqueeze(1) - - # Save output token to print out later - for user in range(batch_size): - user_tok = tt_out_tok[user].tolist() + user_done[user] = True if ( - user_tok[0] != 28803 and user_done[user] == False - ): # Stop saving the ouput after hitting the EOS token - all_outputs[user].append(user_tok[0]) - else: - user_done[user] = True - if ( - iteration < input_mask.shape[1] - ): # Still in prefill, so ignore EOS token and save the generated token - # all_outputs[user].append(user_tok[0]) - pass - else: - logger.trace(f"[User {user}] Finished decoding at iteration {iteration}") - if all(user_done): - users_decoding = False - - if embed_on_device: - # Pad tt_out_tok to batch size of 32 - padded_tt_out_tok = torch.zeros(1, 32, dtype=tt_out_tok.dtype, device=tt_out_tok.device) - padded_tt_out_tok[: tt_out_tok.shape[1]] = tt_out_tok - tt_out_tok = ttnn.from_torch( - padded_tt_out_tok, - device=device, - dtype=ttnn.uint32, - layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=ttnn.L1_MEMORY_CONFIG, - ) - tt_decode_input = tt_embd(tt_out_tok) - else: - tt_decode_input = embd(tt_out_tok) - - # Print out generated outputs for each user at the end of every iteration - iteration_time = time() - iteration_time_start - tokens_per_second_per_user = 1 / iteration_time - if not is_ci_env: - if len(user_input) == 1: - logger.info("[User 0] {}".format("".join(tokenizer.decode(all_outputs[0])))) + iteration < input_mask.shape[1] + ): # Still in prefill, so ignore EOS token and save the generated token + # all_outputs[user].append(user_tok[0]) + pass else: - for user in range(batch_size): - text = "".join(tokenizer.decode(all_outputs[user])) - if len(text) > 100: - text = "..." + text[-97:] - text = text.replace("\n", " ") - logger.info("[User {}] {}".format(user, text)) - - # Always print perf at every iteration - logger.info( - f"Iteration {iteration}: {1000*iteration_time:.0f}ms @ {tokens_per_second_per_user:.1f} tok/s/user ({batch_size*tokens_per_second_per_user:.1f} tok/s throughput)" + logger.trace(f"[User {user}] Finished decoding at iteration {iteration}") + if all(user_done): + users_decoding = False + + if embed_on_device: + # Pad tt_out_tok to batch size of 32 + padded_tt_out_tok = torch.zeros(1, 32, dtype=tt_out_tok.dtype, device=tt_out_tok.device) + padded_tt_out_tok[: tt_out_tok.shape[1]] = tt_out_tok + tt_out_tok = ttnn.from_torch( + padded_tt_out_tok, + device=device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.L1_MEMORY_CONFIG, ) + tt_decode_input = tt_embd(tt_out_tok) + else: + tt_decode_input = embd(tt_out_tok) + + # Print out generated outputs for each user at the end of every iteration + iteration_time = time() - iteration_time_start + tokens_per_second_per_user = 1 / iteration_time + if not is_ci_env: + if len(user_input) == 1: + logger.info("[User 0] {}".format("".join(tokenizer.decode(all_outputs[0])))) + else: + for user in range(batch_size): + text = "".join(tokenizer.decode(all_outputs[user])) + if len(text) > 100: + text = "..." + text[-97:] + text = text.replace("\n", " ") + logger.info("[User {}] {}".format(user, text)) + + # Always print perf at every iteration + logger.info( + f"Iteration {iteration}: {1000*iteration_time:.0f}ms @ {tokens_per_second_per_user:.1f} tok/s/user ({batch_size*tokens_per_second_per_user:.1f} tok/s throughput)" + ) - iteration += 1 + iteration += 1 - # Upper limit of generated tokens for each user (to avoid infinite generation in case eos is not seen) - if iteration >= max_generated_tokens: - users_decoding = False + # Upper limit of generated tokens for each user (to avoid infinite generation in case eos is not seen) + if iteration >= max_generated_tokens: + users_decoding = False # In CI only print the final generated output to avoid spamming the logs if is_ci_env: diff --git a/models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py b/models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py index eab5cd6b0ad..98d35d74995 100644 --- a/models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py +++ b/models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py @@ -125,7 +125,7 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num # This module requires the env paths above for CI runs from models.demos.wormhole.llama31_8b.tt.model_config import TtModelArgs - embed_on_device = False + embed_on_device = True dtype = ttnn.bfloat8_b # We disregard any warmup iteration for profiling, in favour of just measuring compile time on the first iteration @@ -341,15 +341,18 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num curr_pos = generation_start_pos + iteration # Prepare inputs for decode mode (rotary embeddings, attention mask, padding) - # TODO Move the attn mask to device profiler.start(f"prepare_input_decode", iteration=batch_idx) - decode_input, current_pos = prepare_inputs_ttnn( - pt_encoded_input, - curr_pos, - model_args.dim, - model_args.sliding_window, - tt_model.device, - ) + if embed_on_device and iteration > 0: + current_pos = curr_pos + decode_input = pt_encoded_input + else: + decode_input, current_pos = prepare_inputs_ttnn( + pt_encoded_input, + curr_pos, + model_args.dim, + model_args.sliding_window, + tt_model.device, + ) profiler.end(f"prepare_input_decode", iteration=batch_idx) profiler.start(f"decode_and_argmax", iteration=batch_idx) @@ -404,7 +407,16 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num profiler.start(f"decode_embedding", iteration=batch_idx) if embed_on_device: - tt_out_tok = ttnn.from_torch(tt_out_tok, device=device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT) + # Pad tt_out_tok to batch size of 32 + padded_tt_out_tok = torch.zeros(1, 32, dtype=tt_out_tok.dtype, device=tt_out_tok.device) + padded_tt_out_tok[: tt_out_tok.shape[1]] = tt_out_tok + tt_out_tok = ttnn.from_torch( + padded_tt_out_tok, + device=device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) pt_encoded_input = tt_embd(tt_out_tok) else: pt_encoded_input = embd(tt_out_tok) diff --git a/models/demos/wormhole/llama31_8b/tt/llama_embedding.py b/models/demos/wormhole/llama31_8b/tt/llama_embedding.py index 14a00704431..b966eebd4e6 100644 --- a/models/demos/wormhole/llama31_8b/tt/llama_embedding.py +++ b/models/demos/wormhole/llama31_8b/tt/llama_embedding.py @@ -33,6 +33,8 @@ def __init__( ) def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: - x = ttnn.embedding(x, self.weights, layout=ttnn.TILE_LAYOUT) + x = ttnn.embedding(x, self.weights, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG) x = ttnn.reshape(x, [x.shape[0], 1, x.shape[1], x.shape[2]]) + # x = ttnn.pad(x, padding=((0, 0), (0, 0), (0, 32-x.shape[2]), (0, 0)), value=0) + # x = ttnn.tilize(x, use_multicore=True) return x From a93ea767d25ce83063339bfab2b8a025580a805a Mon Sep 17 00:00:00 2001 From: mtairum Date: Thu, 12 Sep 2024 12:36:26 +0000 Subject: [PATCH 3/3] #12328: Removed token verification from Llama3.1-8B demo_with_prefill --- .../llama31_8b/demo/demo_with_prefill.py | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py b/models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py index 98d35d74995..75d460b5e0c 100644 --- a/models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py +++ b/models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py @@ -482,26 +482,6 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num profiler.end(f"inference_decode", iteration=batch_idx) - # When running in CI, check the output against the expected output to avoid accuracy regressions - # TODO Extend the expected output validation to further batches - if is_ci_env and batch_idx == 0: # Only check output of batch 0 - expected_output = "models/demos/wormhole/llama31_8b/demo/expected_outputs_prefill_128.json" - with open(expected_output, "r") as f: - expected_out = json.load(f) - # assert ( - # len(expected_out) >= batch_size * 2 - # ), f"expected_outputs.json should have {batch_size * 2} outputs: {batch_size} for general weights and {batch_size} for instruct weights!" - - for i in range(batch_size): - user_output = "".join(tokenizer.decode(all_outputs[i])) - if instruct_mode: # The instruct outputs are at the end of the expected outputs file - user_expect = expected_out[i + batch_size]["output_instruct"] - else: - user_expect = expected_out[i]["output_general"] - - assert user_output == user_expect, f"Output for user {i} does not match expected output!" - logger.info("[CI-Only] Output token validation passed!") - # Finish profiling at the end of all batches profiler.end("run") @@ -567,11 +547,6 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num decode_tsu = 33 targets = {"prefill_t/s": target_prefill_ts, "decode_t/s": target_decode_ts, "decode_t/s/u": decode_tsu} - # TODO move token verification here? - # if expected_greedy_output_path is not None: - # token_check_does_pass, expected_output = check_tokens_match(generated_text, expected_greedy_output_path) - # measurements["token_verification"] = float(token_check_does_pass) - # Save benchmark data for CI dashboard if is_ci_env and is_n300: benchmark_data = create_benchmark_data(profiler, measurements, N_warmup_iter, targets)