-
Notifications
You must be signed in to change notification settings - Fork 127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add state replication and if raising transformations #1639
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. | ||
""" If raising transformation """ | ||
|
||
from dace import data as dt, sdfg as sd | ||
from dace.sdfg import InterstateEdge | ||
from dace.sdfg import utils as sdutil | ||
from dace.sdfg.state import SDFGState | ||
from dace.transformation import transformation | ||
from dace.properties import make_properties | ||
|
||
@make_properties | ||
class IfRaising(transformation.MultiStateTransformation): | ||
""" | ||
Duplicates an if guard and anticipates the evaluation of the condition | ||
""" | ||
|
||
if_guard = transformation.PatternNode(sd.SDFGState) | ||
|
||
@staticmethod | ||
def annotates_memlets(): | ||
return False | ||
|
||
@classmethod | ||
def expressions(cls): | ||
return [sdutil.node_path_graph(cls.if_guard)] | ||
|
||
def can_be_applied(self, graph, expr_index, sdfg, permissive=False): | ||
if_guard: SDFGState = self.if_guard | ||
|
||
out_edges = graph.out_edges(if_guard) | ||
|
||
if len(out_edges) != 2: | ||
return False | ||
|
||
if if_guard.is_empty(): | ||
return False | ||
|
||
# check that condition does not depend on computations in the state | ||
condition_symbols = out_edges[0].data.condition.get_free_symbols() | ||
_, wset = if_guard.read_and_write_sets() | ||
if len(condition_symbols.intersection(wset)) != 0: | ||
return False | ||
|
||
return True | ||
|
||
|
||
def apply(self, _, sdfg: sd.SDFG): | ||
if_guard: SDFGState = self.if_guard | ||
|
||
raised_if_guard = sdfg.add_state('raised_if_guard') | ||
sdutil.change_edge_dest(sdfg, if_guard, raised_if_guard) | ||
|
||
replica = sd.SDFGState.from_json(if_guard.to_json(), context={'sdfg': sdfg}) | ||
all_block_names = set([s.label for s in sdfg.nodes()]) | ||
replica.label = dt.find_new_name(replica.label, all_block_names) | ||
sdfg.add_node(replica) | ||
|
||
# move conditional edges up | ||
if_branch, else_branch = sdfg.out_edges(if_guard) | ||
sdfg.remove_edge(if_branch) | ||
sdfg.remove_edge(else_branch) | ||
|
||
sdfg.add_edge(if_guard, if_branch.dst, InterstateEdge(assignments=if_branch.data.assignments)) | ||
sdfg.add_edge(replica, else_branch.dst, InterstateEdge(assignments=else_branch.data.assignments)) | ||
|
||
sdfg.add_edge(raised_if_guard, if_guard, InterstateEdge(condition=if_branch.data.condition)) | ||
sdfg.add_edge(raised_if_guard, replica, InterstateEdge(condition=else_branch.data.condition)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. | ||
""" State replication transformation """ | ||
|
||
from dace import data as dt, sdfg as sd | ||
from dace.sdfg import utils as sdutil | ||
from dace.sdfg.state import SDFGState | ||
from dace.transformation import transformation | ||
from copy import deepcopy | ||
from dace.transformation.interstate.loop_detection import DetectLoop | ||
from dace.properties import make_properties | ||
|
||
@make_properties | ||
class StateReplication(transformation.MultiStateTransformation): | ||
""" | ||
Creates a copy of a state for each of its incoming edge. Then, redirects every edge to a different copy. | ||
This results in states with only one incoming edge. | ||
""" | ||
|
||
target_state = transformation.PatternNode(sd.SDFGState) | ||
|
||
@staticmethod | ||
def annotates_memlets(): | ||
return True | ||
|
||
@classmethod | ||
def expressions(cls): | ||
return [sdutil.node_path_graph(cls.target_state)] | ||
|
||
|
||
def can_be_applied(self, graph, expr_index, sdfg, permissive=False): | ||
target_state: SDFGState = self.target_state | ||
|
||
out_edges = graph.out_edges(target_state) | ||
in_edges = graph.in_edges(target_state) | ||
|
||
if len(in_edges) < 2: | ||
return False | ||
|
||
# avoid useless replications | ||
if target_state.is_empty() and len(out_edges) < 2: | ||
return False | ||
|
||
# make sure this is not a loop guard | ||
if len(out_edges) == 2: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Application on loops results in the addition of useless states (but the SDFG is still correct). This will not get rid of the loop, meaning that |
||
detect = DetectLoop() | ||
detect.loop_guard = target_state | ||
detect.loop_begin = out_edges[0].dst | ||
detect.exit_state = out_edges[1].dst | ||
if detect.can_be_applied(graph, 0, sdfg): | ||
return False | ||
detect.exit_state = out_edges[0].dst | ||
detect.loop_begin = out_edges[1].dst | ||
if detect.can_be_applied(graph, 0, sdfg): | ||
return False | ||
|
||
return True | ||
|
||
def apply(self, _, sdfg: sd.SDFG): | ||
target_state: SDFGState = self.target_state | ||
|
||
if len(sdfg.out_edges(target_state)) == 0: | ||
sdfg.add_state_after(target_state) | ||
|
||
state_names = set(s.label for s in sdfg.nodes()) | ||
|
||
root_blueprint = target_state.to_json() | ||
for e in sdfg.in_edges(target_state)[1:]: | ||
state_copy = sd.SDFGState.from_json(root_blueprint, context={'sdfg': sdfg}) | ||
state_copy.label = dt.find_new_name(state_copy.label, state_names) | ||
state_names.add(state_copy.label) | ||
sdfg.add_node(state_copy) | ||
|
||
sdfg.remove_edge(e) | ||
sdfg.add_edge(e.src, state_copy, e.data) | ||
|
||
# connect out edges | ||
for oe in sdfg.out_edges(target_state): | ||
sdfg.add_edge(state_copy, oe.dst, deepcopy(oe.data)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import dace | ||
from dace.sdfg import nodes | ||
from dace.transformation.interstate import IfRaising, StateReplication | ||
from dace.transformation.dataflow import OTFMapFusion | ||
import numpy as np | ||
|
||
|
||
def test_raise_and_duplicate_and_fusions(): | ||
N = dace.symbol('N', dace.int64) | ||
@dace.program | ||
def program(flag: dace.bool, in_arr: dace.float64[N], arr: dace.float64[N]): | ||
tmp1 = np.empty_like(arr) | ||
tmp2 = np.empty_like(arr) | ||
for i in dace.map[0:N]: | ||
tmp1[i] = in_arr[i] | ||
if flag: | ||
for i in dace.map[0:N]: | ||
tmp2[i] = tmp1[i] | ||
else: | ||
for i in dace.map[0:N]: | ||
tmp2[i] = tmp1[i] | ||
for i in dace.map[0:N]: | ||
arr[i] = tmp2[i] | ||
|
||
sdfg = program.to_sdfg(simplify=True) | ||
sdfg.apply_transformations([IfRaising, StateReplication]) | ||
sdfg.simplify() | ||
sdfg.apply_transformations_repeated([OTFMapFusion]) | ||
|
||
assert len(sdfg.nodes()) == 4 | ||
assert sdfg.start_state.is_empty() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the tests fails here but it works on my machine. Is there something specific to the setup in the pipeline? |
||
|
||
entries = 0 | ||
for state in sdfg.nodes(): | ||
for node in state.nodes(): | ||
if isinstance(node, nodes.MapEntry): | ||
entries += 1 | ||
|
||
assert entries == 2 | ||
|
||
|
||
def test_if_raise_dependency(): | ||
N = dace.symbol('N', dace.int64) | ||
@dace.program | ||
def program(arr: dace.float64[N]): | ||
flag = np.sum(arr) | ||
if flag: | ||
return 1 | ||
else: | ||
return 0 | ||
|
||
sdfg = program.to_sdfg(simplify=True) | ||
|
||
transform = IfRaising() | ||
transform.if_guard = sdfg.start_state | ||
|
||
assert not transform.can_be_applied(sdfg, 0, sdfg) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_raise_and_duplicate_and_fusions() | ||
test_if_raise_dependency() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be missing something I do not know about here. Also I don't know if this can result in spurious failures