Skip to content

Commit

Permalink
Merge pull request #1307 from spcl/fix-subgraph-fusion-intermediate-n…
Browse files Browse the repository at this point in the history
…ode-removal

Fixes Intemediate Node Removal in SubgraphFusion
  • Loading branch information
alexnick83 authored Jul 15, 2023
2 parents f6a19de + bf9790a commit 531b0ae
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 72 deletions.
16 changes: 8 additions & 8 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,22 +841,22 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg
return out_edge
if not src_is_data and not dst_is_data:
return None

# Check if there is a 'views' connector
if in_edge.dst_conn and in_edge.dst_conn == 'views':
return in_edge
if out_edge.src_conn and out_edge.src_conn == 'views':
return out_edge

# If both sides lead to access nodes, if one memlet's data points to the
# view it cannot point to the viewed node.
# TODO: This sounds arbitrary and is not well communicated to the frontends. Revisit in the future.
# If both sides lead to access nodes, if one memlet's data points to the view it cannot point to the viewed node.
if in_edge.data.data == view.data and out_edge.data.data != view.data:
return out_edge
if in_edge.data.data != view.data and out_edge.data.data == view.data:
return in_edge
if in_edge.data.data == view.data and out_edge.data.data == view.data:
return None

# Check if there is a 'views' connector
if in_edge.dst_conn and in_edge.dst_conn == 'views':
return in_edge
if out_edge.src_conn and out_edge.src_conn == 'views':
return out_edge

# If both memlets' data are the respective access nodes, the access
# node at the highest scope is the one that is viewed.
if isinstance(in_edge.src, nd.EntryNode):
Expand Down
181 changes: 119 additions & 62 deletions dace/transformation/subgraph/subgraph_fusion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
""" This module contains classes that implement subgraph fusion
"""
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
""" This module contains classes that implement subgraph fusion. """
import dace
import networkx as nx

from dace import dtypes, registry, symbolic, subsets, data
from dace.sdfg import nodes, utils, replace, SDFG, scope_contains_scope
Expand Down Expand Up @@ -1144,78 +1144,135 @@ def change_data(transient_array, shape, strides, total_size, offset, lifetime, s

# Try to remove intermediate nodes that are not contained in the subgraph
# by reconnecting their adjacent edges to nodes outside the subgraph.
for node in intermediate_nodes:
# NOTE: Currently limited to cases where there is a single source and sink
# if there are multiple intermediate accesses for the same data.

# Sort intermediate nodes by data name
intermediate_data = dict()
for acc in intermediate_nodes:
if acc.data in intermediate_data:
intermediate_data[acc.data].append(acc)
else:
intermediate_data[acc.data] = [acc]

filtered_intermediate_data = dict()
intermediate_sources = dict()
intermediate_sinks = dict()
for dname, accesses in intermediate_data.items():

sources = set(accesses)
sinks = set(accesses)

# Find sinks
for acc0 in accesses:
for acc1 in set(sinks):
if acc0 is acc1:
continue
if nx.has_path(graph.nx, acc0, acc1):
sinks.remove(acc0)
break
if len(sinks) > 1:
continue
# Find sources
for acc0 in accesses:
for acc1 in set(sources):
if acc0 is acc1:
continue
if nx.has_path(graph.nx, acc1, acc0):
sources.remove(acc0)
break
if len(sources) > 1:
continue

filtered_intermediate_data[dname] = accesses
intermediate_sources[dname] = sources
intermediate_sinks[dname] = sinks

edges_to_remove = set()

for dname, accesses in filtered_intermediate_data.items():

# Checking if data are contained in the subgraph
if not subgraph_contains_data[node.data]:
if not subgraph_contains_data[dname]:
# Find existing outer access nodes
inode, onode = None, None
for e in graph.in_edges(global_map_entry):
if isinstance(e.src, nodes.AccessNode) and node.data == e.src.data:
if isinstance(e.src, nodes.AccessNode) and dname == e.src.data:
inode = e.src
break
for e in graph.out_edges(global_map_exit):
if isinstance(e.dst, nodes.AccessNode) and node.data == e.dst.data:
if isinstance(e.dst, nodes.AccessNode) and dname == e.dst.data:
onode = e.dst
break

to_remove = set()

# Compute the union of all incoming subsets.
# TODO: Do we expect this operation to ever fail?
in_subset: subsets.Subset = None
for ie in graph.in_edges(node):
if in_subset:
in_subset = subsets.union(in_subset, ie.data.dst_subset)
else:
in_subset = ie.data.dst_subset
first_subset: subsets.Subset = None
for acc in accesses:
for ie in graph.in_edges(acc):
if in_subset:
in_subset = subsets.union(in_subset, ie.data.dst_subset)
else:
in_subset = ie.data.dst_subset
first_subset = ie.data.dst_subset

# Create transient data corresponding to the union of the incoming subsets.
desc = sdfg.arrays[node.data]
name, new_desc = sdfg.add_temp_transient(in_subset.bounding_box_size(), desc.dtype, desc.storage)
new_node = graph.add_access(name)

# Reconnect incoming edges through the transient data.
for ie in graph.in_edges(node):
mem = Memlet(data=name,
subset=ie.data.dst_subset.offset_new(in_subset, True),
other_subset=ie.data.src_subset)
new_edge = graph.add_edge(ie.src, ie.src_conn, new_node, None, mem)
to_remove.add(ie)
# Update memlet paths.
for e in graph.memlet_path(new_edge):
if e.data.data == node.data:
e.data.data = name
e.data.dst_subset.offset(in_subset, True)

# Reconnect outgoing edges through the transient data.
for oe in graph.out_edges(node):
if in_subset.covers(oe.data.src_subset):
mem = Memlet(data=name,
subset=oe.data.src_subset.offset_new(in_subset, True),
other_subset=oe.data.dst_subset)
new_edge = graph.add_edge(new_node, None, oe.dst, oe.dst_conn, mem)
desc = sdfg.arrays[dname]
new_name, _ = sdfg.add_temp_transient(in_subset.bounding_box_size(), desc.dtype, desc.storage)

for acc in accesses:

acc.data = new_name

# Reconnect incoming edges through the transient data.
for ie in graph.in_edges(acc):
mem = Memlet(data=new_name,
subset=ie.data.dst_subset.offset_new(in_subset, True),
other_subset=ie.data.src_subset)
# new_edge = graph.add_edge(ie.src, ie.src_conn, new_node, None, mem)
ie.data = mem
# Update memlet paths.
for e in graph.memlet_path(new_edge):
if e.data.data == node.data:
e.data.data = name
e.data.src_subset.offset(in_subset, True)
else:
# If the outgoing subset is not covered by the transient data, connect to the outer input node.
if not inode:
inode = graph.add_access(node.data)
graph.add_memlet_path(inode, global_map_entry, oe.dst, memlet=oe.data, dst_conn=oe.dst_conn)
to_remove.add(oe)

# Connect transient data to the outer output node.
if not onode:
onode = graph.add_access(node.data)
graph.add_memlet_path(new_node,
global_map_exit,
onode,
memlet=Memlet(data=node.data, subset=in_subset),
src_conn=None)

for e in to_remove:
graph.remove_edge(e)
if to_remove:
graph.remove_node(node)
for e in graph.memlet_path(ie):
if e.data.data == dname:
e.data.data = new_name
e.data.dst_subset.offset(in_subset, True)

# Reconnect outgoing edges through the transient data.
for oe in graph.out_edges(acc):
if in_subset.covers(oe.data.src_subset):
mem = Memlet(data=new_name,
subset=oe.data.src_subset.offset_new(in_subset, True),
other_subset=oe.data.dst_subset)
# new_edge = graph.add_edge(new_node, None, oe.dst, oe.dst_conn, mem)
oe.data = mem
# Update memlet paths.
for e in graph.memlet_path(oe):
if e.data.data == dname:
e.data.data = new_name
e.data.src_subset.offset(in_subset, True)
else:
# NOTE: For debugging purposes
intersect = subsets.intersects(in_subset, oe.data.src_subset)
if intersect is None:
warnings.warn(f'{dname}[{in_subset}] may intersect with {dname}[{oe.data.src_subset}]')
elif intersect:
raise ValueError(f'{dname}[{in_subset}] intersects with {dname}[{oe.data.src_subset}]')
# If the outgoing subset is not covered by the transient data, connect to the outer input node.
if not inode:
inode = graph.add_access(dname)
graph.add_memlet_path(inode, global_map_entry, oe.dst, memlet=oe.data, dst_conn=oe.dst_conn)
edges_to_remove.add(oe)

# Connect transient data to the outer output node.
if acc in intermediate_sinks[dname]:
if not onode:
onode = graph.add_access(dname)
graph.add_memlet_path(acc,
global_map_exit,
onode,
memlet=Memlet(data=dname, subset=in_subset),
src_conn=None)

for e in edges_to_remove:
graph.remove_edge(e)
60 changes: 58 additions & 2 deletions tests/transformations/subgraph_fusion/intermediate_mimo_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import copy
import dace
from dace.sdfg import nodes
from dace.sdfg.graph import SubgraphView
from dace.transformation.dataflow import MapFission
from dace.transformation.helpers import nest_state_subgraph
import numpy as np
import unittest
Expand Down Expand Up @@ -101,5 +100,62 @@ def test_mimo():
_test_quantitatively(sdfg)


def test_single_data_multiple_intermediate_accesses():

@dace.program
def sdmi_accesses(ZSOLQA: dace.float64[1, 5, 5], ZEPSEC: dace.float64, ZQX: dace.float64[1, 137, 5],
LLINDEX3: dace.bool[1, 5, 5], ZRATIO: dace.float64[1, 5], ZSINKSUM: dace.float64[1, 5]):

for i in dace.map[0:5]:
ZSINKSUM[0, i] = 0.0
for j in dace.map[0:5]:
LLINDEX3[0, j, i] = False

for i in dace.map[0:5]:
for k in range(5):
ZSINKSUM[0, i] = ZSINKSUM[0, i] - ZSOLQA[0, 0, k]

for i in dace.map[0:5]:
t0 = max(ZEPSEC, ZQX[0, 0, i])
t1 = max(t0, ZSINKSUM[0, i])
ZRATIO[0, i] = t0 / t1

sdfg = sdmi_accesses.to_sdfg(simplify=True)
assert len(sdfg.states()) == 1

rng = np.random.default_rng(42)
ZSOLQA = rng.random((1, 5, 5))
ZEPSEC = rng.random()
ZQX = rng.random((1, 137, 5))
ref_LLINDEX3 = rng.random((1, 5, 5)) > 0.5
ref_ZRATIO = rng.random((1, 5))
ref_ZSINKSUM = rng.random((1, 5))
val_LLINDEX3 = ref_LLINDEX3.copy()
val_ZRATIO = ref_ZRATIO.copy()
val_ZSINKSUM = ref_ZSINKSUM.copy()

sdfg(ZSOLQA=ZSOLQA, ZEPSEC=ZEPSEC, ZQX=ZQX, LLINDEX3=ref_LLINDEX3, ZRATIO=ref_ZRATIO, ZSINKSUM=ref_ZSINKSUM)

graph = sdfg.states()[0]
subgraph = SubgraphView(graph, [node for node in graph.nodes()])

me = MultiExpansion()
me.setup_match(subgraph)
assert me.can_be_applied(sdfg, subgraph) == True
me.apply(sdfg)

sf = SubgraphFusion()
sf.setup_match(subgraph)
assert sf.can_be_applied(sdfg, subgraph) == True
sf.apply(sdfg)

sdfg(ZSOLQA=ZSOLQA, ZEPSEC=ZEPSEC, ZQX=ZQX, LLINDEX3=val_LLINDEX3, ZRATIO=val_ZRATIO, ZSINKSUM=val_ZSINKSUM)

assert np.allclose(ref_LLINDEX3, val_LLINDEX3)
assert np.allclose(ref_ZRATIO, val_ZRATIO)
assert np.allclose(ref_ZSINKSUM, val_ZSINKSUM)


if __name__ == '__main__':
test_mimo()
test_single_data_multiple_intermediate_accesses()

0 comments on commit 531b0ae

Please sign in to comment.