Skip to content

Commit

Permalink
#6344: Update RoBERTa QA demo
Browse files Browse the repository at this point in the history
  • Loading branch information
kkeerthana0573 committed Jul 2, 2024
1 parent 3687a43 commit 224d4b1
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 9 deletions.
2 changes: 1 addition & 1 deletion models/demos/bert/tt/ttnn_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def preprocess_inputs(
position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)

if attention_mask is not None:
attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape)
attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape, torch.float32)
attention_mask = attention_mask.expand((batch_size, -1, -1, -1))
attention_mask = torch.clamp(attention_mask, min=-100000)
attention_mask = ttnn.from_torch(
Expand Down
2 changes: 1 addition & 1 deletion models/demos/bert/tt/ttnn_optimized_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def preprocess_inputs(
position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)

if attention_mask is not None:
attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape)
attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape, torch.float32)
attention_mask = attention_mask.expand((batch_size, -1, -1, -1))
attention_mask = torch.clamp(attention_mask, min=-100000)
attention_mask = ttnn.from_torch(
Expand Down
26 changes: 19 additions & 7 deletions models/experimental/functional_roberta/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
import torch
from loguru import logger
import tt_lib
import transformers
import ttnn
import evaluate
Expand All @@ -15,8 +14,7 @@
disable_persistent_kernel_cache,
profiler,
)
from models.demos.bert.tt import ttnn_bert
from models.demos.bert.tt import ttnn_optimized_bert
from models.demos.bert.tt import ttnn_bert, ttnn_optimized_bert

from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch
from ttnn.model_preprocessing import (
Expand All @@ -43,6 +41,12 @@ def load_inputs(input_path, batch):
return context, question


def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx


def run_roberta_question_and_answering_inference(
device,
use_program_cache,
Expand Down Expand Up @@ -106,10 +110,14 @@ def run_roberta_question_and_answering_inference(

profiler.start(f"preprocessing_input")

position_ids = create_position_ids_from_input_ids(
input_ids=roberta_input.input_ids, padding_idx=config.pad_token_id
)
ttnn_roberta_inputs = bert.preprocess_inputs(
roberta_input["input_ids"],
roberta_input["token_type_ids"],
torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None,
position_ids,
roberta_input["attention_mask"],
device=device,
)
profiler.end(f"preprocessing_input")
Expand Down Expand Up @@ -209,10 +217,14 @@ def run_roberta_question_and_answering_inference_squad_v2(
if i < n_iterations:
batch_data = batch[0]
curr_batch_size = batch_data["input_ids"].shape[0]
position_ids = create_position_ids_from_input_ids(
input_ids=batch_data.input_ids, padding_idx=config.pad_token_id
)
ttnn_roberta_inputs = bert.preprocess_inputs(
batch_data["input_ids"],
batch_data["token_type_ids"],
torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None,
position_ids,
batch_data["attention_mask"],
device=device,
)

Expand Down Expand Up @@ -255,7 +267,7 @@ def run_roberta_question_and_answering_inference_squad_v2(


@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"])
@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert])
@pytest.mark.parametrize("bert", [ttnn_optimized_bert, ttnn_bert])
def test_demo(
input_path,
model_name,
Expand All @@ -278,7 +290,7 @@ def test_demo(


@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"])
@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert])
@pytest.mark.parametrize("bert", [ttnn_optimized_bert, ttnn_bert])
@pytest.mark.parametrize(
"n_iterations",
((3),),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
from models.utility_functions import skip_for_wormhole_b0
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report


@skip_for_wormhole_b0()
@pytest.mark.models_device_performance_bare_metal
@pytest.mark.parametrize(
"batch_size, test, expected_perf",
[
# [8, "3-models.demos.bert.tt.ttnn_optimized_bert-deepset/roberta-large-squad2", 159],
[8, "384-8-deepset/roberta-large-squad2", 159],
],
)
def test_perf_device_bare_metal(batch_size, test, expected_perf):
subdir = "ttnn_roberta"
num_iterations = 3
margin = 0.03
# command = f"pytest models/experimental/functional_roberta/demo/demo.py::test_demo_squadv2[{test}]"
command = f"pytest models/experimental/functional_roberta/tests/test_ttnn_optimized_roberta.py::test_roberta_for_question_answering[{test}]"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]

inference_time_key = "AVG DEVICE KERNEL SAMPLES/S"
expected_perf_cols = {inference_time_key: expected_perf}

post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size)
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols)
prep_device_perf_report(
model_name=f"ttnn_roberta_{batch_size}",
batch_size=batch_size,
post_processed_results=post_processed_results,
expected_results=expected_results,
comments=test.replace("/", "_"),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
import torch
import pytest
import tt_lib
import transformers

from models.demos.bert.tt import ttnn_optimized_bert, ttnn_bert
from ttnn.model_preprocessing import preprocess_model_parameters
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_wormhole_b0

from transformers import RobertaForQuestionAnswering, RobertaConfig


@skip_for_wormhole_b0()
@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("sequence_size", [384])
def test_roberta_for_question_answering(device, use_program_cache, reset_seeds, model_name, batch_size, sequence_size):
config = RobertaConfig.from_pretrained(model_name)
model = RobertaForQuestionAnswering.from_pretrained(model_name)

input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32)
torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32)
torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32)
torch_attention_mask = torch.ones(batch_size, sequence_size)

torch_output = model(
input_ids=input_ids,
attention_mask=torch_attention_mask,
token_type_ids=torch_token_type_ids,
position_ids=torch_position_ids,
)
torch_output_start_logits = torch_output.start_logits
torch_output_end_logits = torch_output.end_logits

tt_model_name = f"ttnn_{model_name}_optimized"

parameters = preprocess_model_parameters(
model_name=tt_model_name,
initialize_model=lambda: transformers.RobertaForQuestionAnswering.from_pretrained(
model_name, torchscript=False
).eval(),
custom_preprocessor=ttnn_optimized_bert.custom_preprocessor,
device=device,
)

ttnn_roberta_inputs = ttnn_optimized_bert.preprocess_inputs(
input_ids,
torch_token_type_ids,
torch_position_ids,
torch_attention_mask,
device=device,
)

tt_output = ttnn_optimized_bert.bert_for_question_answering(
config,
*ttnn_roberta_inputs,
parameters=parameters,
name="roberta",
)
tt_output = ttnn.to_torch(tt_output)

tt_output_start_logits = tt_output[..., :, 0]
tt_output_end_logits = tt_output[..., :, 1]

assert_with_pcc(torch_output_start_logits, tt_output_start_logits, 0.4505)
assert_with_pcc(torch_output_end_logits, tt_output_end_logits, 0.4590)
2 changes: 2 additions & 0 deletions tests/scripts/run_performance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ run_device_perf_models() {
env pytest "tests/ttnn/integration_tests/resnet/test_performance.py" -m $test_marker

env pytest models/demos/resnet/tests -m $test_marker

env pytest models/experimental/functional_roberta/tests -m $test_marker
fi

if [ "$tt_arch" == "wormhole_b0" ]; then
Expand Down

0 comments on commit 224d4b1

Please sign in to comment.