Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add state replication and if raising transformations #1639

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dace/transformation/interstate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@
from .trivial_loop_elimination import TrivialLoopElimination
from .multistate_inline import InlineMultistateSDFG
from .move_assignment_outside_if import MoveAssignmentOutsideIf
from .state_replication import StateReplication
from .if_raising import IfRaising
67 changes: 67 additions & 0 deletions dace/transformation/interstate/if_raising.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" If raising transformation """

from dace import data as dt, sdfg as sd
from dace.sdfg import InterstateEdge
from dace.sdfg import utils as sdutil
from dace.sdfg.state import SDFGState
from dace.transformation import transformation
from dace.properties import make_properties

@make_properties
class IfRaising(transformation.MultiStateTransformation):
"""
Duplicates an if guard and anticipates the evaluation of the condition
"""

if_guard = transformation.PatternNode(sd.SDFGState)

@staticmethod
def annotates_memlets():
return False

@classmethod
def expressions(cls):
return [sdutil.node_path_graph(cls.if_guard)]

def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
if_guard: SDFGState = self.if_guard

out_edges = graph.out_edges(if_guard)

if len(out_edges) != 2:
return False

if if_guard.is_empty():
return False

# check that condition does not depend on computations in the state
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be missing something I do not know about here. Also I don't know if this can result in spurious failures

condition_symbols = out_edges[0].data.condition.get_free_symbols()
_, wset = if_guard.read_and_write_sets()
if len(condition_symbols.intersection(wset)) != 0:
return False

return True


def apply(self, _, sdfg: sd.SDFG):
if_guard: SDFGState = self.if_guard

raised_if_guard = sdfg.add_state('raised_if_guard')
sdutil.change_edge_dest(sdfg, if_guard, raised_if_guard)

replica = sd.SDFGState.from_json(if_guard.to_json(), context={'sdfg': sdfg})
all_block_names = set([s.label for s in sdfg.nodes()])
replica.label = dt.find_new_name(replica.label, all_block_names)
sdfg.add_node(replica)

# move conditional edges up
if_branch, else_branch = sdfg.out_edges(if_guard)
sdfg.remove_edge(if_branch)
sdfg.remove_edge(else_branch)

sdfg.add_edge(if_guard, if_branch.dst, InterstateEdge(assignments=if_branch.data.assignments))
sdfg.add_edge(replica, else_branch.dst, InterstateEdge(assignments=else_branch.data.assignments))

sdfg.add_edge(raised_if_guard, if_guard, InterstateEdge(condition=if_branch.data.condition))
sdfg.add_edge(raised_if_guard, replica, InterstateEdge(condition=else_branch.data.condition))
78 changes: 78 additions & 0 deletions dace/transformation/interstate/state_replication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" State replication transformation """

from dace import data as dt, sdfg as sd
from dace.sdfg import utils as sdutil
from dace.sdfg.state import SDFGState
from dace.transformation import transformation
from copy import deepcopy
from dace.transformation.interstate.loop_detection import DetectLoop
from dace.properties import make_properties

@make_properties
class StateReplication(transformation.MultiStateTransformation):
"""
Creates a copy of a state for each of its incoming edge. Then, redirects every edge to a different copy.
This results in states with only one incoming edge.
"""

target_state = transformation.PatternNode(sd.SDFGState)

@staticmethod
def annotates_memlets():
return True

@classmethod
def expressions(cls):
return [sdutil.node_path_graph(cls.target_state)]


def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
target_state: SDFGState = self.target_state

out_edges = graph.out_edges(target_state)
in_edges = graph.in_edges(target_state)

if len(in_edges) < 2:
return False

# avoid useless replications
if target_state.is_empty() and len(out_edges) < 2:
return False

# make sure this is not a loop guard
if len(out_edges) == 2:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Application on loops results in the addition of useless states (but the SDFG is still correct). This will not get rid of the loop, meaning that apply_transformations_repeated will never halt.

detect = DetectLoop()
detect.loop_guard = target_state
detect.loop_begin = out_edges[0].dst
detect.exit_state = out_edges[1].dst
if detect.can_be_applied(graph, 0, sdfg):
return False
detect.exit_state = out_edges[0].dst
detect.loop_begin = out_edges[1].dst
if detect.can_be_applied(graph, 0, sdfg):
return False

return True

def apply(self, _, sdfg: sd.SDFG):
target_state: SDFGState = self.target_state

if len(sdfg.out_edges(target_state)) == 0:
sdfg.add_state_after(target_state)

state_names = set(s.label for s in sdfg.nodes())

root_blueprint = target_state.to_json()
for e in sdfg.in_edges(target_state)[1:]:
state_copy = sd.SDFGState.from_json(root_blueprint, context={'sdfg': sdfg})
state_copy.label = dt.find_new_name(state_copy.label, state_names)
state_names.add(state_copy.label)
sdfg.add_node(state_copy)

sdfg.remove_edge(e)
sdfg.add_edge(e.src, state_copy, e.data)

# connect out edges
for oe in sdfg.out_edges(target_state):
sdfg.add_edge(state_copy, oe.dst, deepcopy(oe.data))
62 changes: 62 additions & 0 deletions tests/transformations/raise_and_duplicate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import dace
from dace.sdfg import nodes
from dace.transformation.interstate import IfRaising, StateReplication
from dace.transformation.dataflow import OTFMapFusion
import numpy as np


def test_raise_and_duplicate_and_fusions():
N = dace.symbol('N', dace.int64)
@dace.program
def program(flag: dace.bool, in_arr: dace.float64[N], arr: dace.float64[N]):
tmp1 = np.empty_like(arr)
tmp2 = np.empty_like(arr)
for i in dace.map[0:N]:
tmp1[i] = in_arr[i]
if flag:
for i in dace.map[0:N]:
tmp2[i] = tmp1[i]
else:
for i in dace.map[0:N]:
tmp2[i] = tmp1[i]
for i in dace.map[0:N]:
arr[i] = tmp2[i]

sdfg = program.to_sdfg(simplify=True)
sdfg.apply_transformations([IfRaising, StateReplication])
sdfg.simplify()
sdfg.apply_transformations_repeated([OTFMapFusion])

assert len(sdfg.nodes()) == 4
assert sdfg.start_state.is_empty()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tests fails here but it works on my machine. Is there something specific to the setup in the pipeline?


entries = 0
for state in sdfg.nodes():
for node in state.nodes():
if isinstance(node, nodes.MapEntry):
entries += 1

assert entries == 2


def test_if_raise_dependency():
N = dace.symbol('N', dace.int64)
@dace.program
def program(arr: dace.float64[N]):
flag = np.sum(arr)
if flag:
return 1
else:
return 0

sdfg = program.to_sdfg(simplify=True)

transform = IfRaising()
transform.if_guard = sdfg.start_state

assert not transform.can_be_applied(sdfg, 0, sdfg)


if __name__ == '__main__':
test_raise_and_duplicate_and_fusions()
test_if_raise_dependency()
Loading