diff --git a/tests/run_tests.py b/tests/run_tests.py index 6321f69d341c3..f53a3773f83e2 100644 --- a/tests/run_tests.py +++ b/tests/run_tests.py @@ -1,8 +1,10 @@ import argparse import atexit +import copy import os import pdb import platform +import re import shutil import subprocess import sys @@ -61,7 +63,7 @@ def _test_cpp_aot(test_filename, build_dir, test_info): return exclude_tests_cmd -def _test_cpp(): +def _test_cpp(test_keys=None): curr_dir = os.path.dirname(os.path.abspath(__file__)) build_dir = os.path.join(curr_dir, '../build') cpp_test_filename = 'taichi_cpp_tests' @@ -73,6 +75,37 @@ def _test_cpp(): capi_tests_exe_path = os.path.join(build_dir, capi_test_filename) cpp_tests_exe_path = os.path.join(build_dir, cpp_test_filename) + # Run manually specified C++ tests only, for example: + # "python3 tests/run_tests.py --cpp -k Scalarize.*" + if test_keys: + # Search AOT tests + aot_test_cases = copy.copy(__aot_test_cases) + for cpp_test_name, (_, _) in __aot_test_cases.items(): + name_match = re.match(test_keys, cpp_test_name, re.I) + if name_match is None: + aot_test_cases.pop(cpp_test_name, None) + if aot_test_cases: + _test_cpp_aot(cpp_test_filename, build_dir, aot_test_cases) + + # Search CAPI tests + capi_aot_test_cases = copy.copy(__capi_aot_test_cases) + for cpp_test_name, (_, _) in __capi_aot_test_cases.items(): + name_match = re.match(test_keys, cpp_test_name, re.I) + if name_match is None: + capi_aot_test_cases.pop(cpp_test_name, None) + if capi_aot_test_cases: + _test_cpp_aot(capi_test_filename, build_dir, capi_aot_test_cases) + + # Search Cpp tests + _run_cpp_test(cpp_test_filename, build_dir, + f"--gtest_filter={test_keys}") + + _run_cpp_test(capi_test_filename, build_dir, + f"--gtest_filter={test_keys}") + + return + + # Regular C++ tests if os.path.exists(capi_tests_exe_path): # Run C-API test cases exclude_tests_cmd = _test_cpp_aot(capi_test_filename, build_dir, @@ -336,7 +369,7 @@ def size_of_dir(dir): os.environ['TI_OFFLINE_CACHE'] = '0' if args.cpp: - _test_cpp() + _test_cpp(args.keys) return for _ in range(run_count):