From 6de81ca9d199e597e5a22d51527c5857d516e1e0 Mon Sep 17 00:00:00 2001 From: Luigi Fusco Date: Thu, 5 Sep 2024 11:52:43 +0200 Subject: [PATCH 1/3] add state replication and if raising transformations --- dace/transformation/interstate/__init__.py | 2 + dace/transformation/interstate/if_raising.py | 67 ++++++++++++++++ .../interstate/state_replication.py | 78 +++++++++++++++++++ 3 files changed, 147 insertions(+) create mode 100644 dace/transformation/interstate/if_raising.py create mode 100644 dace/transformation/interstate/state_replication.py diff --git a/dace/transformation/interstate/__init__.py b/dace/transformation/interstate/__init__.py index b8bcc716e6..e8bc42fdf2 100644 --- a/dace/transformation/interstate/__init__.py +++ b/dace/transformation/interstate/__init__.py @@ -16,3 +16,5 @@ from .trivial_loop_elimination import TrivialLoopElimination from .multistate_inline import InlineMultistateSDFG from .move_assignment_outside_if import MoveAssignmentOutsideIf +from .state_replication import StateReplication +from .if_raising import IfRaising diff --git a/dace/transformation/interstate/if_raising.py b/dace/transformation/interstate/if_raising.py new file mode 100644 index 0000000000..e5f4128847 --- /dev/null +++ b/dace/transformation/interstate/if_raising.py @@ -0,0 +1,67 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" If raising transformation """ + +from dace import data as dt, sdfg as sd +from dace.sdfg import InterstateEdge +from dace.sdfg import utils as sdutil +from dace.sdfg.state import SDFGState +from dace.transformation import transformation +from dace.properties import make_properties + +@make_properties +class IfRaising(transformation.MultiStateTransformation): + """ + Duplicates an if guard and anticipates the evaluation of the condition + """ + + if_guard = transformation.PatternNode(sd.SDFGState) + + @staticmethod + def annotates_memlets(): + return False + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.if_guard)] + + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + if_guard: SDFGState = self.if_guard + + out_edges = graph.out_edges(if_guard) + + if len(out_edges) != 2: + return False + + if if_guard.is_empty(): + return False + + # check that condition does not depend on computations in the state + condition_symbols = out_edges[0].data.condition.get_free_symbols() + _, wset = if_guard.read_and_write_sets() + if len(condition_symbols.intersection(wset)) != 0: + return False + + return True + + + def apply(self, _, sdfg: sd.SDFG): + if_guard: SDFGState = self.if_guard + + raised_if_guard = sdfg.add_state('raised_if_guard') + sdutil.change_edge_dest(sdfg, if_guard, raised_if_guard) + + replica = sd.SDFGState.from_json(if_guard.to_json(), context={'sdfg': sdfg}) + all_block_names = set([s.label for s in sdfg.nodes()]) + replica.label = dt.find_new_name(replica.label, all_block_names) + sdfg.add_node(replica) + + # move conditional edges up + if_branch, else_branch = sdfg.out_edges(if_guard) + sdfg.remove_edge(if_branch) + sdfg.remove_edge(else_branch) + + sdfg.add_edge(if_guard, if_branch.dst, InterstateEdge(assignments=if_branch.data.assignments)) + sdfg.add_edge(replica, else_branch.dst, InterstateEdge(assignments=else_branch.data.assignments)) + + sdfg.add_edge(raised_if_guard, if_guard, InterstateEdge(condition=if_branch.data.condition)) + sdfg.add_edge(raised_if_guard, replica, InterstateEdge(condition=else_branch.data.condition)) diff --git a/dace/transformation/interstate/state_replication.py b/dace/transformation/interstate/state_replication.py new file mode 100644 index 0000000000..f1f43f3424 --- /dev/null +++ b/dace/transformation/interstate/state_replication.py @@ -0,0 +1,78 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" State replication transformation """ + +from dace import data as dt, sdfg as sd +from dace.sdfg import utils as sdutil +from dace.sdfg.state import SDFGState +from dace.transformation import transformation +from copy import deepcopy +from dace.transformation.interstate.loop_detection import DetectLoop +from dace.properties import make_properties + +@make_properties +class StateReplication(transformation.MultiStateTransformation): + """ + Creates a copy of a state for each of its incoming edge. Then, redirects every edge to a different copy. + This results in states with only one incoming edge. + """ + + target_state = transformation.PatternNode(sd.SDFGState) + + @staticmethod + def annotates_memlets(): + return True + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.target_state)] + + + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + target_state: SDFGState = self.target_state + + out_edges = graph.out_edges(target_state) + in_edges = graph.in_edges(target_state) + + if len(in_edges) < 2: + return False + + # avoid useless replications + if target_state.is_empty() and len(out_edges) < 2: + return False + + # make sure this is not a loop guard + if len(out_edges) == 2: + detect = DetectLoop() + detect.loop_guard = target_state + detect.loop_begin = out_edges[0].dst + detect.exit_state = out_edges[1].dst + if detect.can_be_applied(graph, 0, sdfg): + return False + detect.exit_state = out_edges[0].dst + detect.loop_begin = out_edges[1].dst + if detect.can_be_applied(graph, 0, sdfg): + return False + + return True + + def apply(self, _, sdfg: sd.SDFG): + target_state: SDFGState = self.target_state + + if len(sdfg.out_edges(target_state)) == 0: + sdfg.add_state_after(target_state) + + state_names = set(s.label for s in sdfg.nodes()) + + root_blueprint = target_state.to_json() + for e in sdfg.in_edges(target_state)[1:]: + state_copy = sd.SDFGState.from_json(root_blueprint, context={'sdfg': sdfg}) + state_copy.label = dt.find_new_name(state_copy.label, state_names) + state_names.add(state_copy.label) + sdfg.add_node(state_copy) + + sdfg.remove_edge(e) + sdfg.add_edge(e.src, state_copy, e.data) + + # connect out edges + for oe in sdfg.out_edges(target_state): + sdfg.add_edge(state_copy, oe.dst, deepcopy(oe.data)) From 3495aeb178e6d9b17e17007a598161c4a18896a6 Mon Sep 17 00:00:00 2001 From: Luigi Fusco Date: Thu, 5 Sep 2024 15:37:50 +0200 Subject: [PATCH 2/3] add some tests --- .../raise_and_duplicate_test.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/transformations/raise_and_duplicate_test.py diff --git a/tests/transformations/raise_and_duplicate_test.py b/tests/transformations/raise_and_duplicate_test.py new file mode 100644 index 0000000000..cc7769ee3a --- /dev/null +++ b/tests/transformations/raise_and_duplicate_test.py @@ -0,0 +1,62 @@ +import dace +from dace.sdfg import nodes +from dace.transformation.interstate import IfRaising, StateReplication +from dace.transformation.dataflow import OTFMapFusion +import numpy as np + + +def test_raise_and_duplicate_and_fusions(): + N = dace.symbol('N', dace.int64) + @dace.program + def program(flag: dace.bool, in_arr: dace.float64[N], arr: dace.float64[N]): + tmp1 = np.empty_like(arr) + tmp2 = np.empty_like(arr) + for i in dace.map[0:N]: + tmp1[i] = in_arr[i] + if flag: + for i in dace.map[0:N]: + tmp2[i] = tmp1[i] + else: + for i in dace.map[0:N]: + tmp2[i] = tmp1[i] + for i in dace.map[0:N]: + arr[i] = tmp2[i] + + sdfg = program.to_sdfg() + sdfg.apply_transformations([IfRaising, StateReplication]) + sdfg.simplify() + sdfg.apply_transformations_repeated([OTFMapFusion]) + + assert len(sdfg.nodes()) == 4 + assert sdfg.start_state.is_empty() + + entries = 0 + for state in sdfg.nodes(): + for node in state.nodes(): + if isinstance(node, nodes.MapEntry): + entries += 1 + + assert entries == 2 + + +def test_if_raise_dependency(): + N = dace.symbol('N', dace.int64) + @dace.program + def program(arr: dace.float64[N]): + flag = np.sum(arr) + if flag: + return 1 + else: + return 0 + + sdfg = program.to_sdfg() + + transform = IfRaising() + transform.if_guard = sdfg.start_state + + assert not transform.can_be_applied(sdfg, 0, sdfg) + + +if __name__ == '__main__': + test_raise_and_duplicate_and_fusions() + test_if_raise_dependency() From 355af5b5e3860544e297aac328c4aca371b37ed6 Mon Sep 17 00:00:00 2001 From: Luigi Fusco Date: Fri, 13 Sep 2024 15:04:07 +0200 Subject: [PATCH 3/3] fix to_sdfg in test --- tests/transformations/raise_and_duplicate_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/transformations/raise_and_duplicate_test.py b/tests/transformations/raise_and_duplicate_test.py index cc7769ee3a..1306f01df0 100644 --- a/tests/transformations/raise_and_duplicate_test.py +++ b/tests/transformations/raise_and_duplicate_test.py @@ -22,7 +22,7 @@ def program(flag: dace.bool, in_arr: dace.float64[N], arr: dace.float64[N]): for i in dace.map[0:N]: arr[i] = tmp2[i] - sdfg = program.to_sdfg() + sdfg = program.to_sdfg(simplify=True) sdfg.apply_transformations([IfRaising, StateReplication]) sdfg.simplify() sdfg.apply_transformations_repeated([OTFMapFusion]) @@ -49,7 +49,7 @@ def program(arr: dace.float64[N]): else: return 0 - sdfg = program.to_sdfg() + sdfg = program.to_sdfg(simplify=True) transform = IfRaising() transform.if_guard = sdfg.start_state