From 036e4073f8662ea407f1939bd415ac7934a26c24 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 16 Sep 2024 11:55:58 +0200 Subject: [PATCH 01/14] Add data dependency analyses --- .../analysis/writeset_underapproximation.py | 311 +++++++++--------- dace/transformation/pass_pipeline.py | 3 +- .../passes/analysis/__init__.py | 1 + .../passes/{ => analysis}/analysis.py | 125 +++++-- .../analysis/control_flow_region_analysis.py | 229 +++++++++++++ .../passes/analysis/loop_analysis.py | 213 ++++++++++++ 6 files changed, 691 insertions(+), 191 deletions(-) create mode 100644 dace/transformation/passes/analysis/__init__.py rename dace/transformation/passes/{ => analysis}/analysis.py (83%) create mode 100644 dace/transformation/passes/analysis/control_flow_region_analysis.py create mode 100644 dace/transformation/passes/analysis/loop_analysis.py diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index bfd5f4cb00..afc34add30 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -8,35 +8,22 @@ import copy import itertools import warnings -from typing import Any, Dict, List, Set, Tuple, Type, Union +from typing import Any, Dict, List, Set, Tuple, Type, TypedDict, Union import sympy import dace +from dace.sdfg.state import LoopRegion +from dace.transformation import transformation from dace.symbolic import issymbolic, pystr_to_symbolic, simplify from dace.transformation.pass_pipeline import Modifies, Pass from dace import registry, subsets, symbolic, dtypes, data, SDFG, Memlet from dace.sdfg.nodes import NestedSDFG, AccessNode from dace.sdfg import nodes, SDFGState, graph as gr -from dace.sdfg.analysis import cfg +from dace.sdfg.analysis import cfg as cfg_analysis from dace.transformation import pass_pipeline as ppl from dace.sdfg import graph from dace.sdfg import scope - -# dictionary mapping each edge to a copy of the memlet of that edge with its write set -# underapproximated -approximation_dict: Dict[graph.Edge, Memlet] = {} -# dictionary that maps loop headers to "border memlets" that are written to in the -# corresponding loop -loop_write_dict: Dict[SDFGState, Dict[str, Memlet]] = {} -# dictionary containing information about the for loops in the SDFG -loop_dict: Dict[SDFGState, Tuple[SDFGState, SDFGState, - List[SDFGState], str, subsets.Range]] = {} -# dictionary mapping each nested SDFG to the iteration variables surrounding it -iteration_variables: Dict[SDFG, Set[str]] = {} -# dictionary mapping each state to the iteration variables surrounding it -# (including the ones from surrounding SDFGs) -ranges_per_state: Dict[SDFGState, - Dict[str, subsets.Range]] = defaultdict(lambda: {}) +from dace.transformation.passes.analysis import loop_analysis @registry.make_registry @@ -417,7 +404,7 @@ def _find_unconditionally_executed_states(sdfg: SDFG) -> Set[SDFGState]: sdfg.add_edge(sink_node, dummy_sink, dace.sdfg.InterstateEdge()) # get all the nodes that are executed unconditionally in the state-machine a.k.a nodes # that dominate the sink states - dominators = cfg.all_dominators(sdfg) + dominators = cfg_analysis.all_dominators(sdfg) states = dominators[dummy_sink] # remove dummy state sdfg.remove_node(dummy_sink) @@ -689,21 +676,34 @@ def _merge_subsets(subset_a: subsets.Subset, subset_b: subsets.Subset) -> subset return subset_b +class UnderapproximateWritesDictT(TypedDict): + approximation: Dict[graph.Edge, Memlet] + loop_approximation: Dict[SDFGState, Dict[str, Memlet]] + loops: Dict[SDFGState, Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]] + + +@transformation.experimental_cfg_block_compatible class UnderapproximateWrites(ppl.Pass): + # Dictionary mapping each edge to a copy of the memlet of that edge with its write set underapproximated. + approximation_dict: Dict[graph.Edge, Memlet] = {} + # Dictionary that maps loop headers to "border memlets" that are written to in the corresponding loop. + loop_write_dict: Dict[SDFGState, Dict[str, Memlet]] = {} + # Dictionary containing information about the for loops in the SDFG. + loop_dict: Dict[SDFGState, Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]] = {} + # Dictionary mapping each nested SDFG to the iteration variables surrounding it. + iteration_variables: Dict[SDFG, Set[str]] = {} + # Mapping of state to the iteration variables surrounding them, including the ones from surrounding SDFGs. + ranges_per_state: Dict[SDFGState, Dict[str, subsets.Range]] = defaultdict(lambda: {}) + def modifies(self) -> Modifies: - return ppl.Modifies.Everything + return ppl.Modifies.States def should_reapply(self, modified: ppl.Modifies) -> bool: - # If anything was modified, reapply - return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes - - def apply_pass( - self, sdfg: dace.SDFG, pipeline_results: Dict[str, Any] - ) -> Dict[str, Union[ - Dict[graph.Edge, Memlet], - Dict[SDFGState, Dict[str, Memlet]], - Dict[SDFGState, Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]]]]: + # If anything was modified, reapply. + return modified & ppl.Modifies.Everything + + def apply_pass(self, top_sdfg: dace.SDFG, _) -> Dict[int, UnderapproximateWritesDictT]: """ Applies the pass to the given SDFG. @@ -725,42 +725,49 @@ def apply_pass( :notes: The only modification this pass performs on the SDFG is splitting interstate edges. """ - # clear the global dictionaries - approximation_dict.clear() - loop_write_dict.clear() - loop_dict.clear() - iteration_variables.clear() - ranges_per_state.clear() - - # fill the approximation dictionary with the original edges as keys and the edges with the - # approximated memlets as values - for (edge, parent) in sdfg.all_edges_recursive(): - if isinstance(parent, SDFGState): - approximation_dict[edge] = copy.deepcopy(edge.data) - if not isinstance(approximation_dict[edge].subset, - subsets.SubsetUnion) and approximation_dict[edge].subset: - approximation_dict[edge].subset = subsets.SubsetUnion( - [approximation_dict[edge].subset]) - if not isinstance(approximation_dict[edge].dst_subset, - subsets.SubsetUnion) and approximation_dict[edge].dst_subset: - approximation_dict[edge].dst_subset = subsets.SubsetUnion( - [approximation_dict[edge].dst_subset]) - if not isinstance(approximation_dict[edge].src_subset, - subsets.SubsetUnion) and approximation_dict[edge].src_subset: - approximation_dict[edge].src_subset = subsets.SubsetUnion( - [approximation_dict[edge].src_subset]) - - self._underapproximate_writes_sdfg(sdfg) - - # Replace None with empty SubsetUnion in each Memlet - for entry in approximation_dict.values(): - if entry.subset is None: - entry.subset = subsets.SubsetUnion([]) - return { - "approximation": approximation_dict, - "loop_approximation": loop_write_dict, - "loops": loop_dict - } + result = defaultdict(lambda: {'approximation': dict(), 'loop_approximation': dict(), 'loops': dict()}) + + for sdfg in top_sdfg.all_sdfgs_recursive(): + # Clear the global dictionaries. + self.approximation_dict.clear() + self.loop_write_dict.clear() + self.loop_dict.clear() + self.iteration_variables.clear() + self.ranges_per_state.clear() + + # fill the approximation dictionary with the original edges as keys and the edges with the + # approximated memlets as values + for (edge, parent) in sdfg.all_edges_recursive(): + if isinstance(parent, SDFGState): + self.approximation_dict[edge] = copy.deepcopy(edge.data) + if not isinstance(self.approximation_dict[edge].subset, + subsets.SubsetUnion) and self.approximation_dict[edge].subset: + self.approximation_dict[edge].subset = subsets.SubsetUnion([ + self.approximation_dict[edge].subset + ]) + if not isinstance(self.approximation_dict[edge].dst_subset, + subsets.SubsetUnion) and self.approximation_dict[edge].dst_subset: + self.approximation_dict[edge].dst_subset = subsets.SubsetUnion([ + self.approximation_dict[edge].dst_subset + ]) + if not isinstance(self.approximation_dict[edge].src_subset, + subsets.SubsetUnion) and self.approximation_dict[edge].src_subset: + self.approximation_dict[edge].src_subset = subsets.SubsetUnion([ + self.approximation_dict[edge].src_subset + ]) + + self._underapproximate_writes_sdfg(sdfg) + + # Replace None with empty SubsetUnion in each Memlet + for entry in self.approximation_dict.values(): + if entry.subset is None: + entry.subset = subsets.SubsetUnion([]) + + result[sdfg.cfg_id]['approximation'] = self.approximation_dict + result[sdfg.cfg_id]['loop_approximation'] = self.loop_write_dict + result[sdfg.cfg_id]['loops'] = self.loop_dict + + return result def _underapproximate_writes_sdfg(self, sdfg: SDFG): """ @@ -770,10 +777,18 @@ def _underapproximate_writes_sdfg(self, sdfg: SDFG): split_interstate_edges(sdfg) loops = self._find_for_loops(sdfg) - loop_dict.update(loops) + self.loop_dict.update(loops) + + for region in sdfg.all_control_flow_regions(): + if isinstance(region, LoopRegion): + start = loop_analysis.get_init_assignment(region) + stop = loop_analysis.get_loop_end(region) + stride = loop_analysis.get_loop_stride(region) + for state in region.all_states(): + self.ranges_per_state[state][region.loop_variable] = subsets.Range([(start, stop, stride)]) - for state in sdfg.nodes(): - self._underapproximate_writes_state(sdfg, state) + for state in region.all_states(): + self._underapproximate_writes_state(sdfg, state) self._underapproximate_writes_loops(loops, sdfg) @@ -885,13 +900,12 @@ def _find_for_loops(self, sources=[begin], condition=lambda _, child: child != guard) - if itvar not in ranges_per_state[begin]: + if itvar not in self.ranges_per_state[begin]: for loop_state in loop_states: - ranges_per_state[loop_state][itervar] = subsets.Range([ - rng]) + self.ranges_per_state[loop_state][itervar] = subsets.Range([rng]) loop_state_list.append(loop_state) - ranges_per_state[guard][itervar] = subsets.Range([rng]) + self.ranges_per_state[guard][itervar] = subsets.Range([rng]) identified_loops[guard] = (begin, last_loop_state, loop_state_list, itvar, subsets.Range([rng])) @@ -934,8 +948,11 @@ def _underapproximate_writes_state(self, sdfg: SDFG, state: SDFGState): # approximation_dict # First, propagate nested SDFGs in a bottom-up fashion + dnodes: Set[nodes.AccessNode] = set() for node in state.nodes(): - if isinstance(node, nodes.NestedSDFG): + if isinstance(node, AccessNode): + dnodes.add(node) + elif isinstance(node, nodes.NestedSDFG): self._find_live_iteration_variables(node, sdfg, state) # Propagate memlets inside the nested SDFG. @@ -947,6 +964,15 @@ def _underapproximate_writes_state(self, sdfg: SDFG, state: SDFGState): # Process scopes from the leaves upwards self._underapproximate_writes_scope(sdfg, state, state.scope_leaves()) + # Make sure any scalar writes are also added if they have not been processed yet. + for dn in dnodes: + desc = sdfg.data(dn.data) + if isinstance(desc, data.Scalar) or (isinstance(desc, data.Array) and desc.total_size == 1): + for iedge in state.in_edges(dn): + if not iedge in self.approximation_dict: + self.approximation_dict[iedge] = copy.deepcopy(iedge.data) + self.approximation_dict[iedge]._edge = iedge + def _find_live_iteration_variables(self, nsdfg: nodes.NestedSDFG, sdfg: SDFG, @@ -963,15 +989,14 @@ def symbol_map(mapping, symbol): return None map_iteration_variables = _collect_iteration_variables(state, nsdfg) - sdfg_iteration_variables = iteration_variables[ - sdfg] if sdfg in iteration_variables else set() - state_iteration_variables = ranges_per_state[state].keys() + sdfg_iteration_variables = self.iteration_variables[sdfg] if sdfg in self.iteration_variables else set() + state_iteration_variables = self.ranges_per_state[state].keys() iteration_variables_local = (map_iteration_variables | sdfg_iteration_variables | state_iteration_variables) mapped_iteration_variables = set( map(lambda x: symbol_map(nsdfg.symbol_mapping, x), iteration_variables_local)) if mapped_iteration_variables: - iteration_variables[nsdfg.sdfg] = mapped_iteration_variables + self.iteration_variables[nsdfg.sdfg] = mapped_iteration_variables def _underapproximate_writes_nested_sdfg( self, @@ -1025,12 +1050,11 @@ def _init_border_memlet(template_memlet: Memlet, # Collect all memlets belonging to this access node memlets = [] for edge in edges: - inside_memlet = approximation_dict[edge] + inside_memlet = self.approximation_dict[edge] memlets.append(inside_memlet) # initialize border memlet if it does not exist already if border_memlet is None: - border_memlet = _init_border_memlet( - inside_memlet, node.label) + border_memlet = _init_border_memlet(inside_memlet, node.label) # Given all of this access nodes' memlets union all the subsets to one SubsetUnion if len(memlets) > 0: @@ -1042,18 +1066,16 @@ def _init_border_memlet(template_memlet: Memlet, border_memlet.subset, subset) # collect the memlets for each loop in the NSDFG - if state in loop_write_dict: - for node_label, loop_memlet in loop_write_dict[state].items(): + if state in self.loop_write_dict: + for node_label, loop_memlet in self.loop_write_dict[state].items(): if node_label not in border_memlets: continue border_memlet = border_memlets[node_label] # initialize border memlet if it does not exist already if border_memlet is None: - border_memlet = _init_border_memlet( - loop_memlet, node_label) + border_memlet = _init_border_memlet(loop_memlet, node_label) # compute the union of the ranges to merge the subsets. - border_memlet.subset = _merge_subsets( - border_memlet.subset, loop_memlet.subset) + border_memlet.subset = _merge_subsets(border_memlet.subset, loop_memlet.subset) # Make sure any potential NSDFG symbol mapping is correctly reversed # when propagating out. @@ -1068,17 +1090,16 @@ def _init_border_memlet(template_memlet: Memlet, # Propagate the inside 'border' memlets outside the SDFG by # offsetting, and unsqueezing if necessary. for edge in parent_state.out_edges(nsdfg_node): - out_memlet = approximation_dict[edge] + out_memlet = self.approximation_dict[edge] if edge.src_conn in border_memlets: internal_memlet = border_memlets[edge.src_conn] if internal_memlet is None: out_memlet.subset = None out_memlet.dst_subset = None - approximation_dict[edge] = out_memlet + self.approximation_dict[edge] = out_memlet continue - out_memlet = _unsqueeze_memlet_subsetunion(internal_memlet, out_memlet, parent_sdfg, - nsdfg_node) - approximation_dict[edge] = out_memlet + out_memlet = _unsqueeze_memlet_subsetunion(internal_memlet, out_memlet, parent_sdfg, nsdfg_node) + self.approximation_dict[edge] = out_memlet def _underapproximate_writes_loop(self, sdfg: SDFG, @@ -1099,9 +1120,7 @@ def _underapproximate_writes_loop(self, propagate_memlet_loop will be called recursively on the outermost loopheaders """ - def _init_border_memlet(template_memlet: Memlet, - node_label: str - ): + def _init_border_memlet(template_memlet: Memlet, node_label: str): ''' Creates a Memlet with the same data as the template_memlet, stores it in the border_memlets dictionary and returns it. @@ -1111,8 +1130,7 @@ def _init_border_memlet(template_memlet: Memlet, border_memlets[node_label] = border_memlet return border_memlet - def filter_subsets(itvar: str, itrange: subsets.Range, - memlet: Memlet) -> List[subsets.Subset]: + def filter_subsets(itvar: str, itrange: subsets.Range, memlet: Memlet) -> List[subsets.Subset]: # helper method that filters out subsets that do not depend on the iteration variable # if the iteration range is symbolic @@ -1134,7 +1152,7 @@ def filter_subsets(itvar: str, itrange: subsets.Range, if rng.num_elements() == 0: return # make sure there is no break out of the loop - dominators = cfg.all_dominators(sdfg) + dominators = cfg_analysis.all_dominators(sdfg) if any(begin not in dominators[s] and not begin is s for s in loop_states): return border_memlets = defaultdict(None) @@ -1159,7 +1177,7 @@ def filter_subsets(itvar: str, itrange: subsets.Range, # collect all the subsets of the incoming memlets for the current access node for edge in edges: - inside_memlet = copy.copy(approximation_dict[edge]) + inside_memlet = copy.copy(self.approximation_dict[edge]) # filter out subsets that could become empty depending on assignments # of symbols filtered_subsets = filter_subsets( @@ -1177,35 +1195,27 @@ def filter_subsets(itvar: str, itrange: subsets.Range, self._underapproximate_writes_loop_subset(sdfg, memlets, border_memlet, sdfg.arrays[node.label], itvar, rng) - if state not in loop_write_dict: + if state not in self.loop_write_dict: continue # propagate the border memlets of nested loop - for node_label, other_border_memlet in loop_write_dict[state].items(): + for node_label, other_border_memlet in self.loop_write_dict[state].items(): # filter out subsets that could become empty depending on symbol assignments - filtered_subsets = filter_subsets( - itvar, rng, other_border_memlet) + filtered_subsets = filter_subsets(itvar, rng, other_border_memlet) if not filtered_subsets: continue - other_border_memlet.subset = subsets.SubsetUnion( - filtered_subsets) + other_border_memlet.subset = subsets.SubsetUnion(filtered_subsets) border_memlet = border_memlets.get(node_label) if border_memlet is None: - border_memlet = _init_border_memlet( - other_border_memlet, node_label) + border_memlet = _init_border_memlet(other_border_memlet, node_label) self._underapproximate_writes_loop_subset(sdfg, [other_border_memlet], border_memlet, sdfg.arrays[node_label], itvar, rng) - loop_write_dict[loop_header] = border_memlets + self.loop_write_dict[loop_header] = border_memlets - def _underapproximate_writes_loop_subset(self, - sdfg: dace.SDFG, - memlets: List[Memlet], - dst_memlet: Memlet, - arr: dace.data.Array, - itvar: str, - rng: subsets.Subset, + def _underapproximate_writes_loop_subset(self, sdfg: dace.SDFG, memlets: List[Memlet], dst_memlet: Memlet, + arr: dace.data.Array, itvar: str, rng: subsets.Subset, loop_nest_itvars: Union[Set[str], None] = None): """ Helper function that takes a list of (border) memlets, propagates them out of a @@ -1223,16 +1233,11 @@ def _underapproximate_writes_loop_subset(self, if len(memlets) > 0: params = [itvar] # get all the other iteration variables surrounding this memlet - surrounding_itvars = iteration_variables[sdfg] if sdfg in iteration_variables else set( - ) + surrounding_itvars = self.iteration_variables[sdfg] if sdfg in self.iteration_variables else set() if loop_nest_itvars: surrounding_itvars |= loop_nest_itvars - subset = self._underapproximate_subsets(memlets, - arr, - params, - rng, - use_dst=True, + subset = self._underapproximate_subsets(memlets, arr, params, rng, use_dst=True, surrounding_itvars=surrounding_itvars).subset if subset is None or len(subset.subset_list) == 0: @@ -1240,9 +1245,7 @@ def _underapproximate_writes_loop_subset(self, # compute the union of the ranges to merge the subsets. dst_memlet.subset = _merge_subsets(dst_memlet.subset, subset) - def _underapproximate_writes_scope(self, - sdfg: SDFG, - state: SDFGState, + def _underapproximate_writes_scope(self, sdfg: SDFG, state: SDFGState, scopes: Union[scope.ScopeTree, List[scope.ScopeTree]]): """ Propagate memlets from the given scopes outwards. @@ -1253,8 +1256,7 @@ def _underapproximate_writes_scope(self, """ # for each map scope find the iteration variables of surrounding maps - surrounding_map_vars: Dict[scope.ScopeTree, - Set[str]] = _collect_itvars_scope(scopes) + surrounding_map_vars: Dict[scope.ScopeTree, Set[str]] = _collect_itvars_scope(scopes) if isinstance(scopes, scope.ScopeTree): scopes_to_process = [scopes] else: @@ -1272,8 +1274,7 @@ def _underapproximate_writes_scope(self, sdfg, state, surrounding_map_vars) - self._underapproximate_writes_node( - state, scope_node.exit, surrounding_iteration_variables) + self._underapproximate_writes_node(state, scope_node.exit, surrounding_iteration_variables) # Add parent to next frontier next_scopes.add(scope_node.parent) scopes_to_process = next_scopes @@ -1286,9 +1287,8 @@ def _collect_iteration_variables_scope_node(self, surrounding_map_vars: Dict[scope.ScopeTree, Set[str]]) -> Set[str]: map_iteration_variables = surrounding_map_vars[ scope_node] if scope_node in surrounding_map_vars else set() - sdfg_iteration_variables = iteration_variables[ - sdfg] if sdfg in iteration_variables else set() - loop_iteration_variables = ranges_per_state[state].keys() + sdfg_iteration_variables = self.iteration_variables[sdfg] if sdfg in self.iteration_variables else set() + loop_iteration_variables = self.ranges_per_state[state].keys() surrounding_iteration_variables = (map_iteration_variables | sdfg_iteration_variables | loop_iteration_variables) @@ -1308,12 +1308,8 @@ def _underapproximate_writes_node(self, :param surrounding_itvars: Iteration variables that surround the map scope """ if isinstance(node, nodes.EntryNode): - internal_edges = [ - e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_') - ] - external_edges = [ - e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_') - ] + internal_edges = [e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_')] + external_edges = [e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_')] def geticonn(e): return e.src_conn[4:] @@ -1323,12 +1319,8 @@ def geteconn(e): use_dst = False else: - internal_edges = [ - e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_') - ] - external_edges = [ - e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_') - ] + internal_edges = [e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_')] + external_edges = [e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_')] def geticonn(e): return e.dst_conn[3:] @@ -1339,21 +1331,17 @@ def geteconn(e): use_dst = True for edge in external_edges: - if approximation_dict[edge].is_empty(): + if self.approximation_dict[edge].is_empty(): new_memlet = Memlet() else: internal_edge = next( e for e in internal_edges if geticonn(e) == geteconn(edge)) - aligned_memlet = self._align_memlet( - dfg_state, internal_edge, dst=use_dst) - new_memlet = self._underapproximate_memlets(dfg_state, - aligned_memlet, - node, - True, - connector=geteconn( - edge), + aligned_memlet = self._align_memlet(dfg_state, internal_edge, dst=use_dst) + new_memlet = self._underapproximate_memlets(dfg_state, aligned_memlet, node, True, + connector=geteconn(edge), surrounding_itvars=surrounding_itvars) - approximation_dict[edge] = new_memlet + new_memlet._edge = edge + self.approximation_dict[edge] = new_memlet def _align_memlet(self, state: SDFGState, @@ -1373,16 +1361,16 @@ def _align_memlet(self, is_src = edge.data._is_data_src # Memlet is already aligned if is_src is None or (is_src and not dst) or (not is_src and dst): - res = approximation_dict[edge] + res = self.approximation_dict[edge] return res # Data<->Code memlets always have one data container mpath = state.memlet_path(edge) if not isinstance(mpath[0].src, AccessNode) or not isinstance(mpath[-1].dst, AccessNode): - return approximation_dict[edge] + return self.approximation_dict[edge] # Otherwise, find other data container - result = copy.deepcopy(approximation_dict[edge]) + result = copy.deepcopy(self.approximation_dict[edge]) if dst: node = mpath[-1].dst else: @@ -1390,8 +1378,8 @@ def _align_memlet(self, # Fix memlet fields result.data = node.data - result.subset = approximation_dict[edge].other_subset - result.other_subset = approximation_dict[edge].subset + result.subset = self.approximation_dict[edge].other_subset + result.other_subset = self.approximation_dict[edge].subset result._is_data_src = not is_src return result @@ -1448,9 +1436,9 @@ def _underapproximate_memlets(self, # and union their subsets if union_inner_edges: aggdata = [ - approximation_dict[e] + self.approximation_dict[e] for e in neighboring_edges - if approximation_dict[e].data == memlet.data and approximation_dict[e] != memlet + if self.approximation_dict[e].data == memlet.data and self.approximation_dict[e] != memlet ] else: aggdata = [] @@ -1459,8 +1447,7 @@ def _underapproximate_memlets(self, if arr is None: if memlet.data not in sdfg.arrays: - raise KeyError('Data descriptor (Array, Stream) "%s" not defined in SDFG.' % - memlet.data) + raise KeyError('Data descriptor (Array, Stream) "%s" not defined in SDFG.' % memlet.data) # FIXME: A memlet alone (without an edge) cannot figure out whether it is data<->data or data<->code # so this test cannot be used diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 494f9c39ae..0da8a96165 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -22,6 +22,7 @@ class Modifies(Flag): Symbols = auto() #: Symbols were modified States = auto() #: The number of SDFG states and their connectivity (not their contents) were modified InterstateEdges = auto() #: Contents (conditions/assignments) or existence of inter-state edges were modified + CFG = States | InterstateEdges #: A CFG (any level) was modified (connectivity or number of control flow blocks, but not their contents) AccessNodes = auto() #: Access nodes' existence or properties were modified Scopes = auto() #: Scopes (e.g., Map, Consume, Pipeline) or associated properties were created/removed/modified Tasklets = auto() #: Tasklets were created/removed or their contents were modified @@ -29,7 +30,7 @@ class Modifies(Flag): Memlets = auto() #: Memlets' existence, contents, or properties were modified Nodes = AccessNodes | Scopes | Tasklets | NestedSDFGs #: Modification of any dataflow node (contained in an SDFG state) was made Edges = InterstateEdges | Memlets #: Any edge (memlet or inter-state) was modified - Everything = Descriptors | Symbols | States | InterstateEdges | Nodes | Memlets #: Modification to arbitrary parts of SDFGs (nodes, edges, or properties) + Everything = Descriptors | Symbols | CFG | Nodes | Memlets #: Modification to arbitrary parts of SDFGs (nodes, edges, or properties) @properties.make_properties diff --git a/dace/transformation/passes/analysis/__init__.py b/dace/transformation/passes/analysis/__init__.py new file mode 100644 index 0000000000..5bc1f6e3f3 --- /dev/null +++ b/dace/transformation/passes/analysis/__init__.py @@ -0,0 +1 @@ +from .analysis import * diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis/analysis.py similarity index 83% rename from dace/transformation/passes/analysis.py rename to dace/transformation/passes/analysis/analysis.py index c8bb0b7a9c..b230425d00 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -1,7 +1,8 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict -from dace.transformation import pass_pipeline as ppl +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion, LoopRegion +from dace.transformation import pass_pipeline as ppl, transformation from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt, symbolic from dace.sdfg.graph import Edge from dace.sdfg import nodes as nd @@ -16,6 +17,7 @@ @properties.make_properties +@transformation.experimental_cfg_block_compatible class StateReachability(ppl.Pass): """ Evaluates state reachability (which other states can be executed after each state). @@ -28,25 +30,84 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply - return modified & ppl.Modifies.States - - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: + return modified & ppl.Modifies.CFG + + def depends_on(self) -> Set[ppl.Pass | ppl.Pass]: + return {ControlFlowBlockReachability} + + def _region_closure(self, region: ControlFlowRegion, + block_reach: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]) -> Set[SDFGState]: + closure: Set[SDFGState] = set() + if isinstance(region, LoopRegion): + # Any point inside the loop may reach any other point inside the loop again. + # TODO(later): This is an overapproximation. A branch terminating in a break is excluded from this. + closure.update(region.all_states()) + + # Add all states that this region can reach in its parent graph to the closure. + for reached_block in block_reach[region.parent_graph.cfg_id][region]: + if isinstance(reached_block, ControlFlowRegion): + closure.update(reached_block.all_states()) + elif isinstance(reached_block, SDFGState): + closure.add(reached_block) + + # Walk up the parent tree. + pivot = region.parent_graph + while pivot and not isinstance(pivot, SDFG): + closure.update(self._region_closure(pivot, block_reach)) + pivot = pivot.parent_graph + return closure + + def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: """ :return: A dictionary mapping each state to its other reachable states. """ + # Ensure control flow block reachability is run if not run within a pipeline. + if not ControlFlowBlockReachability.__name__ in pipeline_res: + cf_block_reach_dict = ControlFlowBlockReachability().apply_pass(top_sdfg, {}) + else: + cf_block_reach_dict = pipeline_res[ControlFlowBlockReachability.__name__] reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[SDFGState, Set[SDFGState]] = {} + result: Dict[SDFGState, Set[SDFGState]] = defaultdict(set) + for state in sdfg.states(): + for reached in cf_block_reach_dict[state.parent_graph.cfg_id][state]: + if isinstance(reached, ControlFlowRegion): + result[state].update(reached.all_states()) + elif isinstance(reached, SDFGState): + result[state].add(reached) + if state.parent_graph is not sdfg: + result[state].update(self._region_closure(state.parent_graph, cf_block_reach_dict)) + reachable[sdfg.cfg_id] = result + return reachable - # In networkx this is currently implemented naively for directed graphs. - # The implementation below is faster - # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) - for n, v in reachable_nodes(sdfg.nx): - result[n] = set(v) +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class ControlFlowBlockReachability(ppl.Pass): + """ + Evaluates control flow block reachability (which control flow block can be executed after each control flow block) + """ - reachable[sdfg.cfg_id] = result + CATEGORY: str = 'Analysis' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]: + """ + :return: For each control flow region, a dictionary mapping each control flow block to its other reachable + control flow blocks in the same region. + """ + reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = defaultdict(lambda: defaultdict(set)) + for cfg in top_sdfg.all_control_flow_regions(recursive=True): + # In networkx this is currently implemented naively for directed graphs. + # The implementation below is faster + # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) + for n, v in reachable_nodes(cfg.nx): + reachable[cfg.cfg_id][n] = set(v) return reachable @@ -99,6 +160,7 @@ def reachable_nodes(G): @properties.make_properties +@transformation.experimental_cfg_block_compatible class SymbolAccessSets(ppl.Pass): """ Evaluates symbol access sets (which symbols are read/written in each state or interstate edge). @@ -116,25 +178,27 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: """ - :return: A dictionary mapping each state to a tuple of its (read, written) data descriptors. + :return: A dictionary mapping each state and interstate edge to a tuple of its (read, written) symbols. """ - top_result: Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]] = {} + top_result: Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - adesc = set(sdfg.arrays.keys()) - result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} - for state in sdfg.nodes(): - readset = state.free_symbols - # No symbols may be written to inside states. - result[state] = (readset, set()) - for oedge in sdfg.out_edges(state): - edge_readset = oedge.data.read_symbols() - adesc - edge_writeset = set(oedge.data.assignments.keys()) - result[oedge] = (edge_readset, edge_writeset) - top_result[sdfg.cfg_id] = result + for cfg in sdfg.all_control_flow_regions(): + adesc = set(sdfg.arrays.keys()) + result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} + for block in cfg.nodes(): + if isinstance(block, SDFGState): + # No symbols may be written to inside states. + result[block] = (block.free_symbols, set()) + for oedge in cfg.out_edges(block): + edge_readset = oedge.data.read_symbols() - adesc + edge_writeset = set(oedge.data.assignments.keys()) + result[oedge] = (edge_readset, edge_writeset) + top_result[cfg.cfg_id] = result return top_result @properties.make_properties +@transformation.experimental_cfg_block_compatible class AccessSets(ppl.Pass): """ Evaluates memory access sets (which arrays/data descriptors are read/written in each state). @@ -179,6 +243,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Tuple[Set[s @properties.make_properties +@transformation.experimental_cfg_block_compatible class FindAccessStates(ppl.Pass): """ For each data descriptor, creates a set of states in which access nodes of that data are used. @@ -201,13 +266,13 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: for sdfg in top_sdfg.all_sdfgs_recursive(): result: Dict[str, Set[SDFGState]] = defaultdict(set) - for state in sdfg.nodes(): + for state in sdfg.states(): for anode in state.data_nodes(): result[anode.data].add(state) # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() - for e in sdfg.edges(): + for e in sdfg.all_interstate_edges(): fsyms = e.data.free_symbols & anames for access in fsyms: result[access].update({e.src, e.dst}) @@ -217,6 +282,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: @properties.make_properties +@transformation.experimental_cfg_block_compatible class FindAccessNodes(ppl.Pass): """ For each data descriptor, creates a dictionary mapping states to all read and write access nodes with the given @@ -242,7 +308,7 @@ def apply_pass(self, top_sdfg: SDFG, for sdfg in top_sdfg.all_sdfgs_recursive(): result: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = defaultdict( lambda: defaultdict(lambda: [set(), set()])) - for state in sdfg.nodes(): + for state in sdfg.states(): for anode in state.data_nodes(): if state.in_degree(anode) > 0: result[anode.data][state][1].add(anode) @@ -508,6 +574,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i @properties.make_properties +@transformation.experimental_cfg_block_compatible class AccessRanges(ppl.Pass): """ For each data descriptor, finds all memlets used to access it (read/write ranges). @@ -544,6 +611,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Memlet]]]: @properties.make_properties +@transformation.experimental_cfg_block_compatible class FindReferenceSources(ppl.Pass): """ For each Reference data descriptor, finds all memlets used to set it. If a Tasklet was used @@ -586,6 +654,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Union[Memlet, @properties.make_properties +@transformation.experimental_cfg_block_compatible class DeriveSDFGConstraints(ppl.Pass): CATEGORY: str = 'Analysis' diff --git a/dace/transformation/passes/analysis/control_flow_region_analysis.py b/dace/transformation/passes/analysis/control_flow_region_analysis.py new file mode 100644 index 0000000000..e11aa945a8 --- /dev/null +++ b/dace/transformation/passes/analysis/control_flow_region_analysis.py @@ -0,0 +1,229 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from collections import defaultdict +from typing import Any, Dict, List, Set, Tuple + +import networkx as nx + +from dace import SDFG, SDFGState +from dace import data as dt +from dace import properties +from dace.memlet import Memlet +from dace.sdfg import nodes, propagation +from dace.sdfg.analysis.writeset_underapproximation import UnderapproximateWrites, UnderapproximateWritesDictT +from dace.sdfg.graph import MultiConnectorEdge +from dace.sdfg.scope import ScopeTree +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion +from dace.subsets import Range +from dace.transformation import pass_pipeline as ppl, transformation +from dace.transformation.passes.analysis import AccessRanges, ControlFlowBlockReachability + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class StateDataDependence(ppl.Pass): + """ + Analyze the input dependencies and the underapproximated outputs of states. + """ + + CATEGORY: str = 'Analysis' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & (ppl.Modifies.Nodes | ppl.Modifies.Memlets) + + def depends_on(self): + return {UnderapproximateWrites, AccessRanges} + + def _gather_reads_scope(self, state: SDFGState, scope: ScopeTree, + writes: Dict[str, List[Tuple[Memlet, nodes.AccessNode]]], + not_covered_reads: Set[Memlet], scope_ranges: Dict[str, Range]): + scope_nodes = state.scope_children()[scope.entry] + data_nodes_in_scope: Set[nodes.AccessNode] = set([n for n in scope_nodes if isinstance(nodes.AccessNode)]) + if scope.entry is not None: + # propagate + pass + + for anode in data_nodes_in_scope: + for oedge in state.out_edges(anode): + if not oedge.data.is_empty(): + root_edge = state.memlet_tree(oedge).root().edge + read_subset = root_edge.data.src_subset + covered = False + for [write, to] in writes[anode.data]: + if write.subset.covers_precise(read_subset) and nx.has_path(state.nx, to, anode): + covered = True + break + if not covered: + not_covered_reads.add(root_edge.data) + + def _state_get_deps(self, state: SDFGState, + underapproximated_writes: UnderapproximateWritesDictT) -> Tuple[Set[Memlet], Set[Memlet]]: + # Collect underapproximated write memlets. + writes: Dict[str, List[Tuple[Memlet, nodes.AccessNode]]] = defaultdict(lambda: []) + for anode in state.data_nodes(): + for iedge in state.in_edges(anode): + if not iedge.data.is_empty(): + root_edge = state.memlet_tree(iedge).root().edge + if root_edge in underapproximated_writes['approximation']: + writes[anode.data].append([underapproximated_writes['approximation'][root_edge], anode]) + else: + writes[anode.data].append([root_edge.data, anode]) + + # Go over (overapproximated) reads and check if they are covered by writes. + not_covered_reads: List[Tuple[MultiConnectorEdge[Memlet], Memlet]] = [] + for anode in state.data_nodes(): + for oedge in state.out_edges(anode): + if not oedge.data.is_empty(): + if oedge.data.data != anode.data: + # Special case for memlets copying data out of the scope, which are by default aligned with the + # outside data container. In this case, the source container must either be a scalar, or the + # read subset is contained in the memlet's `other_subset` property. + # See `dace.sdfg.propagation.align_memlet` for more. + desc = state.sdfg.data(anode.data) + if oedge.data.other_subset is not None: + read_subset = oedge.data.other_subset + elif isinstance(desc, dt.Scalar) or (isinstance(desc, dt.Array) and desc.total_size == 1): + read_subset = Range([(0, 0, 1)] * len(desc.shape)) + else: + raise RuntimeError('Invalid memlet range detected in StateDataDependence analysis') + else: + read_subset = oedge.data.src_subset or oedge.data.subset + covered = False + for [write, to] in writes[anode.data]: + if write.subset.covers_precise(read_subset) and nx.has_path(state.nx, to, anode): + covered = True + break + if not covered: + #root_edge = state.memlet_tree(oedge).root().edge + #not_covered_reads.append([root_edge, root_edge.data]) + not_covered_reads.append([oedge, oedge.data]) + # Make sure all reads are propagated if they happen inside maps. We do not need to do this for writes, because + # it is already taken care of by the write underapproximation analysis pass. + self._recursive_propagate_reads(state, state.scope_tree()[None], not_covered_reads) + + write_set = set() + for data in writes: + for memlet, _ in writes[data]: + write_set.add(memlet) + + read_set = set() + for reads in not_covered_reads: + read_set.add(reads[1]) + + return read_set, write_set + + def _recursive_propagate_reads(self, state: SDFGState, scope: ScopeTree, + read_edges: Set[Tuple[MultiConnectorEdge[Memlet], Memlet]]): + for child in scope.children: + self._recursive_propagate_reads(state, child, read_edges) + + if scope.entry is not None: + if isinstance(scope.entry, nodes.MapEntry): + for read_tuple in read_edges: + read_edge, read_memlet = read_tuple + for param in scope.entry.map.params: + if param in read_memlet.free_symbols: + aligned_memlet = propagation.align_memlet(state, read_edge, True) + propagated_memlet = propagation.propagate_memlet(state, aligned_memlet, scope.entry, True) + read_tuple[1] = propagated_memlet + + def apply_pass(self, top_sdfg: SDFG, + pipeline_results: Dict[str, Any]) -> Dict[int, Dict[SDFGState, Tuple[Set[Memlet], Set[Memlet]]]]: + """ + :return: For each SDFG, a dictionary mapping states to sets of their input and output memlets. + """ + + results = defaultdict(lambda: defaultdict(lambda: [set(), set()])) + + underapprox_writes_dict: Dict[int, Any] = pipeline_results[UnderapproximateWrites.__name__] + for sdfg in top_sdfg.all_sdfgs_recursive(): + uapprox_writes = underapprox_writes_dict[sdfg.cfg_id] + for state in sdfg.states(): + input_dependencies, output_dependencies = self._state_get_deps(state, uapprox_writes) + results[sdfg.cfg_id][state] = [input_dependencies, output_dependencies] + + return results + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class CFGDataDependence(ppl.Pass): + """ + Analyze the input dependencies and the underapproximated outputs of control flow graphs / regions. + """ + + CATEGORY: str = 'Analysis' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + + def depends_on(self): + return {StateDataDependence, ControlFlowBlockReachability} + + def _recursive_get_deps_region(self, cfg: ControlFlowRegion, + results: Dict[int, Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]], + state_deps: Dict[int, Dict[SDFGState, Tuple[Set[Memlet], Set[Memlet]]]], + cfg_reach: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] + ) -> Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]: + # Collect all individual reads and writes happening inside the region. + region_reads: Dict[str, List[Tuple[Memlet, ControlFlowBlock]]] = defaultdict(list) + region_writes: Dict[str, List[Tuple[Memlet, ControlFlowBlock]]] = defaultdict(list) + for node in cfg.nodes(): + if isinstance(node, SDFGState): + for read in state_deps[node.sdfg.cfg_id][node][0]: + region_reads[read.data].append([read, node]) + for write in state_deps[node.sdfg.cfg_id][node][1]: + region_writes[write.data].append([write, node]) + elif isinstance(node, ControlFlowRegion): + sub_reads, sub_writes = self._recursive_get_deps_region(node, results, state_deps, cfg_reach) + for data in sub_reads: + for read in sub_reads[data]: + region_reads[data].append([read, node]) + for data in sub_writes: + for write in sub_writes[data]: + region_writes[data].append([write, node]) + + # Through reachability analysis, check which writes cover which reads. + # TODO: make sure this doesn't cover up reads if we have a cycle in the CFG. + not_covered_reads: Dict[str, Set[Memlet]] = defaultdict(set) + for data in region_reads: + for read, read_block in region_reads[data]: + covered = False + for write, write_block in region_writes[data]: + if (write.subset.covers_precise(read.src_subset or read.subset) and + write_block is not read_block and + nx.has_path(cfg.nx, write_block, read_block)): + covered = True + break + if not covered: + not_covered_reads[data].add(read) + + write_set: Dict[str, Set[Memlet]] = defaultdict(set) + for data in region_writes: + for memlet, _ in region_writes[data]: + write_set[data].add(memlet) + + results[cfg.cfg_id] = [not_covered_reads, write_set] + + return not_covered_reads, write_set + + def apply_pass(self, top_sdfg: SDFG, + pipeline_res: Dict[str, Any]) -> Dict[int, Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]]: + """ + :return: For each SDFG, a dictionary mapping states to sets of their input and output memlets. + """ + + results = defaultdict(lambda: defaultdict(lambda: [defaultdict(set), defaultdict(set)])) + + state_deps_dict = pipeline_res[StateDataDependence.__name__] + cfb_reachability_dict = pipeline_res[ControlFlowBlockReachability.__name__] + for sdfg in top_sdfg.all_sdfgs_recursive(): + self._recursive_get_deps_region(sdfg, results, state_deps_dict, cfb_reachability_dict) + + return results diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py new file mode 100644 index 0000000000..293021de9c --- /dev/null +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -0,0 +1,213 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +from collections import defaultdict +from typing import Any, Dict, Optional, Set, Tuple + +import sympy + +from dace import SDFG, properties, symbolic, transformation +from dace.memlet import Memlet +from dace.sdfg.state import LoopRegion +from dace.subsets import Range, SubsetUnion +from dace.transformation import pass_pipeline as ppl +from dace.transformation.pass_pipeline import Pass +from dace.transformation.passes.analysis.control_flow_region_analysis import CFGDataDependence + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class LoopCarryDependencyAnalysis(ppl.Pass): + """ + Analyze the data dependencies between loop iterations for loop regions. + """ + + CATEGORY: str = 'Analysis' + + _non_analyzable_loops: Set[LoopRegion] + + def __init__(self): + self._non_analyzable_loops = set() + super().__init__() + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + + def depends_on(self) -> Set[type[Pass] | Pass]: + return {CFGDataDependence} + + def _intersects(self, loop: LoopRegion, write_subset: Range, read_subset: Range, update: sympy.Basic) -> bool: + """ + Check if a write subset intersects a read subset after being offset by the loop stride. The offset is performed + based on the symbolic loop update assignment expression. + """ + offset = update - symbolic.symbol(loop.loop_variable) + offset_list = [] + for i in range(write_subset.dims()): + if loop.loop_variable in write_subset.get_free_symbols_by_indices([i]): + offset_list.append(offset) + else: + offset_list.append(0) + offset_write = write_subset.offset_new(offset_list, True) + return offset_write.intersects(read_subset) + + def apply_pass(self, top_sdfg: SDFG, + pipeline_results: Dict[str, Any]) -> Dict[int, Dict[LoopRegion, Dict[Memlet, Set[Memlet]]]]: + """ + :return: For each SDFG, a dictionary mapping loop regions to a dictionary that resolves reads to writes in the + same loop, from which they may carry a RAW dependency. + """ + results = defaultdict(lambda: defaultdict(dict)) + + cfg_dependency_dict: Dict[int, Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]] = pipeline_results[ + CFGDataDependence.__name__ + ] + for cfg in top_sdfg.all_control_flow_regions(recursive=True): + if isinstance(cfg, LoopRegion): + loop_inputs, loop_outputs = cfg_dependency_dict[cfg.cfg_id] + update_assignment = None + loop_dependencies: Dict[Memlet, Set[Memlet]] = dict() + + for data in loop_inputs: + if not data in loop_outputs: + continue + + for input in loop_inputs[data]: + read_subset = input.src_subset or input.subset + dep_candidates: Set[Memlet] = set() + if cfg.loop_variable and cfg.loop_variable in input.free_symbols: + # If the iteration variable is involved in an access, we need to first offset it by the loop + # stride and then check for an overlap/intersection. If one is found after offsetting, there + # is a RAW loop carry dependency. + for output in loop_outputs[data]: + # Get and cache the update assignment for the loop. + if update_assignment is None and not cfg in self._non_analyzable_loops: + update_assignment = get_update_assignment(cfg) + if update_assignment is None: + self._non_analyzable_loops(cfg) + + if isinstance(output.subset, SubsetUnion): + if any([self._intersects(cfg, s, read_subset, update_assignment) + for s in output.subset.subset_list]): + dep_candidates.add(output) + elif self._intersects(cfg, output.subset, read_subset, update_assignment): + dep_candidates.add(output) + else: + # Check for basic overlaps/intersections in RAW loop carry dependencies, when there is no + # iteration variable involved. + for output in loop_outputs[data]: + if isinstance(output.subset, SubsetUnion): + if any([s.intersects(read_subset) for s in output.subset.subset_list]): + dep_candidates.add(output) + elif output.subset.intersects(read_subset): + dep_candidates.add(output) + loop_dependencies[input] = dep_candidates + results[cfg.sdfg.cfg_id][cfg] = loop_dependencies + + return results + + +class FindAssignment(ast.NodeVisitor): + + assignments: Dict[str, str] + multiple: bool + + def __init__(self): + self.assignments = {} + self.multiple = False + + def visit_Assign(self, node: ast.Assign) -> Any: + for tgt in node.targets: + if isinstance(tgt, ast.Name): + if tgt.id in self.assignments: + self.multiple = True + self.assignments[tgt.id] = ast.unparse(node.value) + return self.generic_visit(node) + + +def get_loop_end(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region to identify the end value of the iteration variable under normal loop termination (no break). + """ + end: Optional[symbolic.SymbolicType] = None + a = sympy.Wild('a') + condition = symbolic.pystr_to_symbolic(loop.loop_condition.as_string) + itersym = symbolic.pystr_to_symbolic(loop.loop_variable) + match = condition.match(itersym < a) + if match: + end = match[a] - 1 + if end is None: + match = condition.match(itersym <= a) + if match: + end = match[a] + if end is None: + match = condition.match(itersym > a) + if match: + end = match[a] + 1 + if end is None: + match = condition.match(itersym >= a) + if match: + end = match[a] + return end + + +def get_init_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region's init statement to identify the exact init assignment expression. + """ + init_stmt = loop.init_statement + if init_stmt is None: + return None + + init_codes_list = init_stmt.code if isinstance(init_stmt.code, list) else [init_stmt.code] + assignments: Dict[str, str] = {} + for code in init_codes_list: + visitor = FindAssignment() + visitor.visit(code) + if visitor.multiple: + return None + for assign in visitor.assignments: + if assign in assignments: + return None + assignments[assign] = visitor.assignments[assign] + + if loop.loop_variable in assignments: + return symbolic.pystr_to_symbolic(assignments[loop.loop_variable]) + + return None + + +def get_update_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region's update statement to identify the exact update assignment expression. + """ + update_stmt = loop.update_statement + if update_stmt is None: + return None + + update_codes_list = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] + assignments: Dict[str, str] = {} + for code in update_codes_list: + visitor = FindAssignment() + visitor.visit(code) + if visitor.multiple: + return None + for assign in visitor.assignments: + if assign in assignments: + return None + assignments[assign] = visitor.assignments[assign] + + if loop.loop_variable in assignments: + return symbolic.pystr_to_symbolic(assignments[loop.loop_variable]) + + return None + + +def get_loop_stride(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + update_assignment = get_update_assignment(loop) + if update_assignment: + return update_assignment - symbolic.pystr_to_symbolic(loop.loop_variable) + return None From 4aa13eda11d990185cc73bf7595f847f713e0d4b Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 16 Sep 2024 13:53:02 +0200 Subject: [PATCH 02/14] Fix type --- dace/transformation/passes/analysis/analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index b230425d00..d0fc8decdc 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -32,7 +32,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.CFG - def depends_on(self) -> Set[ppl.Pass | ppl.Pass]: + def depends_on(self): return {ControlFlowBlockReachability} def _region_closure(self, region: ControlFlowRegion, From a228f34a8d1e448c31ae97bfa15f6c7de3a5a535 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 16 Sep 2024 14:28:14 +0200 Subject: [PATCH 03/14] Fix types --- .../analysis/writeset_underapproximation.py | 30 ++++++++++++------- .../passes/analysis/analysis.py | 2 +- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index afc34add30..3dcdbf3473 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -4,25 +4,33 @@ an SDFG. """ -from collections import defaultdict import copy import itertools +import sys import warnings -from typing import Any, Dict, List, Set, Tuple, Type, TypedDict, Union +from collections import defaultdict +from typing import Dict, List, Set, Tuple, Union + +if sys.version >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + import sympy import dace +from dace import SDFG, Memlet, data, dtypes, registry, subsets, symbolic +from dace.sdfg import SDFGState +from dace.sdfg import graph +from dace.sdfg import graph as gr +from dace.sdfg import nodes, scope +from dace.sdfg.analysis import cfg as cfg_analysis +from dace.sdfg.nodes import AccessNode, NestedSDFG from dace.sdfg.state import LoopRegion -from dace.transformation import transformation from dace.symbolic import issymbolic, pystr_to_symbolic, simplify -from dace.transformation.pass_pipeline import Modifies, Pass -from dace import registry, subsets, symbolic, dtypes, data, SDFG, Memlet -from dace.sdfg.nodes import NestedSDFG, AccessNode -from dace.sdfg import nodes, SDFGState, graph as gr -from dace.sdfg.analysis import cfg as cfg_analysis from dace.transformation import pass_pipeline as ppl -from dace.sdfg import graph -from dace.sdfg import scope +from dace.transformation import transformation +from dace.transformation.pass_pipeline import Modifies from dace.transformation.passes.analysis import loop_analysis @@ -807,8 +815,8 @@ def _find_for_loops(self, """ # We import here to avoid cyclic imports. - from dace.transformation.interstate.loop_detection import find_for_loop from dace.sdfg import utils as sdutils + from dace.transformation.interstate.loop_detection import find_for_loop # dictionary mapping loop headers to beginstate, loopstates, looprange identified_loops = {} diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index d0fc8decdc..1a4ab01b88 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -62,7 +62,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGS :return: A dictionary mapping each state to its other reachable states. """ # Ensure control flow block reachability is run if not run within a pipeline. - if not ControlFlowBlockReachability.__name__ in pipeline_res: + if pipeline_res is None or not ControlFlowBlockReachability.__name__ in pipeline_res: cf_block_reach_dict = ControlFlowBlockReachability().apply_pass(top_sdfg, {}) else: cf_block_reach_dict = pipeline_res[ControlFlowBlockReachability.__name__] From 10c3b6c74ec8074cd30a74aace2614d197b21e77 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 16 Sep 2024 16:56:30 +0200 Subject: [PATCH 04/14] Update tests --- .../analysis/writeset_underapproximation.py | 4 +- .../analysis/control_flow_region_analysis.py | 4 +- .../analysis/control_flow_region_analysis.py | 80 +++++++++++++++++++ 3 files changed, 83 insertions(+), 5 deletions(-) create mode 100644 tests/passes/analysis/control_flow_region_analysis.py diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index 3dcdbf3473..c4a685e62a 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -11,7 +11,7 @@ from collections import defaultdict from typing import Dict, List, Set, Tuple, Union -if sys.version >= (3, 8): +if sys.version_info >= (3, 8): from typing import TypedDict else: from typing_extensions import TypedDict @@ -31,7 +31,6 @@ from dace.transformation import pass_pipeline as ppl from dace.transformation import transformation from dace.transformation.pass_pipeline import Modifies -from dace.transformation.passes.analysis import loop_analysis @registry.make_registry @@ -782,6 +781,7 @@ def _underapproximate_writes_sdfg(self, sdfg: SDFG): Underapproximates write-sets of loops, maps and nested SDFGs in the given SDFG. """ from dace.transformation.helpers import split_interstate_edges + from dace.transformation.passes.analysis import loop_analysis split_interstate_edges(sdfg) loops = self._find_for_loops(sdfg) diff --git a/dace/transformation/passes/analysis/control_flow_region_analysis.py b/dace/transformation/passes/analysis/control_flow_region_analysis.py index e11aa945a8..92b2badecf 100644 --- a/dace/transformation/passes/analysis/control_flow_region_analysis.py +++ b/dace/transformation/passes/analysis/control_flow_region_analysis.py @@ -16,7 +16,7 @@ from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion from dace.subsets import Range from dace.transformation import pass_pipeline as ppl, transformation -from dace.transformation.passes.analysis import AccessRanges, ControlFlowBlockReachability +from dace.transformation.passes.analysis.analysis import AccessRanges, ControlFlowBlockReachability @properties.make_properties @@ -97,8 +97,6 @@ def _state_get_deps(self, state: SDFGState, covered = True break if not covered: - #root_edge = state.memlet_tree(oedge).root().edge - #not_covered_reads.append([root_edge, root_edge.data]) not_covered_reads.append([oedge, oedge.data]) # Make sure all reads are propagated if they happen inside maps. We do not need to do this for writes, because # it is already taken care of by the write underapproximation analysis pass. diff --git a/tests/passes/analysis/control_flow_region_analysis.py b/tests/passes/analysis/control_flow_region_analysis.py new file mode 100644 index 0000000000..64461edd85 --- /dev/null +++ b/tests/passes/analysis/control_flow_region_analysis.py @@ -0,0 +1,80 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests analysis passes related to control flow regions (control_flow_region_analysis.py). """ + + +import dace +from dace.memlet import Memlet +from dace.sdfg.sdfg import InterstateEdge +from dace.sdfg.state import LoopRegion, SDFGState +from dace.transformation.pass_pipeline import Pipeline +from dace.transformation.passes.analysis.control_flow_region_analysis import CFGDataDependence, StateDataDependence + + +def test_simple_state_data_dependence_with_self_contained_read(): + N = dace.symbol('N') + + @dace.program + def myprog(A: dace.float64[N], B: dace.float64): + for i in dace.map[0:N/2]: + with dace.tasklet: + in1 << B[i] + out1 >> A[i] + out1 = in1 + 1 + with dace.tasklet: + in1 << B[i] + out1 >> B[N - (i + 1)] + out1 = in1 - 1 + for i in dace.map[0:N/2]: + with dace.tasklet: + in1 << A[i] + out1 >> B[i] + out1 = in1 * 2 + + sdfg = myprog.to_sdfg() + + res = {} + Pipeline([StateDataDependence()]).apply_pass(sdfg, res) + state_data_deps = res[StateDataDependence.__name__][0][sdfg.states()[0]] + + assert len(state_data_deps[0]) == 1 + read_memlet: Memlet = list(state_data_deps[0])[0] + assert read_memlet.data == 'B' + assert read_memlet.subset[0][0] == 0 + assert read_memlet.subset[0][1] == 0.5 * N - 1 or read_memlet.subset[0][1] == N / 2 - 1 + + assert len(state_data_deps[1]) == 3 + + +''' +def test_nested_cf_region_data_dependence(): + N = dace.symbol('N') + + @dace.program + def myprog(A: dace.float64[N], B: dace.float64): + for i in range(N): + with dace.tasklet: + in1 << B[i] + out1 >> A[i] + out1 = in1 + 1 + for i in range(N): + with dace.tasklet: + in1 << A[i] + out1 >> B[i] + out1 = in1 * 2 + + myprog.use_experimental_cfg_blocks = True + + sdfg = myprog.to_sdfg() + + res = {} + pipeline = Pipeline([CFGDataDependence()]) + pipeline.__experimental_cfg_block_compatible__ = True + pipeline.apply_pass(sdfg, res) + + print(sdfg) + ''' + + +if __name__ == '__main__': + test_simple_state_data_dependence_with_self_contained_read() + #test_nested_cf_region_data_dependence() From 77ca17f4db1984a2adc7da7440ec8cc340c16543 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 16 Sep 2024 17:51:50 +0200 Subject: [PATCH 05/14] Fixes --- dace/frontend/python/parser.py | 2 + dace/transformation/helpers.py | 4 +- .../passes/analysis/analysis.py | 80 ++++++++++++------- .../passes/analysis/loop_analysis.py | 2 +- ...y => control_flow_region_analysis_test.py} | 0 5 files changed, 56 insertions(+), 32 deletions(-) rename tests/passes/analysis/{control_flow_region_analysis.py => control_flow_region_analysis_test.py} (100%) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index e55829933c..e0900c749b 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -498,6 +498,8 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF sdutils.inline_control_flow_regions(sdfg) sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks + sdfg.reset_cfg_list() + # Apply simplification pass automatically if not cached and (simplify == True or (simplify is None and Config.get_bool('optimizer', 'automatic_simplification'))): diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 0d583236cb..6c17538a37 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -379,7 +379,7 @@ def nest_state_subgraph(sdfg: SDFG, SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ - if state.parent != sdfg: + if state.sdfg != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph is not state and subgraph.graph is not state: raise KeyError('Subgraph does not belong to given state') @@ -433,7 +433,7 @@ def nest_state_subgraph(sdfg: SDFG, # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG - other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes() + other_nodes = set(n.data for s in sdfg.states() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 1a4ab01b88..095319f807 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -35,28 +35,6 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {ControlFlowBlockReachability} - def _region_closure(self, region: ControlFlowRegion, - block_reach: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]) -> Set[SDFGState]: - closure: Set[SDFGState] = set() - if isinstance(region, LoopRegion): - # Any point inside the loop may reach any other point inside the loop again. - # TODO(later): This is an overapproximation. A branch terminating in a break is excluded from this. - closure.update(region.all_states()) - - # Add all states that this region can reach in its parent graph to the closure. - for reached_block in block_reach[region.parent_graph.cfg_id][region]: - if isinstance(reached_block, ControlFlowRegion): - closure.update(reached_block.all_states()) - elif isinstance(reached_block, SDFGState): - closure.add(reached_block) - - # Walk up the parent tree. - pivot = region.parent_graph - while pivot and not isinstance(pivot, SDFG): - closure.update(self._region_closure(pivot, block_reach)) - pivot = pivot.parent_graph - return closure - def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: """ :return: A dictionary mapping each state to its other reachable states. @@ -71,12 +49,8 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGS result: Dict[SDFGState, Set[SDFGState]] = defaultdict(set) for state in sdfg.states(): for reached in cf_block_reach_dict[state.parent_graph.cfg_id][state]: - if isinstance(reached, ControlFlowRegion): - result[state].update(reached.all_states()) - elif isinstance(reached, SDFGState): + if isinstance(reached, SDFGState): result[state].add(reached) - if state.parent_graph is not sdfg: - result[state].update(self._region_closure(state.parent_graph, cf_block_reach_dict)) reachable[sdfg.cfg_id] = result return reachable @@ -90,24 +64,72 @@ class ControlFlowBlockReachability(ppl.Pass): CATEGORY: str = 'Analysis' + contain_to_single_level = properties.Property(dtype=bool, default=False) + + def __init__(self, contain_to_single_level=False) -> None: + super().__init__() + + self.contain_to_single_level = contain_to_single_level + def modifies(self) -> ppl.Modifies: return ppl.Modifies.Nothing def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.CFG + def _region_closure(self, region: ControlFlowRegion, + block_reach: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]) -> Set[SDFGState]: + closure: Set[SDFGState] = set() + if isinstance(region, LoopRegion): + # Any point inside the loop may reach any other point inside the loop again. + # TODO(later): This is an overapproximation. A branch terminating in a break is excluded from this. + closure.update(region.all_control_flow_blocks()) + + # Add all states that this region can reach in its parent graph to the closure. + for reached_block in block_reach[region.parent_graph.cfg_id][region]: + if isinstance(reached_block, ControlFlowRegion): + closure.update(reached_block.all_control_flow_blocks()) + closure.add(reached_block) + + # Walk up the parent tree. + pivot = region.parent_graph + while pivot and not isinstance(pivot, SDFG): + closure.update(self._region_closure(pivot, block_reach)) + pivot = pivot.parent_graph + return closure + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]: """ :return: For each control flow region, a dictionary mapping each control flow block to its other reachable control flow blocks in the same region. """ - reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = defaultdict(lambda: defaultdict(set)) + single_level_reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = defaultdict( + lambda: defaultdict(set) + ) for cfg in top_sdfg.all_control_flow_regions(recursive=True): # In networkx this is currently implemented naively for directed graphs. # The implementation below is faster # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) for n, v in reachable_nodes(cfg.nx): - reachable[cfg.cfg_id][n] = set(v) + single_level_reachable[cfg.cfg_id][n] = set(v) + if isinstance(cfg, LoopRegion): + single_level_reachable[cfg.cfg_id][n].update(cfg.nodes()) + + if self.contain_to_single_level: + return single_level_reachable + + reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = {} + for sdfg in top_sdfg.all_sdfgs_recursive(): + for cfg in sdfg.all_control_flow_regions(): + result: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = defaultdict(set) + for block in cfg.nodes(): + for reached in single_level_reachable[block.parent_graph.cfg_id][block]: + if isinstance(reached, ControlFlowRegion): + result[block].update(reached.all_control_flow_blocks()) + result[block].add(reached) + if block.parent_graph is not sdfg: + result[block].update(self._region_closure(block.parent_graph, single_level_reachable)) + reachable[cfg.cfg_id] = result return reachable diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py index 293021de9c..dd8c5f7446 100644 --- a/dace/transformation/passes/analysis/loop_analysis.py +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -36,7 +36,7 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.CFG - def depends_on(self) -> Set[type[Pass] | Pass]: + def depends_on(self): return {CFGDataDependence} def _intersects(self, loop: LoopRegion, write_subset: Range, read_subset: Range, update: sympy.Basic) -> bool: diff --git a/tests/passes/analysis/control_flow_region_analysis.py b/tests/passes/analysis/control_flow_region_analysis_test.py similarity index 100% rename from tests/passes/analysis/control_flow_region_analysis.py rename to tests/passes/analysis/control_flow_region_analysis_test.py From 4e6035d68e41ef33a4e06253bb472af662fe2e2f Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 17 Sep 2024 14:03:19 +0200 Subject: [PATCH 06/14] Fix tests --- .../analysis/writeset_underapproximation.py | 13 ++- dace/sdfg/propagation.py | 23 ++--- .../analysis/control_flow_region_analysis.py | 2 +- .../control_flow_region_analysis_test.py | 45 ++++----- .../writeset_underapproximation_test.py | 94 ++++++++++++------- 5 files changed, 103 insertions(+), 74 deletions(-) diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index c4a685e62a..e1b88f9401 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -1,7 +1,6 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ -Pass derived from ``propagation.py`` that under-approximates write-sets of for-loops and Maps in -an SDFG. +Pass derived from ``propagation.py`` that under-approximates write-sets of for-loops and Maps in an SDFG. """ import copy @@ -736,11 +735,11 @@ def apply_pass(self, top_sdfg: dace.SDFG, _) -> Dict[int, UnderapproximateWrites for sdfg in top_sdfg.all_sdfgs_recursive(): # Clear the global dictionaries. - self.approximation_dict.clear() - self.loop_write_dict.clear() - self.loop_dict.clear() - self.iteration_variables.clear() - self.ranges_per_state.clear() + self.approximation_dict = {} + self.loop_write_dict = {} + self.loop_dict = {} + self.iteration_variables = {} + self.ranges_per_state = defaultdict(lambda: {}) # fill the approximation dictionary with the original edges as keys and the edges with the # approximated memlets as values diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 1c038dd2e4..6447d8f89b 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -4,21 +4,22 @@ from internal memory accesses and scope ranges). """ -from collections import deque import copy -from dace.symbolic import issymbolic, pystr_to_symbolic, simplify -import itertools import functools +import itertools +import warnings +from collections import deque +from typing import List, Set + import sympy -from sympy import ceiling, Symbol +from sympy import Symbol, ceiling from sympy.concrete.summations import Sum -import warnings -import networkx as nx -from dace import registry, subsets, symbolic, dtypes, data +from dace import data, dtypes, registry, subsets, symbolic from dace.memlet import Memlet -from dace.sdfg import nodes, graph as gr -from typing import List, Set +from dace.sdfg import graph as gr +from dace.sdfg import nodes +from dace.symbolic import issymbolic, pystr_to_symbolic, simplify @registry.make_registry @@ -569,8 +570,8 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): """ # We import here to avoid cyclic imports. - from dace.transformation.interstate.loop_detection import find_for_loop from dace.sdfg import utils as sdutils + from dace.transformation.interstate.loop_detection import find_for_loop condition_edges = {} @@ -739,8 +740,8 @@ def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: # We import here to avoid cyclic imports. from dace.sdfg import InterstateEdge - from dace.transformation.helpers import split_interstate_edges from dace.sdfg.analysis import cfg + from dace.transformation.helpers import split_interstate_edges # Reset the state edge annotations (which may have changed due to transformations) reset_state_annotations(sdfg) diff --git a/dace/transformation/passes/analysis/control_flow_region_analysis.py b/dace/transformation/passes/analysis/control_flow_region_analysis.py index 92b2badecf..265c6465ba 100644 --- a/dace/transformation/passes/analysis/control_flow_region_analysis.py +++ b/dace/transformation/passes/analysis/control_flow_region_analysis.py @@ -214,7 +214,7 @@ def _recursive_get_deps_region(self, cfg: ControlFlowRegion, def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict[str, Any]) -> Dict[int, Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]]: """ - :return: For each SDFG, a dictionary mapping states to sets of their input and output memlets. + :return: For each CFG, a dictionary mapping control flow regions to sets of their input and output memlets. """ results = defaultdict(lambda: defaultdict(lambda: [defaultdict(set), defaultdict(set)])) diff --git a/tests/passes/analysis/control_flow_region_analysis_test.py b/tests/passes/analysis/control_flow_region_analysis_test.py index 64461edd85..d1ea5161bf 100644 --- a/tests/passes/analysis/control_flow_region_analysis_test.py +++ b/tests/passes/analysis/control_flow_region_analysis_test.py @@ -4,33 +4,36 @@ import dace from dace.memlet import Memlet -from dace.sdfg.sdfg import InterstateEdge +from dace.sdfg.propagation import propagate_memlets_sdfg +from dace.sdfg.sdfg import SDFG, InterstateEdge from dace.sdfg.state import LoopRegion, SDFGState from dace.transformation.pass_pipeline import Pipeline from dace.transformation.passes.analysis.control_flow_region_analysis import CFGDataDependence, StateDataDependence def test_simple_state_data_dependence_with_self_contained_read(): + sdfg = SDFG('myprog') N = dace.symbol('N') - - @dace.program - def myprog(A: dace.float64[N], B: dace.float64): - for i in dace.map[0:N/2]: - with dace.tasklet: - in1 << B[i] - out1 >> A[i] - out1 = in1 + 1 - with dace.tasklet: - in1 << B[i] - out1 >> B[N - (i + 1)] - out1 = in1 - 1 - for i in dace.map[0:N/2]: - with dace.tasklet: - in1 << A[i] - out1 >> B[i] - out1 = in1 * 2 - - sdfg = myprog.to_sdfg() + sdfg.add_array('A', (N,), dace.float32) + sdfg.add_array('B', (N,), dace.float32) + mystate = sdfg.add_state('mystate', is_start_block=True) + b_read = mystate.add_access('B') + b_write_second_half = mystate.add_access('B') + b_write_first_half = mystate.add_access('B') + a_read_write = mystate.add_access('A') + first_entry, first_exit = mystate.add_map('map_one', {'i': '0:0.5*N'}) + second_entry, second_exit = mystate.add_map('map_two', {'i': '0:0.5*N'}) + t1 = mystate.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1 + 1.0') + t2 = mystate.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1 - 1.0') + t3 = mystate.add_tasklet('t3', {'i1'}, {'o1'}, 'o1 = i1 * 2.0') + mystate.add_memlet_path(b_read, first_entry, t1, memlet=Memlet('B[i]'), dst_conn='i1') + mystate.add_memlet_path(b_read, first_entry, t2, memlet=Memlet('B[i]'), dst_conn='i1') + mystate.add_memlet_path(t1, first_exit, a_read_write, memlet=Memlet('A[i]'), src_conn='o1') + mystate.add_memlet_path(t2, first_exit, b_write_second_half, memlet=Memlet('B[N - (i + 1)]'), src_conn='o1') + mystate.add_memlet_path(a_read_write, second_entry, t3, memlet=Memlet('A[i]'), dst_conn='i1') + mystate.add_memlet_path(t3, second_exit, b_write_first_half, memlet=Memlet('B[i]'), src_conn='o1') + + propagate_memlets_sdfg(sdfg) res = {} Pipeline([StateDataDependence()]).apply_pass(sdfg, res) @@ -40,7 +43,7 @@ def myprog(A: dace.float64[N], B: dace.float64): read_memlet: Memlet = list(state_data_deps[0])[0] assert read_memlet.data == 'B' assert read_memlet.subset[0][0] == 0 - assert read_memlet.subset[0][1] == 0.5 * N - 1 or read_memlet.subset[0][1] == N / 2 - 1 + assert read_memlet.subset[0][1] == 0.5 * N - 1 assert len(state_data_deps[1]) == 3 diff --git a/tests/passes/writeset_underapproximation_test.py b/tests/passes/writeset_underapproximation_test.py index 7d5272d80a..d0c0e03209 100644 --- a/tests/passes/writeset_underapproximation_test.py +++ b/tests/passes/writeset_underapproximation_test.py @@ -9,8 +9,6 @@ M = dace.symbol("M") K = dace.symbol("K") -pipeline = Pipeline([UnderapproximateWrites()]) - def test_2D_map_overwrites_2D_array(): """ @@ -33,9 +31,10 @@ def test_2D_map_overwrites_2D_array(): output_nodes={'B': a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results['approximation'] + result = results[sdfg.cfg_id]['approximation'] edge = map_state.in_edges(a1)[0] result_subset_list = result[edge].subset.subset_list result_subset = result_subset_list[0] @@ -65,9 +64,10 @@ def test_2D_map_added_indices(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -94,9 +94,10 @@ def test_2D_map_multiplied_indices(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -121,9 +122,10 @@ def test_1D_map_one_index_multiple_dims(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -146,9 +148,10 @@ def test_1D_map_one_index_squared(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -185,9 +188,10 @@ def test_map_tree_full_write(): inner_edge_1 = map_state.add_edge(inner_map_exit_1, "OUT_B", map_exit, "IN_B", dace.Memlet(data="B")) outer_edge = map_state.add_edge(map_exit, "OUT_B", a1, None, dace.Memlet(data="B")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] expected_subset_outer_edge = Range.from_string("0:M, 0:N") expected_subset_inner_edge = Range.from_string("0:M, _i") result_inner_edge_0 = result[inner_edge_0].subset.subset_list[0] @@ -230,9 +234,10 @@ def test_map_tree_no_write_multiple_indices(): inner_edge_1 = map_state.add_edge(inner_map_exit_1, "OUT_B", map_exit, "IN_B", dace.Memlet(data="B")) outer_edge = map_state.add_edge(map_exit, "OUT_B", a1, None, dace.Memlet(data="B")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] result_inner_edge_0 = result[inner_edge_0].subset.subset_list result_inner_edge_1 = result[inner_edge_1].subset.subset_list result_outer_edge = result[outer_edge].subset.subset_list @@ -273,9 +278,10 @@ def test_map_tree_multiple_indices_per_dimension(): inner_edge_1 = map_state.add_edge(inner_map_exit_1, "OUT_B", map_exit, "IN_B", dace.Memlet(data="B")) outer_edge = map_state.add_edge(map_exit, "OUT_B", a1, None, dace.Memlet(data="B")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] expected_subset_outer_edge = Range.from_string("0:M, 0:N") expected_subset_inner_edge_1 = Range.from_string("0:M, _i") result_inner_edge_1 = result[inner_edge_1].subset.subset_list[0] @@ -300,11 +306,12 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] nsdfg = sdfg.cfg_list[1].parent_nsdfg_node map_state = sdfg.states()[0] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] edge = map_state.out_edges(nsdfg)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -323,11 +330,12 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] map_state = sdfg.states()[0] edge = map_state.in_edges(map_state.data_nodes()[0])[0] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] expected_subset = Range.from_string("0:N, 0:M") assert (str(result[edge].subset.subset_list[0]) == str(expected_subset)) @@ -357,9 +365,10 @@ def test_map_in_loop(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id]["loop_approximation"] expected_subset = Range.from_string("0:N, 0:M") assert (str(result[guard]["B"].subset.subset_list[0]) == str(expected_subset)) @@ -390,9 +399,10 @@ def test_map_in_loop_multiplied_indices_first_dimension(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id]["loop_approximation"] assert (guard not in result.keys() or len(result[guard]) == 0) @@ -421,9 +431,10 @@ def test_map_in_loop_multiplied_indices_second_dimension(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id]["loop_approximation"] assert (guard not in result.keys() or len(result[guard]) == 0) @@ -444,8 +455,9 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id]["approximation"] # find write set accessnode = None write_set = None @@ -478,9 +490,10 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id]["approximation"] # find write set accessnode = None write_set = None @@ -510,15 +523,16 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id]["approximation"] # find write set accessnode = None write_set = None - for node, _ in sdfg.all_nodes_recursive(): + for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.AccessNode): - if node.data == "A": + if node.data == "A" and parent.out_degree(node) == 0: accessnode = node for edge, memlet in write_approx.items(): if edge.dst is accessnode: @@ -542,15 +556,16 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id]["approximation"] # find write set accessnode = None write_set = None - for node, _ in sdfg.all_nodes_recursive(): + for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.AccessNode): - if node.data == "A": + if node.data == "A" and parent.out_degree(node) == 0: accessnode = node for edge, memlet in write_approx.items(): if edge.dst is accessnode: @@ -574,7 +589,8 @@ def test_simple_loop_overwrite(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (str(result[guard]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) @@ -598,7 +614,8 @@ def test_loop_2D_overwrite(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (str(result[guard1]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) assert (str(result[guard2]["A"].subset) == "j, 0:N") @@ -629,7 +646,8 @@ def test_loop_2D_propagation_gap_symbolic(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert ("A" not in result[guard1].keys()) assert ("A" not in result[guard2].keys()) @@ -657,7 +675,8 @@ def test_2_loops_overwrite(): loop_tasklet_2 = loop_body_2.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body_2.add_edge(loop_tasklet_2, "a", a1, None, dace.Memlet("A[i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (str(result[guard_1]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) assert (str(result[guard_2]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) @@ -687,7 +706,8 @@ def test_loop_2D_overwrite_propagation_gap_non_empty(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (str(result[guard1]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) assert (str(result[guard2]["A"].subset) == "j, 0:N") @@ -717,7 +737,8 @@ def test_loop_nest_multiplied_indices(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i,i*j]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (guard1 not in result.keys() or "A" not in result[guard1].keys()) assert (guard2 not in result.keys() or "A" not in result[guard2].keys()) @@ -748,7 +769,8 @@ def test_loop_nest_empty_nested_loop(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (guard1 not in result.keys() or "A" not in result[guard1].keys()) assert (guard2 not in result.keys() or "A" not in result[guard2].keys()) @@ -779,7 +801,8 @@ def test_loop_nest_inner_loop_conditional(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[k]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (guard1 not in result.keys() or "A" not in result[guard1].keys()) assert (guard2 in result.keys() and "A" in result[guard2].keys() and str(result[guard2]['A'].subset) == "0:N") @@ -799,9 +822,10 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id]["approximation"] write_set = None accessnode = None for node, _ in sdfg.all_nodes_recursive(): @@ -828,10 +852,11 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] # find write set - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id]["approximation"] accessnode = None write_set = None for node, _ in sdfg.all_nodes_recursive(): @@ -864,9 +889,10 @@ def test_loop_break(): loop_tasklet = loop_body_1.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body_1.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i]")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id]["loop_approximation"] assert (guard3 not in result.keys() or "A" not in result[guard3].keys()) From b61a283e1c624da2c9ef4b8634e2081eb5b15159 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 17 Sep 2024 16:09:27 +0200 Subject: [PATCH 07/14] Add tests --- .../analysis/control_flow_region_analysis.py | 2 +- .../control_flow_region_analysis_test.py | 100 ++++++++++++------ 2 files changed, 68 insertions(+), 34 deletions(-) diff --git a/dace/transformation/passes/analysis/control_flow_region_analysis.py b/dace/transformation/passes/analysis/control_flow_region_analysis.py index 265c6465ba..377765c31b 100644 --- a/dace/transformation/passes/analysis/control_flow_region_analysis.py +++ b/dace/transformation/passes/analysis/control_flow_region_analysis.py @@ -35,7 +35,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & (ppl.Modifies.Nodes | ppl.Modifies.Memlets) def depends_on(self): - return {UnderapproximateWrites, AccessRanges} + return {UnderapproximateWrites} def _gather_reads_scope(self, state: SDFGState, scope: ScopeTree, writes: Dict[str, List[Tuple[Memlet, nodes.AccessNode]]], diff --git a/tests/passes/analysis/control_flow_region_analysis_test.py b/tests/passes/analysis/control_flow_region_analysis_test.py index d1ea5161bf..bf0742f3f1 100644 --- a/tests/passes/analysis/control_flow_region_analysis_test.py +++ b/tests/passes/analysis/control_flow_region_analysis_test.py @@ -1,21 +1,18 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests analysis passes related to control flow regions (control_flow_region_analysis.py). """ - import dace from dace.memlet import Memlet -from dace.sdfg.propagation import propagate_memlets_sdfg -from dace.sdfg.sdfg import SDFG, InterstateEdge -from dace.sdfg.state import LoopRegion, SDFGState +from dace.sdfg.sdfg import SDFG from dace.transformation.pass_pipeline import Pipeline -from dace.transformation.passes.analysis.control_flow_region_analysis import CFGDataDependence, StateDataDependence +from dace.transformation.passes.analysis.control_flow_region_analysis import StateDataDependence -def test_simple_state_data_dependence_with_self_contained_read(): +def test_state_data_dependence_with_contained_read(): sdfg = SDFG('myprog') N = dace.symbol('N') - sdfg.add_array('A', (N,), dace.float32) - sdfg.add_array('B', (N,), dace.float32) + sdfg.add_array('A', (N, ), dace.float32) + sdfg.add_array('B', (N, ), dace.float32) mystate = sdfg.add_state('mystate', is_start_block=True) b_read = mystate.add_access('B') b_write_second_half = mystate.add_access('B') @@ -33,8 +30,6 @@ def test_simple_state_data_dependence_with_self_contained_read(): mystate.add_memlet_path(a_read_write, second_entry, t3, memlet=Memlet('A[i]'), dst_conn='i1') mystate.add_memlet_path(t3, second_exit, b_write_first_half, memlet=Memlet('B[i]'), src_conn='o1') - propagate_memlets_sdfg(sdfg) - res = {} Pipeline([StateDataDependence()]).apply_pass(sdfg, res) state_data_deps = res[StateDataDependence.__name__][0][sdfg.states()[0]] @@ -48,36 +43,75 @@ def test_simple_state_data_dependence_with_self_contained_read(): assert len(state_data_deps[1]) == 3 -''' -def test_nested_cf_region_data_dependence(): +def test_state_data_dependence_with_contained_read_in_map(): + sdfg = SDFG('myprog') N = dace.symbol('N') + sdfg.add_array('A', (N, ), dace.float32) + sdfg.add_transient('tmp', (N, ), dace.float32) + sdfg.add_array('B', (N, ), dace.float32) + mystate = sdfg.add_state('mystate', is_start_block=True) + a_read = mystate.add_access('A') + tmp = mystate.add_access('tmp') + b_write = mystate.add_access('B') + m_entry, m_exit = mystate.add_map('my_map', {'i': 'N'}) + t1 = mystate.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1 * 2.0') + t2 = mystate.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1 - 1.0') + mystate.add_memlet_path(a_read, m_entry, t1, memlet=Memlet('A[i]'), dst_conn='i1') + mystate.add_memlet_path(t1, tmp, memlet=Memlet('tmp[i]'), src_conn='o1') + mystate.add_memlet_path(tmp, t2, memlet=Memlet('tmp[i]'), dst_conn='i1') + mystate.add_memlet_path(t2, m_exit, b_write, memlet=Memlet('B[i]'), src_conn='o1') - @dace.program - def myprog(A: dace.float64[N], B: dace.float64): - for i in range(N): - with dace.tasklet: - in1 << B[i] - out1 >> A[i] - out1 = in1 + 1 - for i in range(N): - with dace.tasklet: - in1 << A[i] - out1 >> B[i] - out1 = in1 * 2 + res = {} + Pipeline([StateDataDependence()]).apply_pass(sdfg, res) + state_data_deps = res[StateDataDependence.__name__][0][sdfg.states()[0]] - myprog.use_experimental_cfg_blocks = True + assert len(state_data_deps[0]) == 1 + read_memlet: Memlet = list(state_data_deps[0])[0] + assert read_memlet.data == 'A' - sdfg = myprog.to_sdfg() + assert len(state_data_deps[1]) == 2 + out_containers = [m.data for m in state_data_deps[1]] + assert 'B' in out_containers + assert 'tmp' in out_containers + assert 'A' not in out_containers + + +def test_state_data_dependence_with_non_contained_read_in_map(): + sdfg = SDFG('myprog') + N = dace.symbol('N') + sdfg.add_array('A', (N, ), dace.float32) + sdfg.add_array('tmp', (N, ), dace.float32) + sdfg.add_array('B', (N, ), dace.float32) + mystate = sdfg.add_state('mystate', is_start_block=True) + a_read = mystate.add_access('A') + tmp = mystate.add_access('tmp') + b_write = mystate.add_access('B') + m_entry, m_exit = mystate.add_map('my_map', {'i': '0:ceil(N/2)'}) + t1 = mystate.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1 * 2.0') + t2 = mystate.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1 - 1.0') + mystate.add_memlet_path(a_read, m_entry, t1, memlet=Memlet('A[i]'), dst_conn='i1') + mystate.add_memlet_path(t1, tmp, memlet=Memlet('tmp[i]'), src_conn='o1') + mystate.add_memlet_path(tmp, t2, memlet=Memlet('tmp[i+ceil(N/2)]'), dst_conn='i1') + mystate.add_memlet_path(t2, m_exit, b_write, memlet=Memlet('B[i]'), src_conn='o1') res = {} - pipeline = Pipeline([CFGDataDependence()]) - pipeline.__experimental_cfg_block_compatible__ = True - pipeline.apply_pass(sdfg, res) + Pipeline([StateDataDependence()]).apply_pass(sdfg, res) + state_data_deps = res[StateDataDependence.__name__][0][sdfg.states()[0]] + + assert len(state_data_deps[0]) == 2 + in_containers = [m.data for m in state_data_deps[0]] + assert 'A' in in_containers + assert 'tmp' in in_containers + assert 'B' not in in_containers - print(sdfg) - ''' + assert len(state_data_deps[1]) == 2 + out_containers = [m.data for m in state_data_deps[1]] + assert 'B' in out_containers + assert 'tmp' in out_containers + assert 'A' not in out_containers if __name__ == '__main__': - test_simple_state_data_dependence_with_self_contained_read() - #test_nested_cf_region_data_dependence() + test_state_data_dependence_with_contained_read() + test_state_data_dependence_with_contained_read_in_map() + test_state_data_dependence_with_non_contained_read_in_map() From 05b1c28847af2a2f222ed36342fd8da0cbaefb32 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 18 Sep 2024 18:19:28 +0200 Subject: [PATCH 08/14] Add loop lifting capabilities --- dace/codegen/control_flow.py | 13 +- dace/sdfg/state.py | 21 ++- .../interstate/loop_detection.py | 53 ++++-- .../transformation/interstate/loop_lifting.py | 112 ++++++++++++ .../simplification/control_flow_raising.py | 22 +++ .../interstate/loop_lifting_test.py | 164 ++++++++++++++++++ 6 files changed, 358 insertions(+), 27 deletions(-) create mode 100644 dace/transformation/interstate/loop_lifting.py create mode 100644 dace/transformation/passes/simplification/control_flow_raising.py create mode 100644 tests/transformations/interstate/loop_lifting_test.py diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index ae9351fc43..d170d04e77 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -270,10 +270,17 @@ def as_cpp(self, codegen, symbols) -> str: for i, elem in enumerate(self.elements): expr += elem.as_cpp(codegen, symbols) # In a general block, emit transitions and assignments after each individual block or region. - if isinstance(elem, BasicCFBlock) or (isinstance(elem, GeneralBlock) and elem.region): - cfg = elem.state.parent_graph if isinstance(elem, BasicCFBlock) else elem.region.parent_graph + if (isinstance(elem, BasicCFBlock) or (isinstance(elem, GeneralBlock) and elem.region) or + (isinstance(elem, GeneralLoopScope) and elem.loop)): + if isinstance(elem, BasicCFBlock): + g_elem = elem.state + elif isinstance(elem, GeneralBlock): + g_elem = elem.region + else: + g_elem = elem.loop + cfg = g_elem.parent_graph sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg - out_edges = cfg.out_edges(elem.state) if isinstance(elem, BasicCFBlock) else cfg.out_edges(elem.region) + out_edges = cfg.out_edges(g_elem) for j, e in enumerate(out_edges): if e not in self.gotos_to_ignore: # Skip gotos to immediate successors diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index e8a8161747..7fcdc34e3e 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2965,26 +2965,35 @@ class LoopRegion(ControlFlowRegion): def __init__(self, label: str, - condition_expr: Optional[str] = None, + condition_expr: Optional[Union[str, CodeBlock]] = None, loop_var: Optional[str] = None, - initialize_expr: Optional[str] = None, - update_expr: Optional[str] = None, + initialize_expr: Optional[Union[str, CodeBlock]] = None, + update_expr: Optional[Union[str, CodeBlock]] = None, inverted: bool = False, sdfg: Optional['SDFG'] = None): super(LoopRegion, self).__init__(label, sdfg) if initialize_expr is not None: - self.init_statement = CodeBlock(initialize_expr) + if isinstance(initialize_expr, CodeBlock): + self.init_statement = initialize_expr + else: + self.init_statement = CodeBlock(initialize_expr) else: self.init_statement = None if condition_expr: - self.loop_condition = CodeBlock(condition_expr) + if isinstance(condition_expr, CodeBlock): + self.loop_condition = condition_expr + else: + self.loop_condition = CodeBlock(condition_expr) else: self.loop_condition = CodeBlock('True') if update_expr is not None: - self.update_statement = CodeBlock(update_expr) + if isinstance(update_expr, CodeBlock): + self.update_statement = update_expr + else: + self.update_statement = CodeBlock(update_expr) else: self.update_statement = None diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index 93c2f6ea1c..de3ed9c04b 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Loop detection transformation """ import sympy as sp @@ -77,19 +77,20 @@ def can_be_applied(self, sdfg: sd.SDFG, permissive: bool = False) -> bool: if expr_index == 0: - return self.detect_loop(graph, False) is not None + return self.detect_loop(graph, False, permissive) is not None elif expr_index == 1: - return self.detect_loop(graph, True) is not None + return self.detect_loop(graph, True, permissive) is not None elif expr_index == 2: - return self.detect_rotated_loop(graph, False) is not None + return self.detect_rotated_loop(graph, False, permissive) is not None elif expr_index == 3: - return self.detect_rotated_loop(graph, True) is not None + return self.detect_rotated_loop(graph, True, permissive) is not None elif expr_index == 4: - return self.detect_self_loop(graph) is not None + return self.detect_self_loop(graph, permissive) is not None raise ValueError(f'Invalid expression index {expr_index}') - def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Optional[str]: + def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool, + accept_missing_itvar: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -159,13 +160,19 @@ def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Option # The backedge must reassign the iteration variable itvar &= backedge.data.assignments.keys() if len(itvar) != 1: - # Either no consistent iteration variable found, or too many - # consistent iteration variables found - return None + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None + else: + if len(itvar) == 0: + return '' + else: + return None return next(iter(itvar)) - def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Optional[str]: + def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, + accept_missing_itvar: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -234,13 +241,18 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) - # The backedge must reassign the iteration variable itvar &= backedge.data.assignments.keys() if len(itvar) != 1: - # Either no consistent iteration variable found, or too many - # consistent iteration variables found - return None + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None + else: + if len(itvar) == 0: + return '' + else: + return None return next(iter(itvar)) - def detect_self_loop(self, graph: ControlFlowRegion) -> Optional[str]: + def detect_self_loop(self, graph: ControlFlowRegion, accept_missing_itvar: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -288,9 +300,14 @@ def detect_self_loop(self, graph: ControlFlowRegion) -> Optional[str]: # The backedge must reassign the iteration variable itvar &= backedge.data.assignments.keys() if len(itvar) != 1: - # Either no consistent iteration variable found, or too many - # consistent iteration variables found - return None + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None + else: + if len(itvar) == 0: + return '' + else: + return None return next(iter(itvar)) diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py new file mode 100644 index 0000000000..52e6e6e540 --- /dev/null +++ b/dace/transformation/interstate/loop_lifting.py @@ -0,0 +1,112 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from dace import properties +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import ControlFlowRegion, LoopRegion +from dace.transformation import transformation +from dace.transformation.interstate.loop_detection import DetectLoop + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class LoopLifting(DetectLoop, transformation.MultiStateTransformation): + + def can_be_applied(self, graph: transformation.ControlFlowRegion, expr_index: int, sdfg: transformation.SDFG, + permissive: bool = False) -> bool: + # Check loop detection with permissive = True, which allows loops where no iteration variable could be detected. + # We want this to detect while loops. + if not super().can_be_applied(graph, expr_index, sdfg, permissive=True): + return False + + # Check that there's a condition edge, that's the only requirement to lift it into loop. + cond_edge = self.loop_condition_edge() + if not cond_edge or cond_edge.data.condition is None: + return False + return True + + def apply(self, graph: ControlFlowRegion, sdfg: SDFG): + first_state = self.loop_guard if self.expr_index <= 1 else self.loop_begin + after = self.exit_state + + loop_info = self.loop_information() + + body = self.loop_body() + meta = self.loop_meta_states() + full_body = set(body) + full_body.update(meta) + cond_edge = self.loop_condition_edge() + incr_edge = self.loop_increment_edge() + inverted = cond_edge is incr_edge + init_edge = self.loop_init_edge() + exit_edge = self.loop_exit_edge() + + label = 'loop_' + first_state.label + if loop_info is None: + itvar = None + init_expr = None + incr_expr = None + else: + incr_expr = f'{loop_info[0]} = {incr_edge.data.assignments[loop_info[0]]}' + init_expr = f'{loop_info[0]} = {init_edge.data.assignments[loop_info[0]]}' + itvar = loop_info[0] + + left_over_assignments = {} + for k in init_edge.data.assignments.keys(): + if k != itvar: + left_over_assignments[k] = init_edge.data.assignments[k] + left_over_incr_assignments = {} + for k in incr_edge.data.assignments.keys(): + if k != itvar: + left_over_incr_assignments[k] = incr_edge.data.assignments[k] + + # TODO(later): In the case of inverted loops with non-loop-variable assignmentes AND the loop latch condition on + # the backedge, do not perform lifting for now. Note, the functionality in the lifting is there (see below, + # where left over increment assignments are used), but a bug in our control-flow-detection in codegen currently + # leads to wrong code being generated by this niche case. Remove the following check if the bug is fixed, and + # then these loops will also be lifted correctly. + if left_over_incr_assignments != {} and inverted: + return + + loop = LoopRegion(label, condition_expr=cond_edge.data.condition, loop_var=itvar, initialize_expr=init_expr, + update_expr=incr_expr, inverted=inverted, sdfg=sdfg) + + graph.add_node(loop) + graph.add_edge(init_edge.src, loop, + InterstateEdge(condition=init_edge.data.condition, assignments=left_over_assignments)) + graph.add_edge(loop, after, InterstateEdge(assignments=exit_edge.data.assignments)) + + loop.add_node(first_state, is_start_block=True) + for n in full_body: + if n is not first_state: + loop.add_node(n) + added = set() + for e in graph.all_edges(*full_body): + if e.src in full_body and e.dst in full_body: + if not e in added: + added.add(e) + if e is incr_edge: + if left_over_incr_assignments != {}: + # If there are left over increments in an inverted loop, only execute them if the condition + # still holds. This is due to SDFG semantics, where interstate assignments are only executed + # if the condition on the edge holds (i.e., the edge is taken). This must be reflected in + # the raised loop. This is a very niche case - specifically, a do-while, where there is a + # non-loop-variable assignment AND the loop latch condition on the back-edge. + left_over_increment_cond = None + if inverted: + left_over_increment_cond = cond_edge.data.condition + + loop.add_edge(e.src, loop.add_state(label + '_tail'), + InterstateEdge(assignments=left_over_incr_assignments, + condition=left_over_increment_cond)) + elif e is cond_edge: + e.data.condition = properties.CodeBlock('1') + loop.add_edge(e.src, e.dst, e.data) + else: + loop.add_edge(e.src, e.dst, e.data) + + # Remove old loop. + for n in full_body: + graph.remove_node(n) + + sdfg.recheck_using_experimental_blocks() + sdfg.reset_cfg_list() diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py new file mode 100644 index 0000000000..5cad716176 --- /dev/null +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -0,0 +1,22 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from dace import properties +from dace.transformation import pass_pipeline as ppl, transformation +from dace.transformation.interstate.loop_lifting import LoopLifting + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class ControlFlowRaising(ppl.Pass): + + CATEGORY: str = 'Simplification' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.CFG + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + + def apply_pass(self, top_sdfg: ppl.SDFG, _) -> ppl.Any | None: + for sdfg in top_sdfg.all_sdfgs_recursive(): + sdfg.apply_transformations_repeated([LoopLifting], validate_all=False, validate=False) diff --git a/tests/transformations/interstate/loop_lifting_test.py b/tests/transformations/interstate/loop_lifting_test.py new file mode 100644 index 0000000000..e7a026a812 --- /dev/null +++ b/tests/transformations/interstate/loop_lifting_test.py @@ -0,0 +1,164 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests loop raising trainsformations. """ + +import numpy as np +import dace +from dace.memlet import Memlet +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import LoopRegion +from dace.transformation.interstate.loop_lifting import LoopLifting + + +def test_lift_regular_for_loop(): + sdfg = SDFG('regular_for') + N = dace.symbol('N') + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + start_state = sdfg.add_state('start', is_start_block=True) + init_state = sdfg.add_state('start', is_start_block=True) + guard_state = sdfg.add_state('guard') + main_state = sdfg.add_state('loop_state') + loop_exit = sdfg.add_state('exit') + final_state = sdfg.add_state('final') + sdfg.add_edge(start_state, init_state, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(init_state, guard_state, InterstateEdge(assignments={'i': 0, 'k': 0})) + sdfg.add_edge(guard_state, main_state, InterstateEdge(condition='i < N')) + sdfg.add_edge(main_state, guard_state, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(guard_state, loop_exit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + sdfg.add_edge(loop_exit, final_state, InterstateEdge()) + a_access = main_state.add_access('A') + w_tasklet = main_state.add_tasklet('t1', {}, {'out'}, 'out = 1') + main_state.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + a_access_2 = loop_exit.add_access('A') + w_tasklet_2 = loop_exit.add_tasklet('t1', {}, {'out'}, 'out = k') + loop_exit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = final_state.add_access('A') + w_tasklet_3 = final_state.add_tasklet('t1', {}, {'out'}, 'out = j') + final_state.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +def test_lift_loop_llvm_canonical(): + sdfg = dace.SDFG('llvm_canonical') + N = dace.symbol('N') + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + + entry = sdfg.add_state('entry', is_start_block=True) + guard = sdfg.add_state('guard') + preheader = sdfg.add_state('preheader') + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + loopexit = sdfg.add_state('loopexit') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(entry, guard, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(guard, exitstate, InterstateEdge(condition='N <= 0')) + sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) + sdfg.add_edge(preheader, body, InterstateEdge(assignments={'i': 0, 'k': 0})) + sdfg.add_edge(body, latch, InterstateEdge()) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N', assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + sdfg.add_edge(loopexit, exitstate, InterstateEdge()) + + a_access = body.add_access('A') + w_tasklet = body.add_tasklet('t1', {}, {'out'}, 'out = 1') + body.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + a_access_2 = loopexit.add_access('A') + w_tasklet_2 = loopexit.add_tasklet('t1', {}, {'out'}, 'out = k') + loopexit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = exitstate.add_access('A') + w_tasklet_3 = exitstate.add_tasklet('t1', {}, {'out'}, 'out = j') + exitstate.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +def test_lift_loop_llvm_canonical_while(): + sdfg = dace.SDFG('llvm_canonical_while') + N = dace.symbol('N') + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + sdfg.add_scalar('i', dace.int32, transient=True) + + entry = sdfg.add_state('entry', is_start_block=True) + guard = sdfg.add_state('guard') + preheader = sdfg.add_state('preheader') + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + loopexit = sdfg.add_state('loopexit') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(entry, guard, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(guard, exitstate, InterstateEdge(condition='N <= 0')) + sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) + sdfg.add_edge(preheader, body, InterstateEdge(assignments={'k': 0})) + sdfg.add_edge(body, latch, InterstateEdge(assignments={'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N')) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + sdfg.add_edge(loopexit, exitstate, InterstateEdge()) + + i_init_write = entry.add_access('i') + iw_init_tasklet = entry.add_tasklet('ti', {}, {'out'}, 'out = 0') + entry.add_edge(iw_init_tasklet, 'out', i_init_write, None, Memlet('i[0]')) + a_access = body.add_access('A') + w_tasklet = body.add_tasklet('t1', {}, {'out'}, 'out = 1') + body.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + i_read = body.add_access('i') + i_write = body.add_access('i') + iw_tasklet = body.add_tasklet('t2', {'in1'}, {'out'}, 'out = in1 + 2') + body.add_edge(i_read, None, iw_tasklet, 'in1', Memlet('i[0]')) + body.add_edge(iw_tasklet, 'out', i_write, None, Memlet('i[0]')) + a_access_2 = loopexit.add_access('A') + w_tasklet_2 = loopexit.add_tasklet('t1', {}, {'out'}, 'out = k') + loopexit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = exitstate.add_access('A') + w_tasklet_3 = exitstate.add_tasklet('t1', {}, {'out'}, 'out = j') + exitstate.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +if __name__ == '__main__': + test_lift_regular_for_loop() + test_lift_loop_llvm_canonical() + test_lift_loop_llvm_canonical_while() From f08d95e858ca093eef4e34ce257082691b08b587 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 19 Sep 2024 12:44:22 +0200 Subject: [PATCH 09/14] Adjust loop detection to LLVM canonical semantics --- .../analysis/writeset_underapproximation.py | 21 +++--- dace/sdfg/propagation.py | 15 ++-- .../interstate/loop_detection.py | 71 ++++++++++++------- .../transformation/interstate/loop_lifting.py | 22 +----- .../interstate/loop_lifting_test.py | 4 +- tests/transformations/loop_detection_test.py | 8 +-- 6 files changed, 71 insertions(+), 70 deletions(-) diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index e1b88f9401..0d2fd989a3 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -153,19 +153,16 @@ def propagate(self, array, expressions, node_range): dexprs = [] for expr in expressions: - if isinstance(expr[i], symbolic.SymExpr): - dexprs.append(expr[i].expr) - elif isinstance(expr[i], tuple): - dexprs.append(( - expr[i][0].expr if isinstance( - expr[i][0], symbolic.SymExpr) else expr[i][0], - expr[i][1].expr if isinstance( - expr[i][1], symbolic.SymExpr) else expr[i][1], - expr[i][2].expr if isinstance( - expr[i][2], symbolic.SymExpr) else expr[i][2], - expr.tile_sizes[i])) + expr_i = expr[i] + if isinstance(expr_i, symbolic.SymExpr): + dexprs.append(expr_i.expr) + elif isinstance(expr_i, tuple): + dexprs.append((expr_i[0].expr if isinstance(expr_i[0], symbolic.SymExpr) else expr_i[0], + expr_i[1].expr if isinstance(expr_i[1], symbolic.SymExpr) else expr_i[1], + expr_i[2].expr if isinstance(expr_i[2], symbolic.SymExpr) else expr_i[2], + expr.tile_sizes[i])) else: - dexprs.append(expr[i]) + dexprs.append(expr_i) result[i] = smpattern.propagate(array, dexprs, node_range) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 6447d8f89b..1de7ce3977 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -94,15 +94,16 @@ def propagate(self, array, expressions, node_range): dexprs = [] for expr in expressions: - if isinstance(expr[i], symbolic.SymExpr): - dexprs.append(expr[i].approx) - elif isinstance(expr[i], tuple): - dexprs.append((expr[i][0].approx if isinstance(expr[i][0], symbolic.SymExpr) else expr[i][0], - expr[i][1].approx if isinstance(expr[i][1], symbolic.SymExpr) else expr[i][1], - expr[i][2].approx if isinstance(expr[i][2], symbolic.SymExpr) else expr[i][2], + expr_i = expr[i] + if isinstance(expr_i, symbolic.SymExpr): + dexprs.append(expr_i.approx) + elif isinstance(expr_i, tuple): + dexprs.append((expr_i[0].approx if isinstance(expr_i[0], symbolic.SymExpr) else expr_i[0], + expr_i[1].approx if isinstance(expr_i[1], symbolic.SymExpr) else expr_i[1], + expr_i[2].approx if isinstance(expr_i[2], symbolic.SymExpr) else expr_i[2], expr.tile_sizes[i])) else: - dexprs.append(expr[i]) + dexprs.append(expr_i) result[i] = smpattern.propagate(array, dexprs, overapprox_range) diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index de3ed9c04b..95056eb344 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -1,6 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Loop detection transformation """ +from re import I import sympy as sp import networkx as nx from typing import AnyStr, Optional, Tuple, List, Set @@ -199,14 +200,10 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, if len(latch_outedges) != 2: return None - # All incoming edges to the start of the loop must set the same variable - itvar = None - for iedge in begin_inedges: - if itvar is None: - itvar = set(iedge.data.assignments.keys()) - else: - itvar &= iedge.data.assignments.keys() - if itvar is None: + # A for-loop latch can further only have one incoming edge (the increment edge). A while-loop, i.e., a loop + # with no explicit iteration variable, may have more than that. + latch_inedges = graph.in_edges(latch) + if not accept_missing_itvar and len(latch_inedges) != 1: return None # Outgoing edges must be a negation of each other @@ -238,8 +235,22 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, if backedge is None: return None - # The backedge must reassign the iteration variable - itvar &= backedge.data.assignments.keys() + # The iteration variable must be reassigned on all incoming edges to the latch block. + # If an assignment overlap of exactly one variable is found between the initialization edge and the edges + # going into the latch block, that will be the iteration variable. + itvar_edge_set: Set[gr.Edge[InterstateEdge]] = set() + itvar_edge_set.update(begin_inedges) + itvar_edge_set.update(latch_inedges) + itvar = None + for iedge in itvar_edge_set: + if iedge is backedge: + continue + if itvar is None: + itvar = set(iedge.data.assignments.keys()) + else: + itvar &= iedge.data.assignments.keys() + if itvar is None: + return None if len(itvar) != 1: if not accept_missing_itvar: # Either no consistent iteration variable found, or too many consistent iteration variables found @@ -430,7 +441,7 @@ def loop_increment_edge(self) -> gr.Edge[InterstateEdge]: return next(e for e in graph.in_edges(guard) if e.src in body) elif self.expr_index in (2, 3): body = self.loop_body() - return next(e for e in graph.in_edges(begin) if e.src in body) + return graph.in_edges(self.loop_latch)[0] elif self.expr_index == 4: return graph.edges_between(begin, begin)[0] @@ -554,8 +565,7 @@ def find_rotated_for_loop( """ Finds rotated loop range from state machine. - :param latch: State from which the outgoing edges detect whether to exit - the loop or not. + :param latch: State from which the outgoing edges detect whether to exit the loop or not. :param entry: First state in the loop body. :param itervar: An optional field that overrides the analyzed iteration variable. :return: (iteration variable, (start, end, stride), @@ -565,11 +575,20 @@ def find_rotated_for_loop( # Extract state transition edge information entry_inedges = graph.in_edges(entry) condition_edge = graph.edges_between(latch, entry)[0] + latch_inedges = graph.in_edges(latch) - # All incoming edges to the loop entry must set the same variable + self_loop = latch is entry if itervar is None: + # The iteration variable must be reassigned on all incoming edges to the latch block. + # If an assignment overlap of exactly one variable is found between the initialization edge and the edges + # going into the latch block, that will be the iteration variable. + itvar_edge_set: Set[gr.Edge[InterstateEdge]] = set() + itvar_edge_set.update(entry_inedges) + itvar_edge_set.update(latch_inedges) itervars = None - for iedge in entry_inedges: + for iedge in itvar_edge_set: + if iedge is condition_edge and not self_loop: + continue if itervars is None: itervars = set(iedge.data.assignments.keys()) else: @@ -587,18 +606,12 @@ def find_rotated_for_loop( # have one assignment. init_edges = [] init_assignment = None - step_edge = None itersym = symbolic.symbol(itervar) for iedge in entry_inedges: + if iedge is condition_edge: + continue assignment = iedge.data.assignments[itervar] - if itersym in symbolic.pystr_to_symbolic(assignment).free_symbols: - if step_edge is None: - step_edge = iedge - else: - # More than one edge with the iteration variable as a free - # symbol, which is not legal. Invalid for loop. - return None - else: + if itersym not in symbolic.pystr_to_symbolic(assignment).free_symbols: if init_assignment is None: init_assignment = assignment init_edges.append(iedge) @@ -608,10 +621,18 @@ def find_rotated_for_loop( return None else: init_edges.append(iedge) - if step_edge is None or len(init_edges) == 0 or init_assignment is None: + if len(init_edges) == 0 or init_assignment is None: # Less than two assignment variations, can't be a valid for loop. return None + if self_loop: + step_edge = condition_edge + else: + step_edge = None if len(latch_inedges) != 1 else latch_inedges[0] + if step_edge is None: + # No explicit step edge found. + return None + # Get the init expression and the stride. start = symbolic.pystr_to_symbolic(init_assignment) stride = (symbolic.pystr_to_symbolic(step_edge.data.assignments[itervar]) - itersym) diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py index 52e6e6e540..54363dd8e2 100644 --- a/dace/transformation/interstate/loop_lifting.py +++ b/dace/transformation/interstate/loop_lifting.py @@ -36,7 +36,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): full_body.update(meta) cond_edge = self.loop_condition_edge() incr_edge = self.loop_increment_edge() - inverted = cond_edge is incr_edge + inverted = self.expr_index in (2, 3) init_edge = self.loop_init_edge() exit_edge = self.loop_exit_edge() @@ -59,14 +59,6 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): if k != itvar: left_over_incr_assignments[k] = incr_edge.data.assignments[k] - # TODO(later): In the case of inverted loops with non-loop-variable assignmentes AND the loop latch condition on - # the backedge, do not perform lifting for now. Note, the functionality in the lifting is there (see below, - # where left over increment assignments are used), but a bug in our control-flow-detection in codegen currently - # leads to wrong code being generated by this niche case. Remove the following check if the bug is fixed, and - # then these loops will also be lifted correctly. - if left_over_incr_assignments != {} and inverted: - return - loop = LoopRegion(label, condition_expr=cond_edge.data.condition, loop_var=itvar, initialize_expr=init_expr, update_expr=incr_expr, inverted=inverted, sdfg=sdfg) @@ -86,18 +78,8 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): added.add(e) if e is incr_edge: if left_over_incr_assignments != {}: - # If there are left over increments in an inverted loop, only execute them if the condition - # still holds. This is due to SDFG semantics, where interstate assignments are only executed - # if the condition on the edge holds (i.e., the edge is taken). This must be reflected in - # the raised loop. This is a very niche case - specifically, a do-while, where there is a - # non-loop-variable assignment AND the loop latch condition on the back-edge. - left_over_increment_cond = None - if inverted: - left_over_increment_cond = cond_edge.data.condition - loop.add_edge(e.src, loop.add_state(label + '_tail'), - InterstateEdge(assignments=left_over_incr_assignments, - condition=left_over_increment_cond)) + InterstateEdge(assignments=left_over_incr_assignments)) elif e is cond_edge: e.data.condition = properties.CodeBlock('1') loop.add_edge(e.src, e.dst, e.data) diff --git a/tests/transformations/interstate/loop_lifting_test.py b/tests/transformations/interstate/loop_lifting_test.py index e7a026a812..e3098d4e5c 100644 --- a/tests/transformations/interstate/loop_lifting_test.py +++ b/tests/transformations/interstate/loop_lifting_test.py @@ -72,8 +72,8 @@ def test_lift_loop_llvm_canonical(): sdfg.add_edge(guard, exitstate, InterstateEdge(condition='N <= 0')) sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) sdfg.add_edge(preheader, body, InterstateEdge(assignments={'i': 0, 'k': 0})) - sdfg.add_edge(body, latch, InterstateEdge()) - sdfg.add_edge(latch, body, InterstateEdge(condition='i < N', assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(body, latch, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N')) sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) sdfg.add_edge(loopexit, exitstate, InterstateEdge()) diff --git a/tests/transformations/loop_detection_test.py b/tests/transformations/loop_detection_test.py index 5469f45762..891d520f41 100644 --- a/tests/transformations/loop_detection_test.py +++ b/tests/transformations/loop_detection_test.py @@ -37,8 +37,8 @@ def test_loop_rotated(): exitstate = sdfg.add_state('exitstate') sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge()) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 2'))) + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 2'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) sdfg.add_edge(latch, exitstate, dace.InterstateEdge('i >= N')) xform = CountLoops() @@ -106,8 +106,8 @@ def test_loop_llvm_canonical(): sdfg.add_edge(guard, exitstate, dace.InterstateEdge('N <= 0')) sdfg.add_edge(guard, preheader, dace.InterstateEdge('N > 0')) sdfg.add_edge(preheader, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge()) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 1'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) sdfg.add_edge(latch, loopexit, dace.InterstateEdge('i >= N')) sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) From 1d903463b9277b0d1eaad713d26e682b7964f1a6 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 19 Sep 2024 13:17:20 +0200 Subject: [PATCH 10/14] Test fix --- tests/transformations/loop_to_map_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index 2cab97da78..12d4898858 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -741,8 +741,8 @@ def test_rotated_loop_to_map(simplify): sdfg.add_edge(guard, exitstate, dace.InterstateEdge('N <= 0')) sdfg.add_edge(guard, preheader, dace.InterstateEdge('N > 0')) sdfg.add_edge(preheader, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge()) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 1'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) sdfg.add_edge(latch, loopexit, dace.InterstateEdge('i >= N')) sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) From 6b5ef0ce115ce33a34c489bc42c479cb7d9f5f6a Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 19 Sep 2024 17:40:25 +0200 Subject: [PATCH 11/14] Remove unnecessary imports --- .../analysis/writeset_underapproximation.py | 24 +++++++++---------- dace/sdfg/propagation.py | 16 ++++++------- dace/transformation/subgraph/expansion.py | 9 ++----- dace/transformation/subgraph/helpers.py | 17 ++++--------- .../writeset_underapproximation_test.py | 1 + 5 files changed, 26 insertions(+), 41 deletions(-) diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index 0d2fd989a3..557ee8a73b 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -82,27 +82,26 @@ def can_be_applied(self, expressions, variable_context, node_range, orig_edges): for dim in range(data_dims): dexprs = [] for expr in expressions: - if isinstance(expr[dim], symbolic.SymExpr): - dexprs.append(expr[dim].expr) - elif isinstance(expr[dim], tuple): - dexprs.append( - (expr[dim][0].expr if isinstance(expr[dim][0], symbolic.SymExpr) else - expr[dim][0], expr[dim][1].expr if isinstance( - expr[dim][1], symbolic.SymExpr) else expr[dim][1], expr[dim][2].expr - if isinstance(expr[dim][2], symbolic.SymExpr) else expr[dim][2])) + expr_dim = expr[dim] + if isinstance(expr_dim, symbolic.SymExpr): + dexprs.append(expr_dim.expr) + elif isinstance(expr_dim, tuple): + dexprs.append((expr_dim[0].expr if isinstance(expr_dim[0], symbolic.SymExpr) else expr_dim[0], + expr_dim[1].expr if isinstance(expr_dim[1], symbolic.SymExpr) else expr_dim[1], + expr_dim[2].expr if isinstance(expr_dim[2], symbolic.SymExpr) else expr_dim[2])) else: - dexprs.append(expr[dim]) + dexprs.append(expr_dim) for pattern_class in SeparableUnderapproximationMemletPattern.extensions().keys(): smpattern = pattern_class() - if smpattern.can_be_applied(dexprs, variable_context, node_range, orig_edges, dim, - data_dims): + if smpattern.can_be_applied(dexprs, variable_context, node_range, orig_edges, dim, data_dims): self.patterns_per_dim[dim] = smpattern break return None not in self.patterns_per_dim def _iteration_variables_appear_multiple_times(self, data_dims, expressions, other_params, params): + # TODO: This name implies exactly the inverse of the returned value.. for expr in expressions: for param in params: occured_before = False @@ -139,8 +138,7 @@ def _iteration_variables_appear_multiple_times(self, data_dims, expressions, oth def _make_range(self, node_range): return subsets.Range([(rb.expr if isinstance(rb, symbolic.SymExpr) else rb, - re.expr if isinstance( - re, symbolic.SymExpr) else re, + re.expr if isinstance(re, symbolic.SymExpr) else re, rs.expr if isinstance(rs, symbolic.SymExpr) else rs) for rb, re, rs in node_range]) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 1de7ce3977..f62bb6eb58 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -62,17 +62,17 @@ def can_be_applied(self, expressions, variable_context, node_range, orig_edges): for rb, re, rs in node_range]) for dim in range(data_dims): - dexprs = [] for expr in expressions: - if isinstance(expr[dim], symbolic.SymExpr): - dexprs.append(expr[dim].approx) - elif isinstance(expr[dim], tuple): - dexprs.append((expr[dim][0].approx if isinstance(expr[dim][0], symbolic.SymExpr) else expr[dim][0], - expr[dim][1].approx if isinstance(expr[dim][1], symbolic.SymExpr) else expr[dim][1], - expr[dim][2].approx if isinstance(expr[dim][2], symbolic.SymExpr) else expr[dim][2])) + expr_dim = expr[dim] + if isinstance(expr_dim, symbolic.SymExpr): + dexprs.append(expr_dim.approx) + elif isinstance(expr_dim, tuple): + dexprs.append((expr_dim[0].approx if isinstance(expr_dim[0], symbolic.SymExpr) else expr_dim[0], + expr_dim[1].approx if isinstance(expr_dim[1], symbolic.SymExpr) else expr_dim[1], + expr_dim[2].approx if isinstance(expr_dim[2], symbolic.SymExpr) else expr_dim[2])) else: - dexprs.append(expr[dim]) + dexprs.append(expr_dim) for pattern_class in SeparableMemletPattern.extensions().keys(): smpattern = pattern_class() diff --git a/dace/transformation/subgraph/expansion.py b/dace/transformation/subgraph/expansion.py index db1e9b59ab..aa182e8c80 100644 --- a/dace/transformation/subgraph/expansion.py +++ b/dace/transformation/subgraph/expansion.py @@ -1,26 +1,21 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ This module contains classes that implement the expansion transformation. """ -from dace import dtypes, registry, symbolic, subsets +from dace import dtypes, symbolic, subsets from dace.sdfg import nodes -from dace.memlet import Memlet from dace.sdfg import replace, SDFG, dynamic_map_inputs from dace.sdfg.graph import SubgraphView from dace.transformation import transformation from dace.properties import make_properties, Property -from dace.sdfg.propagation import propagate_memlets_sdfg from dace.transformation.subgraph import helpers from collections import defaultdict from copy import deepcopy as dcpy -from typing import List, Union import itertools -import dace.libraries.standard as stdlib import warnings -import sys def offset_map(state, map_entry): diff --git a/dace/transformation/subgraph/helpers.py b/dace/transformation/subgraph/helpers.py index b2af49c879..0ea1903522 100644 --- a/dace/transformation/subgraph/helpers.py +++ b/dace/transformation/subgraph/helpers.py @@ -1,20 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Subgraph Transformation Helper API """ -from dace import dtypes, registry, symbolic, subsets -from dace.sdfg import nodes, utils -from dace.memlet import Memlet -from dace.sdfg import replace, SDFG, SDFGState -from dace.properties import make_properties, Property -from dace.sdfg.propagation import propagate_memlets_sdfg +from dace import subsets +from dace.sdfg import nodes from dace.sdfg.graph import SubgraphView -from collections import defaultdict import copy -from typing import List, Union, Dict, Tuple, Set - -import dace.libraries.standard as stdlib - -import itertools +from typing import List, Dict, Set # **************** # Helper functions diff --git a/tests/passes/writeset_underapproximation_test.py b/tests/passes/writeset_underapproximation_test.py index d0c0e03209..d27683b801 100644 --- a/tests/passes/writeset_underapproximation_test.py +++ b/tests/passes/writeset_underapproximation_test.py @@ -545,6 +545,7 @@ def test_nested_sdfg_in_map_branches(): Nested SDFG that overwrites second dimension of array conditionally. --> should approximate write-set of map as empty """ + # No, should be approximated precisely - at least certainly with CF regions..? @dace.program def nested_loop(A: dace.float64[M, N]): From 23af03863c9f0a4196b6978866cf220f5491f4ec Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 2 Oct 2024 12:26:12 +0200 Subject: [PATCH 12/14] Improved loop detection --- .../interstate/loop_detection.py | 242 +++++++++++++----- .../simplification/control_flow_raising.py | 2 +- tests/transformations/loop_detection_test.py | 51 ++-- 3 files changed, 206 insertions(+), 89 deletions(-) diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index 95056eb344..daf13599fe 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -1,10 +1,9 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Loop detection transformation """ -from re import I import sympy as sp import networkx as nx -from typing import AnyStr, Optional, Tuple, List, Set +from typing import AnyStr, Iterable, Optional, Tuple, List, Set from dace import sdfg as sd, symbolic from dace.sdfg import graph as gr, utils as sdutil, InterstateEdge @@ -30,6 +29,9 @@ class DetectLoop(transformation.PatternTransformation): # Available for rotated and self loops entry_state = transformation.PatternNode(sd.SDFGState) + # Available for explicit-latch rotated loops + loop_break = transformation.PatternNode(sd.SDFGState) + @classmethod def expressions(cls): # Case 1: Loop with one state @@ -70,7 +72,32 @@ def expressions(cls): ssdfg.add_edge(cls.loop_begin, cls.loop_begin, sd.InterstateEdge()) ssdfg.add_edge(cls.loop_begin, cls.exit_state, sd.InterstateEdge()) - return [sdfg, msdfg, rsdfg, rmsdfg, ssdfg] + # Case 6: Rotated multi-state loop with explicit exiting and latch states + mlrmsdfg = gr.OrderedDiGraph() + mlrmsdfg.add_nodes_from([cls.entry_state, cls.loop_break, cls.loop_latch, cls.loop_begin, cls.exit_state]) + mlrmsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + mlrmsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + mlrmsdfg.add_edge(cls.loop_break, cls.exit_state, sd.InterstateEdge()) + mlrmsdfg.add_edge(cls.loop_break, cls.loop_latch, sd.InterstateEdge()) + + # Case 7: Rotated single-state loop with explicit exiting and latch states + mlrsdfg = gr.OrderedDiGraph() + mlrsdfg.add_nodes_from([cls.entry_state, cls.loop_latch, cls.loop_begin, cls.exit_state]) + mlrsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + mlrsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + mlrsdfg.add_edge(cls.loop_begin, cls.exit_state, sd.InterstateEdge()) + mlrsdfg.add_edge(cls.loop_begin, cls.loop_latch, sd.InterstateEdge()) + + # Case 8: Guarded rotated multi-state loop with explicit exiting and latch states (modification of case 6) + gmlrmsdfg = gr.OrderedDiGraph() + gmlrmsdfg.add_nodes_from([cls.entry_state, cls.loop_break, cls.loop_latch, cls.loop_begin, cls.exit_state]) + gmlrmsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_begin, cls.loop_break, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_break, cls.exit_state, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_break, cls.loop_latch, sd.InterstateEdge()) + + return [sdfg, msdfg, rsdfg, rmsdfg, ssdfg, mlrmsdfg, mlrsdfg, gmlrmsdfg] def can_be_applied(self, graph: ControlFlowRegion, @@ -78,15 +105,21 @@ def can_be_applied(self, sdfg: sd.SDFG, permissive: bool = False) -> bool: if expr_index == 0: - return self.detect_loop(graph, False, permissive) is not None + return self.detect_loop(graph, multistate_loop=False, accept_missing_itvar=permissive) is not None elif expr_index == 1: - return self.detect_loop(graph, True, permissive) is not None + return self.detect_loop(graph, multistate_loop=True, accept_missing_itvar=permissive) is not None elif expr_index == 2: - return self.detect_rotated_loop(graph, False, permissive) is not None + return self.detect_rotated_loop(graph, multistate_loop=False, accept_missing_itvar=permissive) is not None elif expr_index == 3: - return self.detect_rotated_loop(graph, True, permissive) is not None + return self.detect_rotated_loop(graph, multistate_loop=True, accept_missing_itvar=permissive) is not None elif expr_index == 4: - return self.detect_self_loop(graph, permissive) is not None + return self.detect_self_loop(graph, accept_missing_itvar=permissive) is not None + elif expr_index in (5, 7): + return self.detect_rotated_loop(graph, multistate_loop=True, accept_missing_itvar=permissive, + separate_latch=True) is not None + elif expr_index == 6: + return self.detect_rotated_loop(graph, multistate_loop=False, accept_missing_itvar=permissive, + separate_latch=True) is not None raise ValueError(f'Invalid expression index {expr_index}') @@ -173,7 +206,7 @@ def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool, return next(iter(itvar)) def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, - accept_missing_itvar: bool = False) -> Optional[str]: + accept_missing_itvar: bool = False, separate_latch: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -189,6 +222,9 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, :return: The loop variable or ``None`` if not detected. """ latch = self.loop_latch + ltest = self.loop_latch + if separate_latch: + ltest = self.loop_break if multistate_loop else self.loop_begin begin = self.loop_begin # A for-loop start has at least two incoming edges (init and increment) @@ -196,7 +232,7 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, if len(begin_inedges) < 2: return None # A for-loop latch only has two outgoing edges (loop condition and exit-loop) - latch_outedges = graph.out_edges(latch) + latch_outedges = graph.out_edges(ltest) if len(latch_outedges) != 2: return None @@ -212,8 +248,13 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, # All nodes inside loop must be dominated by loop start dominators = nx.dominance.immediate_dominators(graph.nx, graph.start_block) - loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != latch)) - loop_nodes += [latch] + if begin is ltest: + loop_nodes = [begin] + else: + loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != ltest)) + loop_nodes.append(latch) + if ltest is not latch and ltest is not begin: + loop_nodes.append(ltest) backedge = None for node in loop_nodes: for e in graph.out_edges(node): @@ -235,33 +276,7 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, if backedge is None: return None - # The iteration variable must be reassigned on all incoming edges to the latch block. - # If an assignment overlap of exactly one variable is found between the initialization edge and the edges - # going into the latch block, that will be the iteration variable. - itvar_edge_set: Set[gr.Edge[InterstateEdge]] = set() - itvar_edge_set.update(begin_inedges) - itvar_edge_set.update(latch_inedges) - itvar = None - for iedge in itvar_edge_set: - if iedge is backedge: - continue - if itvar is None: - itvar = set(iedge.data.assignments.keys()) - else: - itvar &= iedge.data.assignments.keys() - if itvar is None: - return None - if len(itvar) != 1: - if not accept_missing_itvar: - # Either no consistent iteration variable found, or too many consistent iteration variables found - return None - else: - if len(itvar) == 0: - return '' - else: - return None - - return next(iter(itvar)) + return rotated_loop_find_itvar(begin_inedges, latch_inedges, backedge, ltest, accept_missing_itvar)[0] def detect_self_loop(self, graph: ControlFlowRegion, accept_missing_itvar: bool = False) -> Optional[str]: """ @@ -338,9 +353,10 @@ def loop_information( if self.expr_index <= 1: guard = self.loop_guard return find_for_loop(guard.parent_graph, guard, entry, itervar) - elif self.expr_index in (2, 3): + elif self.expr_index in (2, 3, 5, 6, 7): latch = self.loop_latch - return find_rotated_for_loop(latch.parent_graph, latch, entry, itervar) + return find_rotated_for_loop(latch.parent_graph, latch, entry, itervar, + separate_latch=(self.expr_index in (5, 6, 7))) elif self.expr_index == 4: return find_rotated_for_loop(entry.parent_graph, entry, entry, itervar) @@ -362,6 +378,14 @@ def loop_body(self) -> List[ControlFlowBlock]: return loop_nodes elif self.expr_index == 4: return [begin] + elif self.expr_index in (5, 7): + ltest = self.loop_break + latch = self.loop_latch + loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != ltest)) + loop_nodes += [ltest, latch] + return loop_nodes + elif self.expr_index == 6: + return [begin, self.loop_latch] return [] @@ -371,8 +395,10 @@ def loop_meta_states(self) -> List[ControlFlowBlock]: """ if self.expr_index in (0, 1): return [self.loop_guard] - if self.expr_index in (2, 3): + if self.expr_index in (2, 3, 6): return [self.loop_latch] + if self.expr_index in (5, 7): + return [self.loop_break, self.loop_latch] return [] def loop_init_edge(self) -> gr.Edge[InterstateEdge]: @@ -385,7 +411,7 @@ def loop_init_edge(self) -> gr.Edge[InterstateEdge]: guard = self.loop_guard body = self.loop_body() return next(e for e in graph.in_edges(guard) if e.src not in body) - elif self.expr_index in (2, 3): + elif self.expr_index in (2, 3, 5, 6, 7): latch = self.loop_latch return next(e for e in graph.in_edges(begin) if e.src is not latch) elif self.expr_index == 4: @@ -405,9 +431,12 @@ def loop_exit_edge(self) -> gr.Edge[InterstateEdge]: elif self.expr_index in (2, 3): latch = self.loop_latch return graph.edges_between(latch, exitstate)[0] - elif self.expr_index == 4: + elif self.expr_index in (4, 6): begin = self.loop_begin return graph.edges_between(begin, exitstate)[0] + elif self.expr_index in (5, 7): + ltest = self.loop_break + return graph.edges_between(ltest, exitstate)[0] raise ValueError(f'Invalid expression index {self.expr_index}') @@ -426,6 +455,10 @@ def loop_condition_edge(self) -> gr.Edge[InterstateEdge]: elif self.expr_index == 4: begin = self.loop_begin return graph.edges_between(begin, begin)[0] + elif self.expr_index in (5, 6, 7): + latch = self.loop_latch + ltest = self.loop_break if self.expr_index in (5, 7) else self.loop_begin + return graph.edges_between(ltest, latch)[0] raise ValueError(f'Invalid expression index {self.expr_index}') @@ -439,7 +472,7 @@ def loop_increment_edge(self) -> gr.Edge[InterstateEdge]: guard = self.loop_guard body = self.loop_body() return next(e for e in graph.in_edges(guard) if e.src in body) - elif self.expr_index in (2, 3): + elif self.expr_index in (2, 3, 5, 6, 7): body = self.loop_body() return graph.in_edges(self.loop_latch)[0] elif self.expr_index == 4: @@ -448,6 +481,84 @@ def loop_increment_edge(self) -> gr.Edge[InterstateEdge]: raise ValueError(f'Invalid expression index {self.expr_index}') +def rotated_loop_find_itvar(begin_inedges: List[gr.Edge[InterstateEdge]], + latch_inedges: List[gr.Edge[InterstateEdge]], + backedge: gr.Edge[InterstateEdge], latch: ControlFlowBlock, + accept_missing_itvar: bool = False) -> Tuple[Optional[str], + Optional[gr.Edge[InterstateEdge]]]: + # The iteration variable must be assigned (initialized) on all edges leading into the beginning block, which + # are not the backedge. Gather all variabes for which that holds - they are all candidates for the iteration + # variable (Phase 1). Said iteration variable must then be incremented: + # EITHER: On the backedge, in which case the increment is only executed if the loop does not exit. This + # corresponds to a while(true) loop that checks the condition at the end of the loop body and breaks + # if it does not hold before incrementing. (Scenario 1) + # OR: On the edge(s) leading into the latch, in which case the increment is executed BEFORE the condition is + # checked - which corresponds to a do-while loop. (Scenario 2) + # For either case, the iteration variable may only be incremented on one of these places. Filter the candidates + # down to each variable for which this condition holds (Phase 2). If there is exactly one candidate remaining, + # that is the iteration variable. Otherwise it cannot be determined. + + # Phase 1: Gather iteration variable candidates. + itvar_candidates = None + for e in begin_inedges: + if e is backedge: + continue + if itvar_candidates is None: + itvar_candidates = set(e.data.assignments.keys()) + else: + itvar_candidates &= set(e.data.assignments.keys()) + + # Phase 2: Filter down the candidates according to incrementation edges. + step_edge = None + filtered_candidates = set() + backedge_incremented = set(backedge.data.assignments.keys()) + latch_incremented = None + if backedge.src is not backedge.dst: + # If this is a self loop, there are no edges going into the latch to be considered. The only incoming edges are + # from outside the loop. + for e in latch_inedges: + if e is backedge: + continue + if latch_incremented is None: + latch_incremented = set(e.data.assignments.keys()) + else: + latch_incremented &= set(e.data.assignments.keys()) + if latch_incremented is None: + latch_incremented = set() + for cand in itvar_candidates: + if cand in backedge_incremented: + # Scenario 1. + + # TODO: Not sure if the condition below is a necessary prerequisite. + # Note, only allow this scenario if the backedge leads directly from the latch to the entry, i.e., there is + # no intermediate block on the backedge path. + if backedge.src is not latch: + continue + + if cand not in latch_incremented: + filtered_candidates.add(cand) + elif cand in latch_incremented: + # Scenario 2. + if cand not in backedge_incremented: + filtered_candidates.add(cand) + if len(filtered_candidates) != 1: + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None, None + else: + if len(filtered_candidates) == 0: + return '', None + else: + return None, None + else: + itvar = next(iter(filtered_candidates)) + if itvar in backedge_incremented: + step_edge = backedge + elif len(latch_inedges) == 1: + step_edge = latch_inedges[0] + return itvar, step_edge + + def find_for_loop( graph: ControlFlowRegion, guard: sd.SDFGState, @@ -548,6 +659,10 @@ def find_for_loop( match = condition.match(itersym >= a) if match: end = match[a] + if end is None: + match = condition.match(sp.Ne(itersym + stride, a)) + if match: + end = match[a] - stride if end is None: # No match found return None @@ -559,7 +674,8 @@ def find_rotated_for_loop( graph: ControlFlowRegion, latch: sd.SDFGState, entry: sd.SDFGState, - itervar: Optional[str] = None + itervar: Optional[str] = None, + separate_latch: bool = False, ) -> Optional[Tuple[AnyStr, Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType], Tuple[ List[sd.SDFGState], sd.SDFGState]]]: """ @@ -574,29 +690,19 @@ def find_rotated_for_loop( """ # Extract state transition edge information entry_inedges = graph.in_edges(entry) - condition_edge = graph.edges_between(latch, entry)[0] + if separate_latch: + condition_edge = graph.in_edges(latch)[0] + backedge = graph.edges_between(latch, entry)[0] + else: + condition_edge = graph.edges_between(latch, entry)[0] + backedge = condition_edge latch_inedges = graph.in_edges(latch) self_loop = latch is entry + step_edge = None if itervar is None: - # The iteration variable must be reassigned on all incoming edges to the latch block. - # If an assignment overlap of exactly one variable is found between the initialization edge and the edges - # going into the latch block, that will be the iteration variable. - itvar_edge_set: Set[gr.Edge[InterstateEdge]] = set() - itvar_edge_set.update(entry_inedges) - itvar_edge_set.update(latch_inedges) - itervars = None - for iedge in itvar_edge_set: - if iedge is condition_edge and not self_loop: - continue - if itervars is None: - itervars = set(iedge.data.assignments.keys()) - else: - itervars &= iedge.data.assignments.keys() - if itervars and len(itervars) == 1: - itervar = next(iter(itervars)) - else: - # Ambiguous or no iteration variable + itervar, step_edge = rotated_loop_find_itvar(entry_inedges, latch_inedges, backedge, latch) + if itervar is None: return None condition = condition_edge.data.condition_sympy() @@ -628,9 +734,7 @@ def find_rotated_for_loop( if self_loop: step_edge = condition_edge else: - step_edge = None if len(latch_inedges) != 1 else latch_inedges[0] if step_edge is None: - # No explicit step edge found. return None # Get the init expression and the stride. @@ -664,6 +768,10 @@ def find_rotated_for_loop( match = condition.match(itersym >= a) if match: end = match[a] + if end is None: + match = condition.match(sp.Ne(itersym + stride, a)) + if match: + end = match[a] - stride if end is None: # No match found return None diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index 5cad716176..2f92ab4e86 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -7,7 +7,7 @@ @properties.make_properties @transformation.experimental_cfg_block_compatible -class ControlFlowRaising(ppl.Pass): +class ControlFlowLifting(ppl.Pass): CATEGORY: str = 'Simplification' diff --git a/tests/transformations/loop_detection_test.py b/tests/transformations/loop_detection_test.py index 891d520f41..323a27787a 100644 --- a/tests/transformations/loop_detection_test.py +++ b/tests/transformations/loop_detection_test.py @@ -27,7 +27,8 @@ def tester(a: dace.float64[20]): assert rng == (1, 19, 1) -def test_loop_rotated(): +@pytest.mark.parametrize('increment_before_condition', (True, False)) +def test_loop_rotated(increment_before_condition): sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -37,8 +38,12 @@ def test_loop_rotated(): exitstate = sdfg.add_state('exitstate') sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 2'))) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) + if increment_before_condition: + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 2'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) + else: + sdfg.add_edge(body, latch, dace.InterstateEdge()) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 2'))) sdfg.add_edge(latch, exitstate, dace.InterstateEdge('i >= N')) xform = CountLoops() @@ -48,8 +53,9 @@ def test_loop_rotated(): assert rng == (0, dace.symbol('N') - 1, 2) -@pytest.mark.skip('Extra incrementation states should not be supported by loop detection') def test_loop_rotated_extra_increment(): + # Extra incrementation states (i.e., something more than a single edge between the latch and the body) should not + # be allowed and consequently not be detected as loops. sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -60,15 +66,13 @@ def test_loop_rotated_extra_increment(): exitstate = sdfg.add_state('exitstate') sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=0))) + sdfg.add_edge(body, latch, dace.InterstateEdge()) sdfg.add_edge(latch, increment, dace.InterstateEdge('i < N')) sdfg.add_edge(increment, body, dace.InterstateEdge(assignments=dict(i='i + 1'))) sdfg.add_edge(latch, exitstate, dace.InterstateEdge('i >= N')) xform = CountLoops() - assert sdfg.apply_transformations(xform) == 1 - itvar, rng, _ = xform.loop_information() - assert itvar == 'i' - assert rng == (0, dace.symbol('N') - 1, 1) + assert sdfg.apply_transformations(xform) == 0 def test_self_loop(): @@ -91,7 +95,8 @@ def test_self_loop(): assert rng == (2, dace.symbol('N') - 1, 3) -def test_loop_llvm_canonical(): +@pytest.mark.parametrize('increment_before_condition', (True, False)) +def test_loop_llvm_canonical(increment_before_condition): sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -106,8 +111,12 @@ def test_loop_llvm_canonical(): sdfg.add_edge(guard, exitstate, dace.InterstateEdge('N <= 0')) sdfg.add_edge(guard, preheader, dace.InterstateEdge('N > 0')) sdfg.add_edge(preheader, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 1'))) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) + if increment_before_condition: + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 1'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) + else: + sdfg.add_edge(body, latch, dace.InterstateEdge()) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) sdfg.add_edge(latch, loopexit, dace.InterstateEdge('i >= N')) sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) @@ -118,9 +127,10 @@ def test_loop_llvm_canonical(): assert rng == (0, dace.symbol('N') - 1, 1) -@pytest.mark.skip('Extra incrementation states should not be supported by loop detection') @pytest.mark.parametrize('with_bounds_check', (False, True)) def test_loop_llvm_canonical_with_extras(with_bounds_check): + # Extra incrementation states (i.e., something more than a single edge between the latch and the body) should not + # be allowed and consequently not be detected as loops. sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -148,17 +158,16 @@ def test_loop_llvm_canonical_with_extras(with_bounds_check): sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) xform = CountLoops() - assert sdfg.apply_transformations(xform) == 1 - itvar, rng, _ = xform.loop_information() - assert itvar == 'i' - assert rng == (0, dace.symbol('N') - 1, 1) + assert sdfg.apply_transformations(xform) == 0 if __name__ == '__main__': test_pyloop() - test_loop_rotated() - # test_loop_rotated_extra_increment() + test_loop_rotated(True) + test_loop_rotated(False) + test_loop_rotated_extra_increment() test_self_loop() - test_loop_llvm_canonical() - # test_loop_llvm_canonical_with_extras(False) - # test_loop_llvm_canonical_with_extras(True) + test_loop_llvm_canonical(True) + test_loop_llvm_canonical(False) + test_loop_llvm_canonical_with_extras(False) + test_loop_llvm_canonical_with_extras(True) From 3fbe26bbf5bf629b21c7b8f8b0616856818958f7 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 2 Oct 2024 15:32:54 +0200 Subject: [PATCH 13/14] Loop detection and lifting fixes --- dace/codegen/control_flow.py | 29 ++++++++++--------- dace/codegen/targets/framecode.py | 13 ++++++++- dace/sdfg/state.py | 10 ++++++- .../interstate/loop_detection.py | 5 ++-- .../transformation/interstate/loop_lifting.py | 17 +++++++---- .../interstate/loop_lifting_test.py | 15 +++++++--- 6 files changed, 62 insertions(+), 27 deletions(-) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index cfa5c8d41d..d0cd3da8b4 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -539,26 +539,27 @@ def as_cpp(self, codegen, symbols) -> str: expr = '' if self.loop.update_statement and self.loop.init_statement and self.loop.loop_variable: - # Initialize to either "int i = 0" or "i = 0" depending on whether the type has been defined. - defined_vars = codegen.dispatcher.defined_vars - if not defined_vars.has(self.loop.loop_variable): - try: - init = f'{symbols[self.loop.loop_variable]} ' - except KeyError: - init = 'auto ' - symbols[self.loop.loop_variable] = None - init += unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols) + init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols) init = init.strip(';') update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=symbols) update = update.strip(';') if self.loop.inverted: - expr += f'{init};\n' - expr += 'do {\n' - expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) - expr += f'{update};\n' - expr += f'\n}} while({cond});\n' + if self.loop.update_before_condition: + expr += f'{init};\n' + expr += 'do {\n' + expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) + expr += f'{update};\n' + expr += f'}} while({cond});\n' + else: + expr += f'{init};\n' + expr += 'while (1) {\n' + expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) + expr += f'if (!({cond}))\n' + expr += 'break;\n' + expr += f'{update};\n' + expr += '}\n' else: expr += f'for ({init}; {cond}; {update}) {{\n' expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 488c1c7fbd..2d3c524771 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -1,4 +1,5 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +import ast import collections import copy import re @@ -15,11 +16,12 @@ from dace.codegen.prettycode import CodeIOStream from dace.codegen.common import codeblock_to_cpp, sym2cpp from dace.codegen.targets.target import TargetCodeGenerator +from dace.codegen.tools.type_inference import infer_expr_type from dace.sdfg import SDFG, SDFGState, nodes from dace.sdfg import scope as sdscope from dace.sdfg import utils from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import ControlFlowRegion +from dace.sdfg.state import ControlFlowRegion, LoopRegion from dace.transformation.passes.analysis import StateReachability @@ -916,6 +918,15 @@ def generate_code(self, interstate_symbols.update(symbols) global_symbols.update(symbols) + if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None: + init_assignment = cfr.init_statement.code[0] + if isinstance(init_assignment, ast.Assign): + init_assignment = init_assignment.value + if not cfr.loop_variable in interstate_symbols: + interstate_symbols[cfr.loop_variable] = infer_expr_type(ast.unparse(init_assignment)) + if not cfr.loop_variable in global_symbols: + global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable] + for isvarName, isvarType in interstate_symbols.items(): if isvarType is None: raise TypeError(f'Type inference failed for symbol {isvarName}') diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 22ac601da1..4dc93a8d9d 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2987,6 +2987,12 @@ class LoopRegion(ControlFlowRegion): inverted = Property(dtype=bool, default=False, desc='If True, the loop condition is checked after the first iteration.') + update_before_condition = Property(dtype=bool, + default=True, + desc='If False, the loop condition is checked before the update statement is' + + ' executed. This only applies to inverted loops, turning them from a typical ' + + 'do-while style into a while(true) with a break before the update (at the end ' + + 'of an iteration)if the condition no longer holds.') loop_variable = Property(dtype=str, default='', desc='The loop variable, if given') def __init__(self, @@ -2996,7 +3002,8 @@ def __init__(self, initialize_expr: Optional[Union[str, CodeBlock]] = None, update_expr: Optional[Union[str, CodeBlock]] = None, inverted: bool = False, - sdfg: Optional['SDFG'] = None): + sdfg: Optional['SDFG'] = None, + update_before_condition = True): super(LoopRegion, self).__init__(label, sdfg) if initialize_expr is not None: @@ -3025,6 +3032,7 @@ def __init__(self, self.loop_variable = loop_var or '' self.inverted = inverted + self.update_before_condition = update_before_condition def inline(self) -> Tuple[bool, Any]: """ diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index daf13599fe..bd65cec290 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -473,8 +473,9 @@ def loop_increment_edge(self) -> gr.Edge[InterstateEdge]: body = self.loop_body() return next(e for e in graph.in_edges(guard) if e.src in body) elif self.expr_index in (2, 3, 5, 6, 7): - body = self.loop_body() - return graph.in_edges(self.loop_latch)[0] + _, step_edge = rotated_loop_find_itvar(graph.in_edges(begin), graph.in_edges(self.loop_latch), + graph.edges_between(self.loop_latch, begin)[0], self.loop_latch) + return step_edge elif self.expr_index == 4: return graph.edges_between(begin, begin)[0] diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py index 54363dd8e2..604aa74d16 100644 --- a/dace/transformation/interstate/loop_lifting.py +++ b/dace/transformation/interstate/loop_lifting.py @@ -36,7 +36,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): full_body.update(meta) cond_edge = self.loop_condition_edge() incr_edge = self.loop_increment_edge() - inverted = self.expr_index in (2, 3) + inverted = self.expr_index in (2, 3, 5, 6, 7) init_edge = self.loop_init_edge() exit_edge = self.loop_exit_edge() @@ -55,12 +55,19 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): if k != itvar: left_over_assignments[k] = init_edge.data.assignments[k] left_over_incr_assignments = {} - for k in incr_edge.data.assignments.keys(): - if k != itvar: - left_over_incr_assignments[k] = incr_edge.data.assignments[k] + if incr_edge is not None: + for k in incr_edge.data.assignments.keys(): + if k != itvar: + left_over_incr_assignments[k] = incr_edge.data.assignments[k] + + if inverted and incr_edge is cond_edge: + update_before_condition = False + else: + update_before_condition = True loop = LoopRegion(label, condition_expr=cond_edge.data.condition, loop_var=itvar, initialize_expr=init_expr, - update_expr=incr_expr, inverted=inverted, sdfg=sdfg) + update_expr=incr_expr, inverted=inverted, sdfg=sdfg, + update_before_condition=update_before_condition) graph.add_node(loop) graph.add_edge(init_edge.src, loop, diff --git a/tests/transformations/interstate/loop_lifting_test.py b/tests/transformations/interstate/loop_lifting_test.py index e3098d4e5c..843209794f 100644 --- a/tests/transformations/interstate/loop_lifting_test.py +++ b/tests/transformations/interstate/loop_lifting_test.py @@ -2,6 +2,7 @@ """ Tests loop raising trainsformations. """ import numpy as np +import pytest import dace from dace.memlet import Memlet from dace.sdfg.sdfg import SDFG, InterstateEdge @@ -52,7 +53,8 @@ def test_lift_regular_for_loop(): assert np.allclose(A_valid, A) -def test_lift_loop_llvm_canonical(): +@pytest.mark.parametrize('increment_before_condition', (True, False)) +def test_lift_loop_llvm_canonical(increment_before_condition): sdfg = dace.SDFG('llvm_canonical') N = dace.symbol('N') sdfg.add_symbol('i', dace.int32) @@ -72,8 +74,12 @@ def test_lift_loop_llvm_canonical(): sdfg.add_edge(guard, exitstate, InterstateEdge(condition='N <= 0')) sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) sdfg.add_edge(preheader, body, InterstateEdge(assignments={'i': 0, 'k': 0})) - sdfg.add_edge(body, latch, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) - sdfg.add_edge(latch, body, InterstateEdge(condition='i < N')) + if increment_before_condition: + sdfg.add_edge(body, latch, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N')) + else: + sdfg.add_edge(body, latch, InterstateEdge(assignments={'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N', assignments={'i': 'i + 2'})) sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) sdfg.add_edge(loopexit, exitstate, InterstateEdge()) @@ -160,5 +166,6 @@ def test_lift_loop_llvm_canonical_while(): if __name__ == '__main__': test_lift_regular_for_loop() - test_lift_loop_llvm_canonical() + test_lift_loop_llvm_canonical(True) + test_lift_loop_llvm_canonical(False) test_lift_loop_llvm_canonical_while() From 2bd5d007b82cdaddc818a9cacada5820b36d1150 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 2 Oct 2024 17:38:09 +0200 Subject: [PATCH 14/14] Work on conditional lifting --- dace/codegen/control_flow.py | 7 +- dace/codegen/targets/framecode.py | 6 +- .../simplification/control_flow_raising.py | 77 ++++++++++++++- .../control_flow_raising_test.py | 98 +++++++++++++++++++ 4 files changed, 177 insertions(+), 11 deletions(-) create mode 100644 tests/passes/simplification/control_flow_raising_test.py diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index d0cd3da8b4..f5559984e7 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -274,14 +274,11 @@ def as_cpp(self, codegen, symbols) -> str: for i, elem in enumerate(self.elements): expr += elem.as_cpp(codegen, symbols) # In a general block, emit transitions and assignments after each individual block or region. - if (isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region) or - (isinstance(elem, GeneralLoopScope) and elem.loop)): + if isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region): if isinstance(elem, BasicCFBlock): g_elem = elem.state - elif isinstance(elem, GeneralBlock): - g_elem = elem.region else: - g_elem = elem.loop + g_elem = elem.region cfg = g_elem.parent_graph sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg out_edges = cfg.out_edges(g_elem) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 2d3c524771..f7b8338269 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -1,5 +1,4 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import ast import collections import copy import re @@ -17,6 +16,7 @@ from dace.codegen.common import codeblock_to_cpp, sym2cpp from dace.codegen.targets.target import TargetCodeGenerator from dace.codegen.tools.type_inference import infer_expr_type +from dace.frontend.python import astutils from dace.sdfg import SDFG, SDFGState, nodes from dace.sdfg import scope as sdscope from dace.sdfg import utils @@ -920,10 +920,10 @@ def generate_code(self, if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None: init_assignment = cfr.init_statement.code[0] - if isinstance(init_assignment, ast.Assign): + if isinstance(init_assignment, astutils.ast.Assign): init_assignment = init_assignment.value if not cfr.loop_variable in interstate_symbols: - interstate_symbols[cfr.loop_variable] = infer_expr_type(ast.unparse(init_assignment)) + interstate_symbols[cfr.loop_variable] = infer_expr_type(astutils.unparse(init_assignment)) if not cfr.loop_variable in global_symbols: global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable] diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index 2f92ab4e86..5cfd6ffba6 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -1,13 +1,19 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Optional, Tuple +import networkx as nx from dace import properties +from dace.sdfg.analysis import cfg as cfg_analysis +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion +from dace.sdfg.utils import dfs_conditional from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.interstate.loop_lifting import LoopLifting @properties.make_properties @transformation.experimental_cfg_block_compatible -class ControlFlowLifting(ppl.Pass): +class ControlFlowRaising(ppl.Pass): CATEGORY: str = 'Simplification' @@ -17,6 +23,71 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.CFG - def apply_pass(self, top_sdfg: ppl.SDFG, _) -> ppl.Any | None: + def _lift_conditionals(self, sdfg: SDFG) -> int: + cfgs = list(sdfg.all_control_flow_regions()) + n_cond_regions_pre = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) + + for region in cfgs: + dummy_exit = region.add_state('__DACE_DUMMY') + for s in region.sink_nodes(): + if s is not dummy_exit: + region.add_edge(s, dummy_exit, InterstateEdge()) + idom = nx.immediate_dominators(region.nx, region.start_block) + alldoms = cfg_analysis.all_dominators(region, idom) + branch_merges = cfg_analysis.branch_merges(region, idom, alldoms) + + for block in region.nodes(): + graph = block.parent_graph + oedges = graph.out_edges(block) + if len(oedges) > 1 and block in branch_merges: + merge_block = branch_merges[block] + + # Construct the branching block. + conditional = ConditionalBlock('conditional_' + block.label, sdfg, graph) + graph.add_node(conditional) + # Connect it. + graph.add_edge(block, conditional, InterstateEdge()) + + # Populate branches. + for i, oe in enumerate(oedges): + branch_name = 'branch_' + str(i) + '_' + block.label + branch = ControlFlowRegion(branch_name, sdfg) + conditional.branches.append([oe.data.condition, branch]) + if oe.dst is merge_block: + # Empty branch. + continue + + branch_nodes = set(dfs_conditional(graph, [oe.dst], lambda _, x: x is not merge_block)) + branch_start = branch.add_state(branch_name + '_start', is_start_block=True) + branch.add_nodes_from(branch_nodes) + branch_end = branch.add_state(branch_name + '_end') + branch.add_edge(branch_start, oe.dst, InterstateEdge(assignments=oe.data.assignments)) + added = set() + for e in graph.all_edges(*branch_nodes): + if not (e in added): + added.add(e) + if e is oe: + continue + elif e.dst is merge_block: + branch.add_edge(e.src, branch_end, e.data) + else: + branch.add_edge(e.src, e.dst, e.data) + graph.remove_nodes_from(branch_nodes) + + # Connect to the end of the branch / what happens after. + if merge_block is not dummy_exit: + graph.add_edge(conditional, merge_block, InterstateEdge()) + region.remove_node(dummy_exit) + + n_cond_regions_post = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) + return n_cond_regions_post - n_cond_regions_pre + + def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int]]: + lifted_loops = 0 + lifted_branches = 0 for sdfg in top_sdfg.all_sdfgs_recursive(): - sdfg.apply_transformations_repeated([LoopLifting], validate_all=False, validate=False) + lifted_loops += sdfg.apply_transformations_repeated([LoopLifting], validate_all=False, validate=False) + lifted_branches += self._lift_conditionals(sdfg) + if lifted_branches == 0 and lifted_loops == 0: + return None + return lifted_loops, lifted_branches diff --git a/tests/passes/simplification/control_flow_raising_test.py b/tests/passes/simplification/control_flow_raising_test.py new file mode 100644 index 0000000000..53e01df12f --- /dev/null +++ b/tests/passes/simplification/control_flow_raising_test.py @@ -0,0 +1,98 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np +from dace.sdfg.state import ConditionalBlock +from dace.transformation.pass_pipeline import FixedPointPipeline, Pipeline +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising + + +def test_dataflow_if_check(): + + @dace.program + def dataflow_if_check(A: dace.int32[10], i: dace.int64): + if A[i] < 10: + return 0 + elif A[i] == 10: + return 10 + return 100 + + sdfg = dataflow_if_check.to_sdfg() + + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + ppl = FixedPointPipeline([ControlFlowRaising()]) + ppl.__experimental_cfg_block_compatible__ = True + ppl.apply_pass(sdfg, {}) + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + A = np.zeros((10,), np.int32) + A[4] = 10 + A[5] = 100 + assert sdfg(A, 0)[0] == 0 + assert sdfg(A, 4)[0] == 10 + assert sdfg(A, 5)[0] == 100 + assert sdfg(A, 6)[0] == 0 + + +def test_nested_if_chain(): + + @dace.program + def nested_if_chain(i: dace.int64): + if i < 2: + return 0 + else: + if i < 4: + return 1 + else: + if i < 6: + return 2 + else: + if i < 8: + return 3 + else: + return 4 + + sdfg = nested_if_chain.to_sdfg() + + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + assert nested_if_chain(0)[0] == 0 + assert nested_if_chain(2)[0] == 1 + assert nested_if_chain(4)[0] == 2 + assert nested_if_chain(7)[0] == 3 + assert nested_if_chain(15)[0] == 4 + + +def test_elif_chain(): + + @dace.program + def elif_chain(i: dace.int64): + if i < 2: + return 0 + elif i < 4: + return 1 + elif i < 6: + return 2 + elif i < 8: + return 3 + else: + return 4 + + elif_chain.use_experimental_cfg_blocks = True + sdfg = elif_chain.to_sdfg() + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + assert elif_chain(0)[0] == 0 + assert elif_chain(2)[0] == 1 + assert elif_chain(4)[0] == 2 + assert elif_chain(7)[0] == 3 + assert elif_chain(15)[0] == 4 + + +if __name__ == '__main__': + test_dataflow_if_check() + test_nested_if_chain() + test_elif_chain()