diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index c40c05aa7f..4844720fa1 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -747,6 +747,7 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, Determines what data is read and written in this subgraph, returning dictionaries from data containers to all subsets that are read/written. """ + from dace.sdfg import utils # Avoid cyclic import # Ensures that the `{src,dst}_subset` are properly set. # TODO: find where the problems are @@ -755,23 +756,30 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, read_set = collections.defaultdict(list) write_set = collections.defaultdict(list) - from dace.sdfg import utils # Avoid cyclic import - subgraphs = utils.concurrent_subgraphs(self) - for sg in subgraphs: - rs = collections.defaultdict(list) - ws = collections.defaultdict(list) + + for subgraph in utils.concurrent_subgraphs(self): + subgraph_read_set = collections.defaultdict(list) # read and write set of this subgraph. + subgraph_write_set = collections.defaultdict(list) # Traverse in topological order, so data that is written before it # is read is not counted in the read set # TODO: This only works if every data descriptor is only once in a path. - for n in utils.dfs_topological_sort(sg, sources=sg.source_nodes()): + for n in utils.dfs_topological_sort(subgraph, sources=subgraph.source_nodes()): if not isinstance(n, nd.AccessNode): + # Read and writes can only be done through access nodes, + # so ignore every other node. continue + + # Get a list of all incoming (writes) and outgoing (reads) edges of the + # access node, ignore all empty memlets as they do not carry any data. + in_edges = [in_edge for in_edge in subgraph.in_edges(n) if not in_edge.data.is_empty()] + out_edges = [out_edge for out_edge in subgraph.out_edges(n) if not out_edge.data.is_empty()] + + # Extract the subsets that describes where we read and write the data + # and store them for the later filtering. + # NOTE: In certain cases the corresponding subset might be None, in this case + # we assume that the whole array is written, which is the default behaviour. ac_desc = n.desc(self.sdfg) ac_size = ac_desc.total_size - in_edges = [in_edge for in_edge in sg.in_edges(n) if not in_edge.data.is_empty()] - out_edges = [out_edge for out_edge in sg.out_edges(n) if not out_edge.data.is_empty()] - - # In some conditions subsets can be `None`, we will now clean them. in_subsets = dict() for in_edge in in_edges: assert in_edge.data.dst_subset is not None or (in_edge.data.num_elements() == ac_size) @@ -789,12 +797,12 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, else out_edge.data.src_subset ) - # Filter out memlets which go out but the same data is written to the AccessNode by another memlet + # Filter out reads that are also written at the access node by another (single) write. for out_edge in list(out_edges): for in_edge in in_edges: if out_edge.data.data != in_edge.data.data: - # NOTE: This check does not make any sense, and is in my view wrong. - # As it will filter out some accesses but not all, which one solely + # NOTE: This check does not make any sense, and is in my (@philip-paul-mueller) + # view wrong. As it will filter out some accesses but not all, which one solely # depends on how the memelts were created. # See also [issue #1643](https://github.com/spcl/dace/issues/1643). continue @@ -803,16 +811,16 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, break if in_edges: - ws[n.data].extend(in_subsets.values()) + subgraph_write_set[n.data].extend(in_subsets.values()) if out_edges: - rs[n.data].extend(out_subsets[out_edge] for out_edge in out_edges) + subgraph_read_set[n.data].extend(out_subsets[out_edge] for out_edge in out_edges) # Union all subgraphs, so an array that was excluded from the read # set because it was written first is still included if it is read # in another subgraph - for data, accesses in rs.items(): + for data, accesses in subgraph_read_set.items(): read_set[data] += accesses - for data, accesses in ws.items(): + for data, accesses in subgraph_write_set.items(): write_set[data] += accesses return copy.deepcopy((read_set, write_set))