From c942f4b95fe67e079497606e039c889420a33af5 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 28 Jun 2023 20:23:44 -0700 Subject: [PATCH] Run reference fix pass on SDFG after deepcopy --- dace/sdfg/sdfg.py | 10 ++++- dace/transformation/passes/fusion_inline.py | 43 ++++++++++++++++++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index adebe51c9b..8d9d442bbc 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -491,7 +491,14 @@ def __deepcopy__(self, memo): setattr(result, '_transformation_hist', copy.deepcopy(self._transformation_hist, memo)) result._sdfg_list = [] if self._parent_sdfg is None: + # Avoid import loops + from dace.transformation.passes.fusion_inline import FixNestedSDFGReferences + result._sdfg_list = result.reset_sdfg_list() + fixed = FixNestedSDFGReferences().apply_pass(result, {}) + if fixed: + warnings.warn(f'Fixed {fixed} nested SDFG parent references during deep copy.') + return result @property @@ -2615,7 +2622,8 @@ def apply_gpu_transformations(self, self.apply_transformations(GPUTransformSDFG, options=dict(sequential_innermaps=sequential_innermaps, - register_trans=register_transients, simplify=simplify), + register_trans=register_transients, + simplify=simplify), validate=validate, validate_all=validate_all, permissive=permissive, diff --git a/dace/transformation/passes/fusion_inline.py b/dace/transformation/passes/fusion_inline.py index abb4a9fe74..74f73e3c93 100644 --- a/dace/transformation/passes/fusion_inline.py +++ b/dace/transformation/passes/fusion_inline.py @@ -7,6 +7,7 @@ from typing import Any, Dict, Optional from dace import SDFG, properties +from dace.sdfg import nodes from dace.sdfg.utils import fuse_states, inline_sdfgs from dace.transformation import pass_pipeline as ppl @@ -20,7 +21,7 @@ class FuseStates(ppl.Pass): CATEGORY: str = 'Simplification' - permissive = properties.Property(dtype=bool, default=False, desc='If True, ignores some race conditions checks.') + permissive = properties.Property(dtype=bool, default=False, desc='If True, ignores some race condition checks.') progress = properties.Property(dtype=bool, default=None, allow_none=True, @@ -82,3 +83,43 @@ def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]: def report(self, pass_retval: int) -> str: return f'Inlined {pass_retval} SDFGs.' + + +@dataclass(unsafe_hash=True) +@properties.make_properties +class FixNestedSDFGReferences(ppl.Pass): + """ + Fixes nested SDFG references to parent state/SDFG/node + """ + + CATEGORY: str = 'Simplification' + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & (ppl.Modifies.States | ppl.Modifies.NestedSDFGs) + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.NestedSDFGs + + def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]: + modified = 0 + for node, state in sdfg.all_nodes_recursive(): + if not isinstance(node, nodes.NestedSDFG): + continue + was_modified = False + if node.sdfg.parent_nsdfg_node is not node: + was_modified = True + node.sdfg.parent_nsdfg_node = node + if node.sdfg.parent is not state: + was_modified = True + node.sdfg.parent = state + if node.sdfg.parent_sdfg is not state.parent: + was_modified = True + node.sdfg.parent_sdfg = state.parent + + if was_modified: + modified += 1 + + return modified or None + + def report(self, pass_retval: int) -> str: + return f'Fixed {pass_retval} nested SDFG references.'