diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index a25d819a2d..1d4a7eb2e5 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -956,22 +956,6 @@ def _candidates( continue in_candidates[e.data.data] = (e.data, nstate, set(range(len(e.data.subset)))) - # Check interstate edges for candidates - for e in nsdfg.sdfg.edges(): - for m in e.data.get_read_memlets(nsdfg.sdfg.arrays): - # If more than one unique element detected, remove from candidates - if m.data in in_candidates: - memlet, ns, indices = in_candidates[m.data] - # Try to find dimensions in which there is a mismatch and remove them from list - for i, (s1, s2) in enumerate(zip(m.subset, memlet.subset)): - if s1 != s2 and i in indices: - indices.remove(i) - if len(indices) == 0: - ignore.add(m.data) - in_candidates[m.data] = (memlet, ns, indices) - continue - in_candidates[m.data] = (m, None, set(range(len(m.subset)))) - # Check in/out candidates for cand in in_candidates.keys() & out_candidates.keys(): s1, nstate1, ind1 = in_candidates[cand] @@ -1000,7 +984,7 @@ def _check_cand(candidates, outer_edges): continue # Check w.r.t. loops - if nstate is not None and len(nstate.ranges) > 0: + if len(nstate.ranges) > 0: # Re-annotate loop ranges, in case someone changed them # TODO: Move out of here! for ns in nsdfg.sdfg.states(): diff --git a/tests/transformations/refine_nested_access_test.py b/tests/transformations/refine_nested_access_test.py deleted file mode 100644 index 30d2a3e77e..0000000000 --- a/tests/transformations/refine_nested_access_test.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" Tests for the RefineNestedAccess transformation. """ - -import dace -import numpy as np - -from dace.transformation.interstate import RefineNestedAccess - - -def test_refine_dataflow(): - - i = dace.symbol('i') - j = dace.symbol('j') - - @dace.program - def inner_sdfg(A: dace.int32[5, 5], B: dace.int32[5, 5]): - B[i, j] = A[j, i] - - sdfg = dace.SDFG('refine_dataflow') - sdfg.add_array('A', [5, 5], dace.int32) - sdfg.add_array('B', [5, 5], dace.int32) - - state = sdfg.add_state() - A = state.add_access('A') - B = state.add_access('B') - me, mx = state.add_map('m', dict(i='0:5', j='0:5')) - nsdfg = state.add_nested_sdfg(inner_sdfg.to_sdfg(), sdfg, {'A'}, {'B'}, {'i': 'i', 'j': 'j'}) - state.add_memlet_path(A, me, nsdfg, dst_conn='A', memlet=dace.Memlet.from_array('A', sdfg.arrays['A'])) - state.add_memlet_path(nsdfg, mx, B, src_conn='B', memlet=dace.Memlet.from_array('B', sdfg.arrays['B'])) - - num = sdfg.apply_transformations_repeated(RefineNestedAccess) - assert num == 1 - - for edge in state.out_edges(me): - assert edge.data.subset == dace.subsets.Range([(j, j, 1), (i, i, 1)]) - for edge in state.in_edges(mx): - assert edge.data.subset == dace.subsets.Range([(i, i, 1), (j, j, 1)]) - - A = np.arange(25, dtype=np.int32).reshape(5, 5).copy() - B = np.empty((5, 5), dtype=np.int32) - sdfg(A=A, B=B) - assert np.allclose(B, A.T) - - -def test_refine_interstate(): - - i = dace.symbol('i') - j = dace.symbol('j') - - @dace.program - def inner_sdfg(A: dace.int32[5, 5], B: dace.int32[5, 5], select: dace.bool[5, 5]): - if select[i, j]: - B[i, j] = A[j, i] - else: - B[i, j] = A[i, j] - - sdfg = dace.SDFG('refine_dataflow') - sdfg.add_array('A', [5, 5], dace.int32) - sdfg.add_array('B', [5, 5], dace.int32) - sdfg.add_array('select', [5, 5], dace.bool) - - state = sdfg.add_state() - A = state.add_access('A') - B = state.add_access('B') - select = state.add_access('select') - me, mx = state.add_map('m', dict(i='0:5', j='0:5')) - nsdfg = state.add_nested_sdfg(inner_sdfg.to_sdfg(), sdfg, {'A', 'select'}, {'B'}, {'i': 'i', 'j': 'j'}) - state.add_memlet_path(A, me, nsdfg, dst_conn='A', memlet=dace.Memlet.from_array('A', sdfg.arrays['A'])) - state.add_memlet_path(select, - me, - nsdfg, - dst_conn='select', - memlet=dace.Memlet.from_array('select', sdfg.arrays['select'])) - state.add_memlet_path(nsdfg, mx, B, src_conn='B', memlet=dace.Memlet.from_array('B', sdfg.arrays['B'])) - - num = sdfg.apply_transformations_repeated(RefineNestedAccess) - assert num == 1 - - for edge in state.out_edges(me): - if edge.data.data == 'A': - expr = dace.symbolic.pystr_to_symbolic('Max(i, j)') - assert edge.data.subset == dace.subsets.Range([(0, expr, 1), (0, expr, 1)]) - else: - assert edge.data.subset == dace.subsets.Range([(i, i, 1), (j, j, 1)]) - for edge in state.in_edges(mx): - assert edge.data.subset == dace.subsets.Range([(i, i, 1), (j, j, 1)]) - - A = np.arange(25, dtype=np.int32).reshape(5, 5).copy() - B = np.empty((5, 5), dtype=np.int32) - select = np.empty((5, 5), dtype=np.bool_) - select[:] = True - upper = np.triu(select, k=0) - sdfg(A=A, B=B, select=upper) - lower = np.tril(A, k=0) - diag = np.diag(np.diag(A)) - assert np.allclose(B, lower.T + lower - diag) - - -if __name__ == '__main__': - test_refine_dataflow() - test_refine_interstate()