Skip to content

Commit

Permalink
Merge pull request #1317 from Sajohn-CH/free_symbols_indices
Browse files Browse the repository at this point in the history
RefineNestedAccess take indices into account when checking for missing free symbols
  • Loading branch information
alexnick83 authored Jul 14, 2023
2 parents 28b93bd + 4638a0a commit f6a19de
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 2 deletions.
21 changes: 21 additions & 0 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,27 @@ def free_symbols(self) -> Set[str]:
result |= self.dst_subset.free_symbols
return result

def get_free_symbols_by_indices(self, indices_src: List[int], indices_dst: List[int]) -> Set[str]:
"""
Returns set of free symbols used in this edges properties but only taking certain indices of the src and dst
subset into account
:param indices_src: The indices of the src subset to take into account
:type indices_src: List[int]
:param indices_dst: The indices of the dst subset to take into account
:type indices_dst: List[int]
:return: The set of free symbols
:rtype: Set[str]
"""
# Symbolic properties are in volume, and the two subsets
result = set()
result |= set(map(str, self.volume.free_symbols))
if self.src_subset:
result |= self.src_subset.get_free_symbols_by_indices(indices_src)
if self.dst_subset:
result |= self.dst_subset.get_free_symbols_by_indices(indices_dst)
return result

def get_stride(self, sdfg: 'dace.sdfg.SDFG', map: 'dace.sdfg.nodes.Map', dim: int = -1) -> 'dace.symbolic.SymExpr':
""" Returns the stride of the underlying memory when traversing a Map.
Expand Down
16 changes: 16 additions & 0 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,22 @@ def free_symbols(self) -> Set[str]:
result |= symbolic.symlist(d).keys()
return result

def get_free_symbols_by_indices(self, indices: List[int]) -> Set[str]:
"""
Get set of free symbols by only looking at the dimension given by the indices list
:param indices: The indices of the dimensions to look at
:type indices: List[int]
:return: The set of free symbols
:rtype: Set[str]
"""
result = set()
for i, dim in enumerate(self.ranges):
if i in indices:
for d in dim:
result |= symbolic.symlist(d).keys()
return result

def reorder(self, order):
""" Re-orders the dimensions in-place according to a permutation list.
Expand Down
4 changes: 2 additions & 2 deletions dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def _check_cand(candidates, outer_edges):
continue

# Check w.r.t. loops
if len(nstate.ranges) > 0:
if nstate is not None and len(nstate.ranges) > 0:
# Re-annotate loop ranges, in case someone changed them
# TODO: Move out of here!
for ns in nsdfg.sdfg.states():
Expand All @@ -1022,7 +1022,7 @@ def _check_cand(candidates, outer_edges):

# If there are any symbols here that are not defined
# in "defined_symbols"
missing_symbols = (memlet.free_symbols - set(nsdfg.symbol_mapping.keys()))
missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), list(indices)) - set(nsdfg.symbol_mapping.keys()))
if missing_symbols:
ignore.add(cname)
continue
Expand Down
32 changes: 32 additions & 0 deletions tests/transformations/refine_nested_access_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,38 @@ def inner_sdfg(A: dace.int32[5, 5], B: dace.int32[5, 5], select: dace.bool[5, 5]
assert np.allclose(B, lower.T + lower - diag)


def test_free_sybmols_only_by_indices():
i = dace.symbol('i')
idx_a = dace.symbol('idx_a')
idx_b = dace.symbol('idx_b')
sdfg = dace.SDFG('refine_free_symbols_only_by_indices')
sdfg.add_array('A', [5], dace.int32)
sdfg.add_array('B', [5, 5], dace.int32)

@dace.program
def inner_sdfg(A: dace.int32[5], B: dace.int32[5, 5], idx_a: int, idx_b: int):
if A[i] > 0.5:
B[i, idx_a] = 1
else:
B[i, idx_b] = 0

state = sdfg.add_state()
A = state.add_access('A')
B = state.add_access('B')
map_entry, map_exit = state.add_map('map', dict(i='0:5'))
nsdfg = state.add_nested_sdfg(inner_sdfg.to_sdfg(simplify=False), sdfg, {'A'}, {'B'}, {'i': 'i'})
state.add_memlet_path(A, map_entry, nsdfg, dst_conn='A', memlet=dace.Memlet.from_array('A', sdfg.arrays['A']))
state.add_memlet_path(nsdfg, map_exit, B, src_conn='B', memlet=dace.Memlet.from_array('B', sdfg.arrays['B']))

num = sdfg.apply_transformations_repeated(RefineNestedAccess)
assert num == 1

assert len(state.in_edges(map_exit)) == 1
edge = state.in_edges(map_exit)[0]
assert edge.data.subset == dace.subsets.Range([(i, i, 1), (0, 4, 1)])


if __name__ == '__main__':
test_refine_dataflow()
test_refine_interstate()
test_free_sybmols_only_by_indices()

0 comments on commit f6a19de

Please sign in to comment.