Skip to content

Commit

Permalink
#10855: Add single test device perf measurements to sweeps
Browse files Browse the repository at this point in the history
  • Loading branch information
jdesousa-TT committed Sep 6, 2024
1 parent 558f2cc commit 6ab016e
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
66 changes: 63 additions & 3 deletions tests/sweep_framework/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import datetime
import os
import enlighten
from tt_metal.tools.profiler.process_ops_logs import get_device_data_generate_report, PROFILER_LOGS_DIR
from multiprocessing import Process, Queue
from queue import Empty
import subprocess
Expand Down Expand Up @@ -43,7 +44,23 @@ def get_devices(test_module):
return default_device()


def gather_single_test_perf(device):
import tt_lib as ttl

ttl.device.DumpDeviceProfiler(device)
opPerfData = get_device_data_generate_report(
PROFILER_LOGS_DIR, None, None, None, export_csv=False, cleanup_device_log=True
)
if len(opPerfData) != 1:
print("SWEEPS: Composite op detected in device perf measurement. Failing.")
return None
else:
return opPerfData[0]


def run(test_module, input_queue, output_queue):
import tt_lib as ttl

device_generator = get_devices(test_module)
try:
device, device_name = next(device_generator)
Expand All @@ -66,7 +83,11 @@ def run(test_module, input_queue, output_queue):
except Exception as e:
status, message = False, str(e)
e2e_perf = None
output_queue.put([status, message, e2e_perf])
if MEASURE_DEVICE_PERF and status:
perf_result = gather_single_test_perf(device)
output_queue.put([status, message, e2e_perf, perf_result])
else:
output_queue.put([status, message, e2e_perf, None])
except Empty as e:
try:
# Run teardown in mesh_device_fixture
Expand Down Expand Up @@ -123,8 +144,15 @@ def execute_suite(test_module, test_vectors, pbar_manager, suite_name):
)
run(test_module, input_queue, output_queue)
response = output_queue.get(block=True, timeout=timeout)
status, message, e2e_perf = response[0], response[1], response[2]
if status:
status, message, e2e_perf, device_perf = response[0], response[1], response[2], response[3]
if status and MEASURE_DEVICE_PERF and device_perf is None:
result["status"] = TestStatus.FAIL_COMPOSITE_OP_PERF
result["message"] = message
elif status and MEASURE_DEVICE_PERF:
result["status"] = TestStatus.PASS
result["message"] = message
result["device_perf"] = device_perf
elif status:
result["status"] = TestStatus.PASS
result["message"] = message
else:
Expand Down Expand Up @@ -159,6 +187,7 @@ def execute_suite(test_module, test_vectors, pbar_manager, suite_name):

suite_pbar.update()
results.append(result)

if p is not None:
p.join()

Expand Down Expand Up @@ -328,6 +357,9 @@ def export_test_results(header_info, results):
for i in range(len(results)):
result = header_info[i]
for elem in results[i].keys():
if elem == "device_perf":
result[elem] = results[i][elem]
continue
result[elem] = serialize(results[i][elem])
client.index(index=results_index, body=result)

Expand All @@ -346,6 +378,18 @@ def disable_watcher():
os.environ.pop("TT_METAL_WATCHER_APPEND")


def enable_profiler():
print("SWEEPS: Enabling Device Profiler")
os.environ["TT_METAL_DEVICE_PROFILER"] = "1"
os.environ["ENABLE_TRACY"] = "1"


def disable_profiler():
print("SWEEPS: Disabling Device Profiler")
os.environ.pop("TT_METAL_DEVICE_PROFILER")
os.environ.pop("ENABLE_TRACY")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="Sweep Test Runner",
Expand Down Expand Up @@ -373,6 +417,13 @@ def disable_watcher():
help="Add this flag to measure e2e perf, for op tests with performance markers.",
)

parser.add_argument(
"--device-perf",
required=False,
action="store_true",
help="Measure device perf using device profiler. REQUIRES PROFILER BUILD!",
)

parser.add_argument(
"--dry-run",
action="store_true",
Expand All @@ -398,6 +449,9 @@ def disable_watcher():
global MEASURE_PERF
MEASURE_PERF = args.perf

global MEASURE_DEVICE_PERF
MEASURE_DEVICE_PERF = args.device_perf

global DRY_RUN
DRY_RUN = args.dry_run

Expand All @@ -409,10 +463,16 @@ def disable_watcher():
if args.watcher:
enable_watcher()

if MEASURE_DEVICE_PERF:
enable_profiler()

from ttnn import *
from serialize import *

run_sweeps(args.module_name, args.suite_name, args.vector_id)

if args.watcher:
disable_watcher()

if MEASURE_DEVICE_PERF:
disable_profiler()
1 change: 1 addition & 0 deletions tests/sweep_framework/statuses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class TestStatus(Enum):
NOT_RUN = 3
FAIL_L1_OUT_OF_MEM = 4
FAIL_WATCHER = 5
FAIL_COMPOSITE_OP_PERF = 6


class VectorValidity(Enum):
Expand Down

0 comments on commit 6ab016e

Please sign in to comment.