diff --git a/python/taichi/aot/utils.py b/python/taichi/aot/utils.py index 28d6982be3716c..c715347cb23851 100644 --- a/python/taichi/aot/utils.py +++ b/python/taichi/aot/utils.py @@ -13,7 +13,9 @@ def check_type_match(lhs, rhs): if isinstance(lhs, MatrixType) and isinstance(rhs, MatrixType): - return lhs.n == rhs.n and lhs.m == rhs.m and (lhs.dtype == rhs.dtype or lhs.dtype is None or rhs.dtype is None) + return lhs.n == rhs.n and lhs.m == rhs.m and (lhs.dtype == rhs.dtype + or lhs.dtype is None + or rhs.dtype is None) elif isinstance(lhs, MatrixType) or isinstance(rhs, MatrixType): return False else: diff --git a/tests/cpp/aot/python_scripts/graph_aot_test.py b/tests/cpp/aot/python_scripts/graph_aot_test.py index 284ecf190a5ab7..ae9afd0380e52a 100644 --- a/tests/cpp/aot/python_scripts/graph_aot_test.py +++ b/tests/cpp/aot/python_scripts/graph_aot_test.py @@ -16,7 +16,9 @@ def run0(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.i32)): arr[i] += base + i @ti.kernel - def run1(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.types.vector(1, ti.i32))): + def run1(base: int, arr: ti.types.ndarray(ndim=1, + dtype=ti.types.vector(1, + ti.i32))): for i in arr: arr[i] += base + i