Skip to content

Commit

Permalink
Some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Oct 2, 2024
1 parent 87ba3aa commit c1b5a36
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 14 deletions.
4 changes: 2 additions & 2 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG, cfg: Con
self.transient_mode=True
self.translate(self.startpoint.execution_part.execution, sdfg, cfg)

def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_Stmt_Node, sdfg: SDFG):
def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion):
"""
This function is responsible for translating Fortran pointer assignments into a SDFG.
:param node: The node to be translated
Expand Down Expand Up @@ -533,7 +533,7 @@ def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_S
self.unallocated_arrays.remove(i)
self.name_mapping[sdfg][node.name_pointer.name] = self.name_mapping[sdfg][node.name_target.name]

def derivedtypedef2sdfg(self, node: ast_internal_classes.Derived_Type_Def_Node, sdfg: SDFG):
def derivedtypedef2sdfg(self, node: ast_internal_classes.Derived_Type_Def_Node, sdfg: SDFG, cfg: ControlFlowRegion):
"""
This function is responsible for registering Fortran derived type declarations into a SDFG as nested data types.
:param node: The node to be translated
Expand Down
10 changes: 5 additions & 5 deletions dace/transformation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,8 +752,8 @@ def state_fission_after(state: SDFGState, node: nodes.Node, label: Optional[str]
state.add_node(node_)
new_nodes[node] = node_

if isinstance(n, nodes.AccessNode) and isinstance(state.sdfg.arrays[n.data], data.View):
for view_node in get_all_view_nodes(state, n):
if isinstance(node, nodes.AccessNode) and isinstance(state.sdfg.arrays[node.data], data.View):
for view_node in get_all_view_nodes(state, node):
node_ = copy.deepcopy(view_node)
state.add_node(node_)
new_nodes[view_node] = node_
Expand All @@ -770,10 +770,10 @@ def state_fission_after(state: SDFGState, node: nodes.Node, label: Optional[str]
# Move nodes
state.remove_nodes_from(nodes_to_move)

for n in nodes_to_move:
if isinstance(n, nodes.NestedSDFG):
for node in nodes_to_move:
if isinstance(node, nodes.NestedSDFG):
# Set the new parent state
n.sdfg.parent = newstate
node.sdfg.parent = newstate

newstate.add_nodes_from(nodes_to_move)

Expand Down
2 changes: 1 addition & 1 deletion dace/transformation/interstate/inline_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def apply(self, state: SDFGState, sdfg: SDFG):
new_start_state = matching_edge.dst

# Remove unreachable states
branch_subgraph = set([e.dst for e in branch_nsdfg.bfs_edges(new_start_state)])
branch_subgraph = set([e.dst for e in branch_nsdfg.edge_bfs(new_start_state)])
branch_subgraph.add(new_start_state)
states_to_remove = set(branch_nsdfg.states()) - branch_subgraph
branch_nsdfg.remove_nodes_from(states_to_remove)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def apply(self, state: SDFGState, sdfg: SDFG):
state.add_edge(viewed, oedge.src_conn, oedge.dst, oedge.dst_conn, Memlet(memlet))
state.remove_edge(oedge)

for edge in state.bfs_edges(viewed):
for edge in state.edge_bfs(viewed):
if edge.data.data == view.data:
memlet = edge.data.data.replace(view.data, viewed.data) + f"[{edge.data.subset}]"
state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, Memlet(memlet))
Expand All @@ -75,7 +75,7 @@ def apply(self, state: SDFGState, sdfg: SDFG):
state.add_edge(iedge.src, iedge.src_conn, viewed, iedge.dst_conn, Memlet(memlet))
state.remove_edge(iedge)

for edge in state.bfs_edges(viewed, reverse=True):
for edge in state.edge_bfs(viewed, reverse=True):
if edge.data.data == view.data:
memlet = edge.data.data.replace(view.data, viewed.data) + f"[{edge.data.subset}]"
state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, Memlet(memlet))
Expand Down
10 changes: 7 additions & 3 deletions dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,11 @@ def apply(self, state: SDFGState, sdfg: SDFG):
struct_views : Dict[str, str] = {}

for e in list(state.in_edges(nsdfg_node)):

# Structure treatment
outer_dataname = state.memlet_path(e)[-1].data.data
if outer_dataname is None:
# Empty memlet, no data.
continue
outer_tokens = outer_dataname.split('.')
outer_dataname = outer_tokens[0]
outer_descriptor = sdfg.arrays[outer_dataname]
Expand Down Expand Up @@ -314,9 +316,11 @@ def apply(self, state: SDFGState, sdfg: SDFG):
views[d] = (arr, mem)

for e in list(state.out_edges(nsdfg_node)):

# Structure treatment
outer_dataname = state.memlet_path(e)[0].data.data
if outer_dataname is None:
# Empty memlet, no data.
continue
outer_tokens = outer_dataname.split('.')
outer_dataname = outer_tokens[0]
outer_descriptor = sdfg.arrays[outer_dataname]
Expand Down Expand Up @@ -353,7 +357,7 @@ def apply(self, state: SDFGState, sdfg: SDFG):
symbolic.safe_replace(nsdfg_node.symbol_mapping, nsdfg.replace_dict)

# Access nodes that need to be reshaped
reshapes: Set(str) = set()
reshapes: Set[str] = set()
for aname, array in nsdfg.arrays.items():
if array.transient:
continue
Expand Down
2 changes: 1 addition & 1 deletion dace/transformation/passes/scalar_to_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]:
input = in_edge.src

# There is only zero or one incoming edges by definition
tasklet_inputs = [e.src for e in state.bfs_edges(input, reverse=True)]
tasklet_inputs = [e.src for e in state.edge_bfs(input, reverse=True)]
# Step 2.1
new_state = xfh.state_fission(gr.SubgraphView(state, set([input, node] + tasklet_inputs)))
new_isedge: sd.InterstateEdge = new_state.parent_graph.out_edges(new_state)[0]
Expand Down

0 comments on commit c1b5a36

Please sign in to comment.