diff --git a/tests/sweep_framework/runner.py b/tests/sweep_framework/runner.py index 5aeef16c037..b31ebfe2ea9 100644 --- a/tests/sweep_framework/runner.py +++ b/tests/sweep_framework/runner.py @@ -9,12 +9,12 @@ import datetime import os import enlighten +from tt_metal.tools.profiler.process_ops_logs import get_device_data_generate_report, PROFILER_LOGS_DIR from multiprocessing import Process, Queue from queue import Empty import subprocess from statuses import TestStatus, VectorValidity, VectorStatus import tt_smi_util -from device_fixtures import default_device from elasticsearch import Elasticsearch, NotFoundError from elastic_config import * @@ -43,6 +43,21 @@ def get_devices(test_module): return default_device() +def gather_single_test_perf(device): + if not isinstance(device, ttnn.Device): + print("SWEEPS: Multi-device perf is not supported. Failing.") + return None + ttnn.DumpDeviceProfiler(device) + opPerfData = get_device_data_generate_report( + PROFILER_LOGS_DIR, None, None, None, export_csv=False, cleanup_device_log=True + ) + if len(opPerfData) != 1: + print("SWEEPS: Composite op detected in device perf measurement. Composite op perf is not supported. Failing.") + return None + else: + return opPerfData[0] + + def run(test_module, input_queue, output_queue): device_generator = get_devices(test_module) try: @@ -66,7 +81,11 @@ def run(test_module, input_queue, output_queue): except Exception as e: status, message = False, str(e) e2e_perf = None - output_queue.put([status, message, e2e_perf]) + if MEASURE_DEVICE_PERF and status: + perf_result = gather_single_test_perf(device) + output_queue.put([status, message, e2e_perf, perf_result]) + else: + output_queue.put([status, message, e2e_perf, None]) except Empty as e: try: # Run teardown in mesh_device_fixture @@ -123,8 +142,15 @@ def execute_suite(test_module, test_vectors, pbar_manager, suite_name): ) run(test_module, input_queue, output_queue) response = output_queue.get(block=True, timeout=timeout) - status, message, e2e_perf = response[0], response[1], response[2] - if status: + status, message, e2e_perf, device_perf = response[0], response[1], response[2], response[3] + if status and MEASURE_DEVICE_PERF and device_perf is None: + result["status"] = TestStatus.FAIL_UNSUPPORTED_DEVICE_PERF + result["message"] = message + elif status and MEASURE_DEVICE_PERF: + result["status"] = TestStatus.PASS + result["message"] = message + result["device_perf"] = device_perf + elif status: result["status"] = TestStatus.PASS result["message"] = message else: @@ -159,6 +185,7 @@ def execute_suite(test_module, test_vectors, pbar_manager, suite_name): suite_pbar.update() results.append(result) + if p is not None: p.join() @@ -328,6 +355,9 @@ def export_test_results(header_info, results): for i in range(len(results)): result = header_info[i] for elem in results[i].keys(): + if elem == "device_perf": + result[elem] = results[i][elem] + continue result[elem] = serialize(results[i][elem]) client.index(index=results_index, body=result) @@ -346,6 +376,18 @@ def disable_watcher(): os.environ.pop("TT_METAL_WATCHER_APPEND") +def enable_profiler(): + print("SWEEPS: Enabling Device Profiler") + os.environ["TT_METAL_DEVICE_PROFILER"] = "1" + os.environ["ENABLE_TRACY"] = "1" + + +def disable_profiler(): + print("SWEEPS: Disabling Device Profiler") + os.environ.pop("TT_METAL_DEVICE_PROFILER") + os.environ.pop("ENABLE_TRACY") + + if __name__ == "__main__": parser = argparse.ArgumentParser( prog="Sweep Test Runner", @@ -373,6 +415,13 @@ def disable_watcher(): help="Add this flag to measure e2e perf, for op tests with performance markers.", ) + parser.add_argument( + "--device-perf", + required=False, + action="store_true", + help="Measure device perf using device profiler. REQUIRES PROFILER BUILD!", + ) + parser.add_argument( "--dry-run", action="store_true", @@ -398,6 +447,9 @@ def disable_watcher(): global MEASURE_PERF MEASURE_PERF = args.perf + global MEASURE_DEVICE_PERF + MEASURE_DEVICE_PERF = args.device_perf + global DRY_RUN DRY_RUN = args.dry_run @@ -409,10 +461,17 @@ def disable_watcher(): if args.watcher: enable_watcher() + if MEASURE_DEVICE_PERF: + enable_profiler() + from ttnn import * from serialize import * + from device_fixtures import default_device run_sweeps(args.module_name, args.suite_name, args.vector_id) if args.watcher: disable_watcher() + + if MEASURE_DEVICE_PERF: + disable_profiler() diff --git a/tests/sweep_framework/statuses.py b/tests/sweep_framework/statuses.py index 2a29beb2679..cf683e19b23 100644 --- a/tests/sweep_framework/statuses.py +++ b/tests/sweep_framework/statuses.py @@ -12,6 +12,7 @@ class TestStatus(Enum): NOT_RUN = 3 FAIL_L1_OUT_OF_MEM = 4 FAIL_WATCHER = 5 + FAIL_UNSUPPORTED_DEVICE_PERF = 6 class VectorValidity(Enum):