Skip to content

Commit

Permalink
Merge branch 'master' into fix-subgraph-fusion-intermediate-node-removal
Browse files Browse the repository at this point in the history
  • Loading branch information
acalotoiu authored Jul 15, 2023
2 parents 064ef33 + f6a19de commit bf9790a
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 8 deletions.
15 changes: 12 additions & 3 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,11 +1082,20 @@ def _subscript_expr(self, slicenode: ast.AST, target: str) -> symbolic.SymbolicT
]

if isinstance(visited_slice, ast.Tuple):
if len(strides) != len(visited_slice.elts):
# If slice is multi-dimensional and writes to array with more than 1 elements, then:
# - Assume this is indirection (?)
# - Soft-squeeze the slice (remove unit-modes) to match the treatment of the strides above.
if target not in self.constants:
desc = self.sdfg.arrays[dname]
if isinstance(desc, data.Array) and data._prod(desc.shape) != 1:
elts = [e for i, e in enumerate(visited_slice.elts) if desc.shape[i] != 1]
else:
elts = visited_slice.elts
if len(strides) != len(elts):
raise SyntaxError('Invalid number of dimensions in expression (expected %d, '
'got %d)' % (len(strides), len(visited_slice.elts)))
'got %d)' % (len(strides), len(elts)))

return sum(symbolic.pystr_to_symbolic(unparse(elt)) * s for elt, s in zip(visited_slice.elts, strides))
return sum(symbolic.pystr_to_symbolic(unparse(elt)) * s for elt, s in zip(elts, strides))

if len(strides) != 1:
raise SyntaxError('Missing dimensions in expression (expected %d, got one)' % len(strides))
Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2352,6 +2352,8 @@ def _visit_test(self, node: ast.Expr):
# Visit test-condition
if not is_test_simple:
parsed_node = self.visit(node)
if isinstance(parsed_node, (list, tuple)) and len(parsed_node) == 1:
parsed_node = parsed_node[0]
if isinstance(parsed_node, str) and parsed_node in self.sdfg.arrays:
datadesc = self.sdfg.arrays[parsed_node]
if isinstance(datadesc, data.Array):
Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -4370,6 +4370,8 @@ def _datatype_converter(sdfg: SDFG, state: SDFGState, arg: UfuncInput, dtype: dt
'outputs': ['__out'],
'code': "__out = dace.{}(__inp)".format(dtype.to_string())
}
if dtype in (dace.bool, dace.bool_):
impl['code'] = "__out = dace.bool_(__inp)"
tasklet_params = _set_tasklet_params(impl, [arg])

# Visitor input only needed when `has_where == True`.
Expand Down
21 changes: 21 additions & 0 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,27 @@ def free_symbols(self) -> Set[str]:
result |= self.dst_subset.free_symbols
return result

def get_free_symbols_by_indices(self, indices_src: List[int], indices_dst: List[int]) -> Set[str]:
"""
Returns set of free symbols used in this edges properties but only taking certain indices of the src and dst
subset into account
:param indices_src: The indices of the src subset to take into account
:type indices_src: List[int]
:param indices_dst: The indices of the dst subset to take into account
:type indices_dst: List[int]
:return: The set of free symbols
:rtype: Set[str]
"""
# Symbolic properties are in volume, and the two subsets
result = set()
result |= set(map(str, self.volume.free_symbols))
if self.src_subset:
result |= self.src_subset.get_free_symbols_by_indices(indices_src)
if self.dst_subset:
result |= self.dst_subset.get_free_symbols_by_indices(indices_dst)
return result

def get_stride(self, sdfg: 'dace.sdfg.SDFG', map: 'dace.sdfg.nodes.Map', dim: int = -1) -> 'dace.symbolic.SymExpr':
""" Returns the stride of the underlying memory when traversing a Map.
Expand Down
7 changes: 7 additions & 0 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,13 @@ def validate_state(state: 'dace.sdfg.SDFGState',
src_node = path[0].src
dst_node = path[-1].dst

# NestedSDFGs must connect to AccessNodes
if not e.data.is_empty():
if isinstance(src_node, nd.NestedSDFG) and not isinstance(dst_node, nd.AccessNode):
raise InvalidSDFGEdgeError("Nested SDFG source nodes must be AccessNodes", sdfg, state_id, eid)
if isinstance(dst_node, nd.NestedSDFG) and not isinstance(src_node, nd.AccessNode):
raise InvalidSDFGEdgeError("Nested SDFG destination nodes must be AccessNodes", sdfg, state_id, eid)

# Set up memlet-specific SDFG context
memlet_context = copy.copy(context)
for pe in path:
Expand Down
16 changes: 16 additions & 0 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,22 @@ def free_symbols(self) -> Set[str]:
result |= symbolic.symlist(d).keys()
return result

def get_free_symbols_by_indices(self, indices: List[int]) -> Set[str]:
"""
Get set of free symbols by only looking at the dimension given by the indices list
:param indices: The indices of the dimensions to look at
:type indices: List[int]
:return: The set of free symbols
:rtype: Set[str]
"""
result = set()
for i, dim in enumerate(self.ranges):
if i in indices:
for d in dim:
result |= symbolic.symlist(d).keys()
return result

def reorder(self, order):
""" Re-orders the dimensions in-place according to a permutation list.
Expand Down
15 changes: 12 additions & 3 deletions dace/transformation/dataflow/map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,16 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None)
graph.node_id(edge.dst),
edge.dst_conn,
)
# Add intermediate memory between subgraphs. If a scalar,
# uses direct connection. If an array, adds a transient node
if edge.data.subset.num_elements() == 1:
# Add intermediate memory between subgraphs.
# If a scalar, uses direct connection. If an array, adds a transient node.
# NOTE: If any of the src/dst nodes is a nested SDFG, treat it as an array.
is_scalar = edge.data.subset.num_elements() == 1
accesses = (
[graph.memlet_path(e1)[0].src for e0 in graph.in_edges(access_node) for e1 in graph.memlet_tree(e0)] +
[graph.memlet_path(e1)[-1].dst for e0 in graph.out_edges(access_node) for e1 in graph.memlet_tree(e0)])
if any(isinstance(a, nodes.NestedSDFG) for a in accesses):
is_scalar = False
if is_scalar:
local_name, _ = sdfg.add_scalar(
local_name,
dtype=access_node.desc(graph).dtype,
Expand Down Expand Up @@ -520,5 +527,7 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None)
# Modify data and memlets on all surrounding edges to match array
for neighbor in graph.all_edges(local_node):
for e in graph.memlet_tree(neighbor):
if e.data.data == local_name:
continue
e.data.data = local_name
e.data.subset.offset(old_edge.data.subset, negative=True)
4 changes: 2 additions & 2 deletions dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def _check_cand(candidates, outer_edges):
continue

# Check w.r.t. loops
if len(nstate.ranges) > 0:
if nstate is not None and len(nstate.ranges) > 0:
# Re-annotate loop ranges, in case someone changed them
# TODO: Move out of here!
for ns in nsdfg.sdfg.states():
Expand All @@ -1022,7 +1022,7 @@ def _check_cand(candidates, outer_edges):

# If there are any symbols here that are not defined
# in "defined_symbols"
missing_symbols = (memlet.free_symbols - set(nsdfg.symbol_mapping.keys()))
missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), list(indices)) - set(nsdfg.symbol_mapping.keys()))
if missing_symbols:
ignore.add(cname)
continue
Expand Down
14 changes: 14 additions & 0 deletions tests/python_frontend/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,19 @@ def if_return_chain(i: dace.int64):
assert if_return_chain(15)[0] == 4


def test_if_test_call():

@dace.program
def if_test_call(a, b):
if bool(a):
return a
else:
return b

assert if_test_call(0, 2)[0] == if_test_call.f(0, 2)
assert if_test_call(1, 2)[0] == if_test_call.f(1, 2)


if __name__ == "__main__":
test_simple_if()
test_call_if()
Expand All @@ -169,3 +182,4 @@ def if_return_chain(i: dace.int64):
test_call_while()
test_if_return_both()
test_if_return_chain()
test_if_test_call()
22 changes: 22 additions & 0 deletions tests/python_frontend/indirections_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,27 @@ def test_spmv():
assert (np.allclose(y, ref))


def test_indirection_size_1():

def compute_index(scal: dc.int32[5]):
result = 0
with dace.tasklet:
s << scal
r >> result
r = s[1] + 1 - 1
return result

@dc.program
def tester(a: dc.float64[1, 2, 3], scal: dc.int32[5]):
ind = compute_index(scal)
a[0, ind, 0] = 1

arr = np.random.rand(1, 2, 3)
scal = np.array([1, 1, 1, 1, 1], dtype=np.int32)
tester(arr, scal)
assert arr[0, 1, 0] == 1


if __name__ == "__main__":
test_indirection_scalar()
test_indirection_scalar_assign()
Expand All @@ -412,3 +433,4 @@ def test_spmv():
test_indirection_array_nested()
test_indirection_array_nested_nsdfg()
test_spmv()
test_indirection_size_1()
61 changes: 61 additions & 0 deletions tests/transformations/mapfusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,71 @@ def inner_product(A: dace.float32[N], B: dace.float32[N], out: dace.float32[1]):
assert np.allclose(val[0], ref)


def test_fusion_with_nested_sdfg_0():

@dace.program
def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int32[10]):
tmp = np.empty([10], dtype=np.int32)
for i in dace.map[0:10]:
if C[i] < 0:
tmp[i] = B[i] - A[i]
else:
tmp[i] = B[i] + A[i]
for i in dace.map[0:10]:
A[i] = tmp[i] * 2

sdfg = fusion_with_nested_sdfg_0.to_sdfg(simplify=True)
sdfg.apply_transformations(MapFusion)

for sd in sdfg.all_sdfgs_recursive():
if sd is not sdfg:
node = sd.parent_nsdfg_node
state = sd.parent
for e0 in state.out_edges(node):
for e1 in state.memlet_tree(e0):
dst = state.memlet_path(e1)[-1].dst
assert isinstance(dst, dace.nodes.AccessNode)


def test_fusion_with_nested_sdfg_1():

@dace.program
def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int32[10]):
tmp = np.empty([10], dtype=np.int32)
for i in dace.map[0:10]:
with dace.tasklet:
a << A[i]
b << B[i]
t >> tmp[i]
t = b - a
for i in dace.map[0:10]:
if C[i] < 0:
A[i] = tmp[i] * 2
else:
B[i] = tmp[i] * 2

sdfg = fusion_with_nested_sdfg_1.to_sdfg(simplify=True)
sdfg.apply_transformations(MapFusion)

if len(sdfg.states()) != 1:
return

for sd in sdfg.all_sdfgs_recursive():
if sd is not sdfg:
node = sd.parent_nsdfg_node
state = sd.parent
for e0 in state.in_edges(node):
for e1 in state.memlet_tree(e0):
src = state.memlet_path(e1)[0].src
assert isinstance(src, dace.nodes.AccessNode)


if __name__ == '__main__':
test_fusion_simple()
test_multiple_fusions()
test_fusion_chain()
test_fusion_with_transient()
test_fusion_with_inverted_indices()
test_fusion_with_empty_memlet()
test_fusion_with_nested_sdfg_0()
test_fusion_with_nested_sdfg_1()
32 changes: 32 additions & 0 deletions tests/transformations/refine_nested_access_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,38 @@ def inner_sdfg(A: dace.int32[5, 5], B: dace.int32[5, 5], select: dace.bool[5, 5]
assert np.allclose(B, lower.T + lower - diag)


def test_free_sybmols_only_by_indices():
i = dace.symbol('i')
idx_a = dace.symbol('idx_a')
idx_b = dace.symbol('idx_b')
sdfg = dace.SDFG('refine_free_symbols_only_by_indices')
sdfg.add_array('A', [5], dace.int32)
sdfg.add_array('B', [5, 5], dace.int32)

@dace.program
def inner_sdfg(A: dace.int32[5], B: dace.int32[5, 5], idx_a: int, idx_b: int):
if A[i] > 0.5:
B[i, idx_a] = 1
else:
B[i, idx_b] = 0

state = sdfg.add_state()
A = state.add_access('A')
B = state.add_access('B')
map_entry, map_exit = state.add_map('map', dict(i='0:5'))
nsdfg = state.add_nested_sdfg(inner_sdfg.to_sdfg(simplify=False), sdfg, {'A'}, {'B'}, {'i': 'i'})
state.add_memlet_path(A, map_entry, nsdfg, dst_conn='A', memlet=dace.Memlet.from_array('A', sdfg.arrays['A']))
state.add_memlet_path(nsdfg, map_exit, B, src_conn='B', memlet=dace.Memlet.from_array('B', sdfg.arrays['B']))

num = sdfg.apply_transformations_repeated(RefineNestedAccess)
assert num == 1

assert len(state.in_edges(map_exit)) == 1
edge = state.in_edges(map_exit)[0]
assert edge.data.subset == dace.subsets.Range([(i, i, 1), (0, 4, 1)])


if __name__ == '__main__':
test_refine_dataflow()
test_refine_interstate()
test_free_sybmols_only_by_indices()

0 comments on commit bf9790a

Please sign in to comment.