diff --git a/models/demos/bert/tt/ttnn_bert.py b/models/demos/bert/tt/ttnn_bert.py index 23c480cd3d43..10b0c7958f4c 100644 --- a/models/demos/bert/tt/ttnn_bert.py +++ b/models/demos/bert/tt/ttnn_bert.py @@ -232,7 +232,7 @@ def preprocess_inputs( position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) if attention_mask is not None: - attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape) + attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape, torch.float32) attention_mask = attention_mask.expand((batch_size, -1, -1, -1)) attention_mask = torch.clamp(attention_mask, min=-100000) attention_mask = ttnn.from_torch( diff --git a/models/demos/bert/tt/ttnn_optimized_bert.py b/models/demos/bert/tt/ttnn_optimized_bert.py index af21f25e7e27..089e755538da 100644 --- a/models/demos/bert/tt/ttnn_optimized_bert.py +++ b/models/demos/bert/tt/ttnn_optimized_bert.py @@ -286,7 +286,7 @@ def preprocess_inputs( position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) if attention_mask is not None: - attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape) + attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape, torch.float32) attention_mask = attention_mask.expand((batch_size, -1, -1, -1)) attention_mask = torch.clamp(attention_mask, min=-100000) attention_mask = ttnn.from_torch( diff --git a/models/experimental/functional_roberta/demo/demo.py b/models/experimental/functional_roberta/demo/demo.py index 7e4dc9a18908..d56337c7513b 100644 --- a/models/experimental/functional_roberta/demo/demo.py +++ b/models/experimental/functional_roberta/demo/demo.py @@ -6,7 +6,6 @@ import pytest import torch from loguru import logger -import tt_lib import transformers import ttnn import evaluate @@ -15,8 +14,7 @@ disable_persistent_kernel_cache, profiler, ) -from models.demos.bert.tt import ttnn_bert -from models.demos.bert.tt import ttnn_optimized_bert +from models.demos.bert.tt import ttnn_bert, ttnn_optimized_bert from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch from ttnn.model_preprocessing import ( @@ -43,6 +41,12 @@ def load_inputs(input_path, batch): return context, question +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + def run_roberta_question_and_answering_inference( device, use_program_cache, @@ -106,10 +110,14 @@ def run_roberta_question_and_answering_inference( profiler.start(f"preprocessing_input") + position_ids = create_position_ids_from_input_ids( + input_ids=roberta_input.input_ids, padding_idx=config.pad_token_id + ) ttnn_roberta_inputs = bert.preprocess_inputs( roberta_input["input_ids"], roberta_input["token_type_ids"], - torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None, + position_ids, + roberta_input["attention_mask"], device=device, ) profiler.end(f"preprocessing_input") @@ -209,10 +217,14 @@ def run_roberta_question_and_answering_inference_squad_v2( if i < n_iterations: batch_data = batch[0] curr_batch_size = batch_data["input_ids"].shape[0] + position_ids = create_position_ids_from_input_ids( + input_ids=batch_data.input_ids, padding_idx=config.pad_token_id + ) ttnn_roberta_inputs = bert.preprocess_inputs( batch_data["input_ids"], batch_data["token_type_ids"], - torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None, + position_ids, + batch_data["attention_mask"], device=device, ) @@ -255,7 +267,7 @@ def run_roberta_question_and_answering_inference_squad_v2( @pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) -@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert]) +@pytest.mark.parametrize("bert", [ttnn_optimized_bert, ttnn_bert]) def test_demo( input_path, model_name, @@ -278,7 +290,7 @@ def test_demo( @pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) -@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert]) +@pytest.mark.parametrize("bert", [ttnn_optimized_bert, ttnn_bert]) @pytest.mark.parametrize( "n_iterations", ((3),), diff --git a/models/experimental/functional_roberta/tests/test_perf_device_roberta.py b/models/experimental/functional_roberta/tests/test_perf_device_roberta.py new file mode 100644 index 000000000000..6868e16a8563 --- /dev/null +++ b/models/experimental/functional_roberta/tests/test_perf_device_roberta.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from models.utility_functions import skip_for_wormhole_b0 +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report + + +@skip_for_wormhole_b0() +@pytest.mark.models_device_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, test, expected_perf", + [ + # [8, "3-models.demos.bert.tt.ttnn_optimized_bert-deepset/roberta-large-squad2", 159], + [8, "384-8-deepset/roberta-large-squad2", 159], + ], +) +def test_perf_device_bare_metal(batch_size, test, expected_perf): + subdir = "ttnn_roberta" + num_iterations = 3 + margin = 0.03 + # command = f"pytest models/experimental/functional_roberta/demo/demo.py::test_demo_squadv2[{test}]" + command = f"pytest models/experimental/functional_roberta/tests/test_ttnn_optimized_roberta.py::test_roberta_for_question_answering[{test}]" + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + prep_device_perf_report( + model_name=f"ttnn_roberta_{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments=test.replace("/", "_"), + ) diff --git a/models/experimental/functional_roberta/tests/test_ttnn_optimized_roberta.py b/models/experimental/functional_roberta/tests/test_ttnn_optimized_roberta.py new file mode 100644 index 000000000000..b6b1201169be --- /dev/null +++ b/models/experimental/functional_roberta/tests/test_ttnn_optimized_roberta.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +import pytest +import tt_lib +import transformers + +from models.demos.bert.tt import ttnn_optimized_bert, ttnn_bert +from ttnn.model_preprocessing import preprocess_model_parameters +from tests.ttnn.utils_for_testing import assert_with_pcc +from models.utility_functions import skip_for_wormhole_b0 + +from transformers import RobertaForQuestionAnswering, RobertaConfig + + +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +def test_roberta_for_question_answering(device, use_program_cache, reset_seeds, model_name, batch_size, sequence_size): + config = RobertaConfig.from_pretrained(model_name) + model = RobertaForQuestionAnswering.from_pretrained(model_name) + + input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.ones(batch_size, sequence_size) + + torch_output = model( + input_ids=input_ids, + attention_mask=torch_attention_mask, + token_type_ids=torch_token_type_ids, + position_ids=torch_position_ids, + ) + torch_output_start_logits = torch_output.start_logits + torch_output_end_logits = torch_output.end_logits + + tt_model_name = f"ttnn_{model_name}_optimized" + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: transformers.RobertaForQuestionAnswering.from_pretrained( + model_name, torchscript=False + ).eval(), + custom_preprocessor=ttnn_optimized_bert.custom_preprocessor, + device=device, + ) + + ttnn_roberta_inputs = ttnn_optimized_bert.preprocess_inputs( + input_ids, + torch_token_type_ids, + torch_position_ids, + torch_attention_mask, + device=device, + ) + + tt_output = ttnn_optimized_bert.bert_for_question_answering( + config, + *ttnn_roberta_inputs, + parameters=parameters, + name="roberta", + ) + tt_output = ttnn.to_torch(tt_output) + + tt_output_start_logits = tt_output[..., :, 0] + tt_output_end_logits = tt_output[..., :, 1] + + assert_with_pcc(torch_output_start_logits, tt_output_start_logits, 0.4505) + assert_with_pcc(torch_output_end_logits, tt_output_end_logits, 0.4590) diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index d0acdbe7e180..673f3c12861e 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -77,6 +77,8 @@ run_device_perf_models() { env pytest "tests/ttnn/integration_tests/resnet/test_performance.py" -m $test_marker env pytest models/demos/resnet/tests -m $test_marker + + env pytest models/experimental/functional_roberta/tests -m $test_marker fi if [ "$tt_arch" == "wormhole_b0" ]; then