From c1b5a36120f2b79fcf2b08e6d239b17cf1737018 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 2 Oct 2024 08:44:29 +0200 Subject: [PATCH] Some fixes --- dace/frontend/fortran/fortran_parser.py | 4 ++-- dace/transformation/helpers.py | 10 +++++----- dace/transformation/interstate/inline_map.py | 2 +- .../interstate/remove_trivial_structure_view.py | 4 ++-- dace/transformation/interstate/sdfg_nesting.py | 10 +++++++--- dace/transformation/passes/scalar_to_symbol.py | 2 +- 6 files changed, 18 insertions(+), 14 deletions(-) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 219a5ab204..915395b949 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -470,7 +470,7 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG, cfg: Con self.transient_mode=True self.translate(self.startpoint.execution_part.execution, sdfg, cfg) - def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_Stmt_Node, sdfg: SDFG): + def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran pointer assignments into a SDFG. :param node: The node to be translated @@ -533,7 +533,7 @@ def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_S self.unallocated_arrays.remove(i) self.name_mapping[sdfg][node.name_pointer.name] = self.name_mapping[sdfg][node.name_target.name] - def derivedtypedef2sdfg(self, node: ast_internal_classes.Derived_Type_Def_Node, sdfg: SDFG): + def derivedtypedef2sdfg(self, node: ast_internal_classes.Derived_Type_Def_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for registering Fortran derived type declarations into a SDFG as nested data types. :param node: The node to be translated diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 774cd663d6..6d02803344 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -752,8 +752,8 @@ def state_fission_after(state: SDFGState, node: nodes.Node, label: Optional[str] state.add_node(node_) new_nodes[node] = node_ - if isinstance(n, nodes.AccessNode) and isinstance(state.sdfg.arrays[n.data], data.View): - for view_node in get_all_view_nodes(state, n): + if isinstance(node, nodes.AccessNode) and isinstance(state.sdfg.arrays[node.data], data.View): + for view_node in get_all_view_nodes(state, node): node_ = copy.deepcopy(view_node) state.add_node(node_) new_nodes[view_node] = node_ @@ -770,10 +770,10 @@ def state_fission_after(state: SDFGState, node: nodes.Node, label: Optional[str] # Move nodes state.remove_nodes_from(nodes_to_move) - for n in nodes_to_move: - if isinstance(n, nodes.NestedSDFG): + for node in nodes_to_move: + if isinstance(node, nodes.NestedSDFG): # Set the new parent state - n.sdfg.parent = newstate + node.sdfg.parent = newstate newstate.add_nodes_from(nodes_to_move) diff --git a/dace/transformation/interstate/inline_map.py b/dace/transformation/interstate/inline_map.py index f497fe31d8..2256fe40dc 100644 --- a/dace/transformation/interstate/inline_map.py +++ b/dace/transformation/interstate/inline_map.py @@ -413,7 +413,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): new_start_state = matching_edge.dst # Remove unreachable states - branch_subgraph = set([e.dst for e in branch_nsdfg.bfs_edges(new_start_state)]) + branch_subgraph = set([e.dst for e in branch_nsdfg.edge_bfs(new_start_state)]) branch_subgraph.add(new_start_state) states_to_remove = set(branch_nsdfg.states()) - branch_subgraph branch_nsdfg.remove_nodes_from(states_to_remove) diff --git a/dace/transformation/interstate/remove_trivial_structure_view.py b/dace/transformation/interstate/remove_trivial_structure_view.py index e3e5dcfab6..099f528b00 100644 --- a/dace/transformation/interstate/remove_trivial_structure_view.py +++ b/dace/transformation/interstate/remove_trivial_structure_view.py @@ -62,7 +62,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): state.add_edge(viewed, oedge.src_conn, oedge.dst, oedge.dst_conn, Memlet(memlet)) state.remove_edge(oedge) - for edge in state.bfs_edges(viewed): + for edge in state.edge_bfs(viewed): if edge.data.data == view.data: memlet = edge.data.data.replace(view.data, viewed.data) + f"[{edge.data.subset}]" state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, Memlet(memlet)) @@ -75,7 +75,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): state.add_edge(iedge.src, iedge.src_conn, viewed, iedge.dst_conn, Memlet(memlet)) state.remove_edge(iedge) - for edge in state.bfs_edges(viewed, reverse=True): + for edge in state.edge_bfs(viewed, reverse=True): if edge.data.data == view.data: memlet = edge.data.data.replace(view.data, viewed.data) + f"[{edge.data.subset}]" state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, Memlet(memlet)) diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index 78f77f8778..f265790c0a 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -281,9 +281,11 @@ def apply(self, state: SDFGState, sdfg: SDFG): struct_views : Dict[str, str] = {} for e in list(state.in_edges(nsdfg_node)): - # Structure treatment outer_dataname = state.memlet_path(e)[-1].data.data + if outer_dataname is None: + # Empty memlet, no data. + continue outer_tokens = outer_dataname.split('.') outer_dataname = outer_tokens[0] outer_descriptor = sdfg.arrays[outer_dataname] @@ -314,9 +316,11 @@ def apply(self, state: SDFGState, sdfg: SDFG): views[d] = (arr, mem) for e in list(state.out_edges(nsdfg_node)): - # Structure treatment outer_dataname = state.memlet_path(e)[0].data.data + if outer_dataname is None: + # Empty memlet, no data. + continue outer_tokens = outer_dataname.split('.') outer_dataname = outer_tokens[0] outer_descriptor = sdfg.arrays[outer_dataname] @@ -353,7 +357,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): symbolic.safe_replace(nsdfg_node.symbol_mapping, nsdfg.replace_dict) # Access nodes that need to be reshaped - reshapes: Set(str) = set() + reshapes: Set[str] = set() for aname, array in nsdfg.arrays.items(): if array.transient: continue diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index ea3d057411..0d3e4c6ad4 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -671,7 +671,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: input = in_edge.src # There is only zero or one incoming edges by definition - tasklet_inputs = [e.src for e in state.bfs_edges(input, reverse=True)] + tasklet_inputs = [e.src for e in state.edge_bfs(input, reverse=True)] # Step 2.1 new_state = xfh.state_fission(gr.SubgraphView(state, set([input, node] + tasklet_inputs))) new_isedge: sd.InterstateEdge = new_state.parent_graph.out_edges(new_state)[0]