Skip to content

Commit

Permalink
Run reference fix pass on SDFG after deepcopy
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Jun 29, 2023
1 parent a6093b9 commit c942f4b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
10 changes: 9 additions & 1 deletion dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,14 @@ def __deepcopy__(self, memo):
setattr(result, '_transformation_hist', copy.deepcopy(self._transformation_hist, memo))
result._sdfg_list = []
if self._parent_sdfg is None:
# Avoid import loops
from dace.transformation.passes.fusion_inline import FixNestedSDFGReferences

result._sdfg_list = result.reset_sdfg_list()
fixed = FixNestedSDFGReferences().apply_pass(result, {})
if fixed:
warnings.warn(f'Fixed {fixed} nested SDFG parent references during deep copy.')

return result

@property
Expand Down Expand Up @@ -2615,7 +2622,8 @@ def apply_gpu_transformations(self,

self.apply_transformations(GPUTransformSDFG,
options=dict(sequential_innermaps=sequential_innermaps,
register_trans=register_transients, simplify=simplify),
register_trans=register_transients,
simplify=simplify),
validate=validate,
validate_all=validate_all,
permissive=permissive,
Expand Down
43 changes: 42 additions & 1 deletion dace/transformation/passes/fusion_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Dict, Optional

from dace import SDFG, properties
from dace.sdfg import nodes
from dace.sdfg.utils import fuse_states, inline_sdfgs
from dace.transformation import pass_pipeline as ppl

Expand All @@ -20,7 +21,7 @@ class FuseStates(ppl.Pass):

CATEGORY: str = 'Simplification'

permissive = properties.Property(dtype=bool, default=False, desc='If True, ignores some race conditions checks.')
permissive = properties.Property(dtype=bool, default=False, desc='If True, ignores some race condition checks.')
progress = properties.Property(dtype=bool,
default=None,
allow_none=True,
Expand Down Expand Up @@ -82,3 +83,43 @@ def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]:

def report(self, pass_retval: int) -> str:
return f'Inlined {pass_retval} SDFGs.'


@dataclass(unsafe_hash=True)
@properties.make_properties
class FixNestedSDFGReferences(ppl.Pass):
"""
Fixes nested SDFG references to parent state/SDFG/node
"""

CATEGORY: str = 'Simplification'

def should_reapply(self, modified: ppl.Modifies) -> bool:
return modified & (ppl.Modifies.States | ppl.Modifies.NestedSDFGs)

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.NestedSDFGs

def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]:
modified = 0
for node, state in sdfg.all_nodes_recursive():
if not isinstance(node, nodes.NestedSDFG):
continue
was_modified = False
if node.sdfg.parent_nsdfg_node is not node:
was_modified = True
node.sdfg.parent_nsdfg_node = node
if node.sdfg.parent is not state:
was_modified = True
node.sdfg.parent = state
if node.sdfg.parent_sdfg is not state.parent:
was_modified = True
node.sdfg.parent_sdfg = state.parent

if was_modified:
modified += 1

return modified or None

def report(self, pass_retval: int) -> str:
return f'Fixed {pass_retval} nested SDFG references.'

0 comments on commit c942f4b

Please sign in to comment.