Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microTVM] Rework evaluate_model_accuracy into a more generic helper function #12539

Merged
merged 2 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/micro/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
# under the License.

"""Allows the tools specified below to be imported directly from tvm.micro.testing"""
from .evaluation import tune_model, create_aot_session, evaluate_model_accuracy
from .evaluation import tune_model, create_aot_session, predict_labels_aot
from .utils import get_supported_boards, get_target
21 changes: 6 additions & 15 deletions python/tvm/micro/testing/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,27 +133,18 @@ def create_aot_session(
return tvm.micro.Session(project.transport(), timeout_override=timeout_override)


# This utility functions was designed ONLY for one input / one output models
# where the outputs are confidences for different classes.
def evaluate_model_accuracy(session, aot_executor, input_data, true_labels, runs_per_sample=1):
"""Evaluates an AOT-compiled model's accuracy and runtime over an RPC session. Works well
when used with create_aot_session."""
def predict_labels_aot(session, aot_executor, input_data, runs_per_sample=1):
"""Predicts labels for each sample in input_data using host-driven AOT.
Returns an iterator of (label, runtime) tuples. This function can only
be used with models for which the output is the confidence for each class."""

assert aot_executor.get_num_inputs() == 1
assert aot_executor.get_num_outputs() == 1
assert runs_per_sample > 0

predicted_labels = []
aot_runtimes = []
for sample in input_data:
aot_executor.get_input(0).copyfrom(sample)
result = aot_executor.module.time_evaluator("run", session.device, number=runs_per_sample)()
predicted_label = aot_executor.get_output(0).numpy().argmax()
runtime = result.mean
output = aot_executor.get_output(0).numpy()
predicted_labels.append(output.argmax())
aot_runtimes.append(runtime)

num_correct = sum(u == v for u, v in zip(true_labels, predicted_labels))
average_time = sum(aot_runtimes) / len(aot_runtimes)
accuracy = num_correct / len(predicted_labels)
return average_time, accuracy, predicted_labels
yield predicted_label, runtime
13 changes: 7 additions & 6 deletions tests/micro/common/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,18 @@ def test_kws_autotune_workflow(platform, board, tmp_path):
np.random.randint(low=-127, high=128, size=(1, 1960), dtype=np.int8) for x in range(3)
)

labels = [0, 0, 0]

# Validate perforance across random runs
time, _, _ = tvm.micro.testing.evaluate_model_accuracy(
session, aot_executor, samples, labels, runs_per_sample=20
)
runtimes = [
runtime
for _, runtime in tvm.micro.testing.predict_labels_aot(
session, aot_executor, samples, runs_per_sample=20
)
]
# `time` is the average time taken to execute model inference on the
# device, measured in seconds. It does not include the time to upload
# the input data via RPC. On slow boards like the Arduino Due, time
# is around 0.12 (120 ms), so this gives us plenty of buffer.
assert time < 1
assert np.median(runtimes) < 1


if __name__ == "__main__":
Expand Down