diff --git a/tests/sweep_framework/runner.py b/tests/sweep_framework/runner.py index 5aeef16c037..b31ebfe2ea9 100644 --- a/tests/sweep_framework/runner.py +++ b/tests/sweep_framework/runner.py @@ -9,12 +9,12 @@ import datetime import os import enlighten +from tt_metal.tools.profiler.process_ops_logs import get_device_data_generate_report, PROFILER_LOGS_DIR from multiprocessing import Process, Queue from queue import Empty import subprocess from statuses import TestStatus, VectorValidity, VectorStatus import tt_smi_util -from device_fixtures import default_device from elasticsearch import Elasticsearch, NotFoundError from elastic_config import * @@ -43,6 +43,21 @@ def get_devices(test_module): return default_device() +def gather_single_test_perf(device): + if not isinstance(device, ttnn.Device): + print("SWEEPS: Multi-device perf is not supported. Failing.") + return None + ttnn.DumpDeviceProfiler(device) + opPerfData = get_device_data_generate_report( + PROFILER_LOGS_DIR, None, None, None, export_csv=False, cleanup_device_log=True + ) + if len(opPerfData) != 1: + print("SWEEPS: Composite op detected in device perf measurement. Composite op perf is not supported. Failing.") + return None + else: + return opPerfData[0] + + def run(test_module, input_queue, output_queue): device_generator = get_devices(test_module) try: @@ -66,7 +81,11 @@ def run(test_module, input_queue, output_queue): except Exception as e: status, message = False, str(e) e2e_perf = None - output_queue.put([status, message, e2e_perf]) + if MEASURE_DEVICE_PERF and status: + perf_result = gather_single_test_perf(device) + output_queue.put([status, message, e2e_perf, perf_result]) + else: + output_queue.put([status, message, e2e_perf, None]) except Empty as e: try: # Run teardown in mesh_device_fixture @@ -123,8 +142,15 @@ def execute_suite(test_module, test_vectors, pbar_manager, suite_name): ) run(test_module, input_queue, output_queue) response = output_queue.get(block=True, timeout=timeout) - status, message, e2e_perf = response[0], response[1], response[2] - if status: + status, message, e2e_perf, device_perf = response[0], response[1], response[2], response[3] + if status and MEASURE_DEVICE_PERF and device_perf is None: + result["status"] = TestStatus.FAIL_UNSUPPORTED_DEVICE_PERF + result["message"] = message + elif status and MEASURE_DEVICE_PERF: + result["status"] = TestStatus.PASS + result["message"] = message + result["device_perf"] = device_perf + elif status: result["status"] = TestStatus.PASS result["message"] = message else: @@ -159,6 +185,7 @@ def execute_suite(test_module, test_vectors, pbar_manager, suite_name): suite_pbar.update() results.append(result) + if p is not None: p.join() @@ -328,6 +355,9 @@ def export_test_results(header_info, results): for i in range(len(results)): result = header_info[i] for elem in results[i].keys(): + if elem == "device_perf": + result[elem] = results[i][elem] + continue result[elem] = serialize(results[i][elem]) client.index(index=results_index, body=result) @@ -346,6 +376,18 @@ def disable_watcher(): os.environ.pop("TT_METAL_WATCHER_APPEND") +def enable_profiler(): + print("SWEEPS: Enabling Device Profiler") + os.environ["TT_METAL_DEVICE_PROFILER"] = "1" + os.environ["ENABLE_TRACY"] = "1" + + +def disable_profiler(): + print("SWEEPS: Disabling Device Profiler") + os.environ.pop("TT_METAL_DEVICE_PROFILER") + os.environ.pop("ENABLE_TRACY") + + if __name__ == "__main__": parser = argparse.ArgumentParser( prog="Sweep Test Runner", @@ -373,6 +415,13 @@ def disable_watcher(): help="Add this flag to measure e2e perf, for op tests with performance markers.", ) + parser.add_argument( + "--device-perf", + required=False, + action="store_true", + help="Measure device perf using device profiler. REQUIRES PROFILER BUILD!", + ) + parser.add_argument( "--dry-run", action="store_true", @@ -398,6 +447,9 @@ def disable_watcher(): global MEASURE_PERF MEASURE_PERF = args.perf + global MEASURE_DEVICE_PERF + MEASURE_DEVICE_PERF = args.device_perf + global DRY_RUN DRY_RUN = args.dry_run @@ -409,10 +461,17 @@ def disable_watcher(): if args.watcher: enable_watcher() + if MEASURE_DEVICE_PERF: + enable_profiler() + from ttnn import * from serialize import * + from device_fixtures import default_device run_sweeps(args.module_name, args.suite_name, args.vector_id) if args.watcher: disable_watcher() + + if MEASURE_DEVICE_PERF: + disable_profiler() diff --git a/tests/sweep_framework/statuses.py b/tests/sweep_framework/statuses.py index 2a29beb2679..cf683e19b23 100644 --- a/tests/sweep_framework/statuses.py +++ b/tests/sweep_framework/statuses.py @@ -12,6 +12,7 @@ class TestStatus(Enum): NOT_RUN = 3 FAIL_L1_OUT_OF_MEM = 4 FAIL_WATCHER = 5 + FAIL_UNSUPPORTED_DEVICE_PERF = 6 class VectorValidity(Enum): diff --git a/tt_metal/tools/profiler/process_ops_logs.py b/tt_metal/tools/profiler/process_ops_logs.py index 82f0c465431..3159fa4b2cb 100755 --- a/tt_metal/tools/profiler/process_ops_logs.py +++ b/tt_metal/tools/profiler/process_ops_logs.py @@ -231,7 +231,9 @@ def append_device_data(ops, deviceLogFolder): return deviceOps -def get_device_data_generate_report(deviceLogFolder, outputFolder, date, nameAppend): +def get_device_data_generate_report( + deviceLogFolder, outputFolder, date, nameAppend, export_csv=True, cleanup_device_log=False +): deviceTimesLog = os.path.join(deviceLogFolder, PROFILER_DEVICE_SIDE_LOG) devicePreOpTime = {} deviceOps = {} @@ -254,11 +256,12 @@ def get_device_data_generate_report(deviceLogFolder, outputFolder, date, nameApp name += f"_{dateStr}" outFolder = os.path.join(outFolder, dateStr) - allOpsCSVPath = os.path.join(outFolder, f"{name}.csv") - logger.info(f"Copying runtime artifacts") - os.system(f"rm -rf {outFolder}; mkdir -p {outFolder}") - if os.path.isfile(f"{PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG}"): - os.system(f"cp {PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG} {outFolder}") + if export_csv: + allOpsCSVPath = os.path.join(outFolder, f"{name}.csv") + logger.info(f"Copying runtime artifacts") + os.system(f"rm -rf {outFolder}; mkdir -p {outFolder}") + if os.path.isfile(f"{PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG}"): + os.system(f"cp {PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG} {outFolder}") if os.path.isfile(deviceTimesLog): logger.info(f"Getting device only ops data") @@ -309,21 +312,25 @@ def get_device_data_generate_report(deviceLogFolder, outputFolder, date, nameApp devicePreOpTime[device] = analysisData[0]["end_cycle"] rowDicts.append(rowDict) - with open(allOpsCSVPath, "w") as allOpsCSV: - allHeaders = [] - for header in OPS_CSV_HEADER: - if header in rowDicts[-1].keys(): - allHeaders.append(header) - writer = csv.DictWriter(allOpsCSV, fieldnames=allHeaders) - writer.writeheader() - for rowDict in rowDicts: - for field, fieldData in rowDict.items(): - rowDict[field] = str(fieldData).replace(",", ";") - writer.writerow(rowDict) - logger.info(f"Device only OPs csv generated at: {allOpsCSVPath}") + if export_csv: + with open(allOpsCSVPath, "w") as allOpsCSV: + allHeaders = [] + for header in OPS_CSV_HEADER: + if header in rowDicts[-1].keys(): + allHeaders.append(header) + writer = csv.DictWriter(allOpsCSV, fieldnames=allHeaders) + writer.writeheader() + for rowDict in rowDicts: + for field, fieldData in rowDict.items(): + rowDict[field] = str(fieldData).replace(",", ";") + writer.writerow(rowDict) + logger.info(f"Device only OPs csv generated at: {allOpsCSVPath}") + + if cleanup_device_log: + os.remove(deviceTimesLog) else: logger.info("No device logs found") - return deviceOps + return rowDicts def generate_reports(ops, deviceOps, signposts, outputFolder, date, nameAppend):