diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index ae9351fc43..7701a19ec2 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -62,7 +62,7 @@ import sympy as sp from dace import dtypes from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import (BreakBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, +from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, ReturnBlock, SDFGState) from dace.sdfg.sdfg import SDFG, InterstateEdge from dace.sdfg.graph import Edge @@ -236,14 +236,18 @@ def first_block(self) -> ReturnBlock: @dataclass -class GeneralBlock(ControlFlow): - """ - General (or unrecognized) control flow block with gotos between blocks. - """ +class RegionBlock(ControlFlow): # The control flow region that this block corresponds to (may be the SDFG in the absence of hierarchical regions). region: Optional[ControlFlowRegion] + +@dataclass +class GeneralBlock(RegionBlock): + """ + General (or unrecognized) control flow block with gotos between blocks. + """ + # List of children control flow blocks elements: List[ControlFlow] @@ -270,7 +274,7 @@ def as_cpp(self, codegen, symbols) -> str: for i, elem in enumerate(self.elements): expr += elem.as_cpp(codegen, symbols) # In a general block, emit transitions and assignments after each individual block or region. - if isinstance(elem, BasicCFBlock) or (isinstance(elem, GeneralBlock) and elem.region): + if isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region): cfg = elem.state.parent_graph if isinstance(elem, BasicCFBlock) else elem.region.parent_graph sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg out_edges = cfg.out_edges(elem.state) if isinstance(elem, BasicCFBlock) else cfg.out_edges(elem.region) @@ -514,10 +518,9 @@ def children(self) -> List[ControlFlow]: @dataclass -class GeneralLoopScope(ControlFlow): +class GeneralLoopScope(RegionBlock): """ General loop block based on a loop control flow region. """ - loop: LoopRegion body: ControlFlow def as_cpp(self, codegen, symbols) -> str: @@ -565,6 +568,10 @@ def as_cpp(self, codegen, symbols) -> str: return expr + @property + def loop(self) -> LoopRegion: + return self.region + @property def first_block(self) -> ControlFlowBlock: return self.loop.start_block @@ -601,6 +608,46 @@ def children(self) -> List[ControlFlow]: return list(self.cases.values()) +@dataclass +class GeneralConditionalScope(RegionBlock): + """ General conditional block based on a conditional control flow region. """ + + branch_bodies: List[Tuple[Optional[CodeBlock], ControlFlow]] + + def as_cpp(self, codegen, symbols) -> str: + sdfg = self.conditional.sdfg + expr = '' + for i in range(len(self.branch_bodies)): + branch = self.branch_bodies[i] + if branch[0] is not None: + cond = unparse_interstate_edge(branch[0].code, sdfg, codegen=codegen, symbols=symbols) + cond = cond.strip(';') + if i == 0: + expr += f'if ({cond}) {{\n' + else: + expr += f'}} else if ({cond}) {{\n' + else: + if i < len(self.branch_bodies) - 1 or i == 0: + raise RuntimeError('Missing branch condition for non-final conditional branch') + expr += '} else {\n' + expr += branch[1].as_cpp(codegen, symbols) + if i == len(self.branch_bodies) - 1: + expr += '}\n' + return expr + + @property + def conditional(self) -> ConditionalBlock: + return self.region + + @property + def first_block(self) -> ControlFlowBlock: + return self.conditional + + @property + def children(self) -> List[ControlFlow]: + return [b for _, b in self.branch_bodies] + + def _loop_from_structure(sdfg: SDFG, guard: SDFGState, enter_edge: Edge[InterstateEdge], leave_edge: Edge[InterstateEdge], back_edges: List[Edge[InterstateEdge]], dispatch_state: Callable[[SDFGState], @@ -973,7 +1020,6 @@ def _structured_control_flow_traversal_with_regions(cfg: ControlFlowRegion, if branch_merges is None: branch_merges = cfg_analysis.branch_merges(cfg) - if ptree is None: ptree = cfg_analysis.block_parent_tree(cfg, with_loops=False) @@ -1004,6 +1050,14 @@ def make_empty_block(): cfg_block = ContinueCFBlock(dispatch_state, parent_block, True, node) elif isinstance(node, ReturnBlock): cfg_block = ReturnCFBlock(dispatch_state, parent_block, True, node) + elif isinstance(node, ConditionalBlock): + cfg_block = GeneralConditionalScope(dispatch_state, parent_block, False, node, []) + for cond, branch in node.branches: + if branch is not None: + body = make_empty_block() + body.parent = cfg_block + _structured_control_flow_traversal_with_regions(branch, dispatch_state, body) + cfg_block.branch_bodies.append((cond, body)) elif isinstance(node, ControlFlowRegion): if isinstance(node, LoopRegion): body = make_empty_block() @@ -1027,69 +1081,8 @@ def make_empty_block(): stack.append(oe[0].dst) parent_block.elements.append(cfg_block) continue - - # Potential branch or loop - if node in branch_merges: - mergeblock = branch_merges[node] - - # Add branching node and ignore outgoing edges - parent_block.elements.append(cfg_block) - parent_block.gotos_to_ignore.extend(oe) # TODO: why? - parent_block.assignments_to_ignore.extend(oe) # TODO: why? - cfg_block.last_block = True - - # Parse all outgoing edges recursively first - cblocks: Dict[Edge[InterstateEdge], GeneralBlock] = {} - for branch in oe: - if branch.dst is mergeblock: - # If we hit the merge state (if without else), defer to end of branch traversal - continue - cblocks[branch] = make_empty_block() - _structured_control_flow_traversal_with_regions(cfg=cfg, - dispatch_state=dispatch_state, - parent_block=cblocks[branch], - start=branch.dst, - stop=mergeblock, - generate_children_of=node, - branch_merges=branch_merges, - ptree=ptree, - visited=visited) - - # Classify branch type: - branch_block = None - # If there are 2 out edges, one negation of the other: - # * if/else in case both branches are not merge state - # * if without else in case one branch is merge state - if (len(oe) == 2 and oe[0].data.condition_sympy() == sp.Not(oe[1].data.condition_sympy())): - if oe[0].dst is mergeblock: - # If without else - branch_block = IfScope(dispatch_state, parent_block, False, node, oe[1].data.condition, - cblocks[oe[1]]) - elif oe[1].dst is mergeblock: - branch_block = IfScope(dispatch_state, parent_block, False, node, oe[0].data.condition, - cblocks[oe[0]]) - else: - branch_block = IfScope(dispatch_state, parent_block, False, node, oe[0].data.condition, - cblocks[oe[0]], cblocks[oe[1]]) - else: - # If there are 2 or more edges (one is not the negation of the - # other): - switch = _cases_from_branches(oe, cblocks) - if switch: - # If all edges are of form "x == y" for a single x and - # integer y, it is a switch/case - branch_block = SwitchCaseScope(dispatch_state, parent_block, False, node, switch[0], switch[1]) - else: - # Otherwise, create if/else if/.../else goto exit chain - branch_block = IfElseChain(dispatch_state, parent_block, False, node, - [(e.data.condition, cblocks[e] if e in cblocks else make_empty_block()) - for e in oe]) - # End of branch classification - parent_block.elements.append(branch_block) - if mergeblock != stop: - stack.append(mergeblock) - - else: # No merge state: Unstructured control flow + else: + # Unstructured control flow. parent_block.sequential = False parent_block.elements.append(cfg_block) stack.extend([e.dst for e in oe]) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index da25816f9b..488c1c7fbd 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -483,7 +483,7 @@ def dispatch_state(state: SDFGState) -> str: states_generated.add(state) # For sanity check return stream.getvalue() - if sdfg.root_sdfg.using_experimental_blocks: + if sdfg.root_sdfg.recheck_using_experimental_blocks(): # Use control flow blocks embedded in the SDFG to generate control flow. cft = cflow.structured_control_flow_tree_with_regions(sdfg, dispatch_state) elif config.Config.get_bool('optimizer', 'detect_control_flow'): diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index e2cc2be88b..407e9eb91c 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -3,7 +3,9 @@ from functools import reduce from itertools import chain from string import ascii_letters -from typing import Dict, Optional +from typing import Dict, List, Optional + +import numpy as np import dace from dace import dtypes, subsets, symbolic @@ -180,6 +182,19 @@ def create_einsum_sdfg(pv: 'dace.frontend.python.newast.ProgramVisitor', beta=beta)[0] +def _build_einsum_views(tensors: str, dimension_dict: dict) -> List[np.ndarray]: + """ + Function taken and adjusted from opt_einsum package version 3.3.0 following unexpected removal in vesion 3.4.0. + Reference: https://github.com/dgasmith/opt_einsum/blob/v3.3.0/opt_einsum/helpers.py#L18 + """ + views = [] + terms = tensors.split('->')[0].split(',') + for term in terms: + dims = [dimension_dict[x] for x in term] + views.append(np.random.rand(*dims)) + return views + + def _create_einsum_internal(sdfg: SDFG, state: SDFGState, einsum_string: str, @@ -231,7 +246,7 @@ def _create_einsum_internal(sdfg: SDFG, # Create optimal contraction path # noinspection PyTypeChecker - _, path_info = oe.contract_path(einsum_string, *oe.helpers.build_views(einsum_string, chardict)) + _, path_info = oe.contract_path(einsum_string, *_build_einsum_views(einsum_string, chardict)) input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays} result_node = None diff --git a/dace/frontend/python/astutils.py b/dace/frontend/python/astutils.py index c9a400e5f1..425e94cd9f 100644 --- a/dace/frontend/python/astutils.py +++ b/dace/frontend/python/astutils.py @@ -384,6 +384,48 @@ def negate_expr(node): return ast.fix_missing_locations(newexpr) +def and_expr(node_a, node_b): + """ Generates the logical AND of two AST expressions. + """ + if type(node_a) is not type(node_b): + raise ValueError('Node types do not match') + + # Support for SymPy expressions + if isinstance(node_a, sympy.Basic): + return sympy.And(node_a, node_b) + # Support for numerical constants + if isinstance(node_a, (numbers.Number, numpy.bool_)): + return str(node_a and node_b) + # Support for strings (most likely dace.Data.Scalar names) + if isinstance(node_a, str): + return f'({node_a}) and ({node_b})' + + from dace.properties import CodeBlock # Avoid import loop + if isinstance(node_a, CodeBlock): + node_a = node_a.code + node_b = node_b.code + + if hasattr(node_a, "__len__"): + if len(node_a) > 1: + raise ValueError("and_expr only expects single expressions, got: {}".format(node_a)) + if len(node_b) > 1: + raise ValueError("and_expr only expects single expressions, got: {}".format(node_b)) + expr_a = node_a[0] + expr_b = node_b[0] + else: + expr_a = node_a + expr_b = node_b + + if isinstance(expr_a, ast.Expr): + expr_a = expr_a.value + if isinstance(expr_b, ast.Expr): + expr_b = expr_b.value + + newexpr = ast.Expr(value=ast.BinOp(left=copy_tree(expr_a), op=ast.And, right=copy_tree(expr_b))) + newexpr = ast.copy_location(newexpr, expr_a) + return ast.fix_missing_locations(newexpr) + + def copy_tree(node: ast.AST) -> ast.AST: """ Copies an entire AST without copying the non-AST parts (e.g., constant values). diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index 790f2de506..14164054d3 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -44,6 +44,7 @@ def program(f: F, recompile: bool = True, distributed_compilation: bool = False, constant_functions=False, + use_experimental_cfg_blocks=False, **kwargs) -> Callable[..., parser.DaceProgram]: """ Entry point to a data-centric program. For methods and ``classmethod``s, use @@ -68,6 +69,8 @@ def program(f: F, not depend on internal variables are constant. This will hardcode their return values into the resulting program. + :param use_experimental_cfg_blocks: If True, makes use of experimental CFG blocks susch as loop and conditional + regions. :note: If arguments are defined with type hints, the program can be compiled ahead-of-time with ``.compile()``. """ @@ -83,7 +86,8 @@ def program(f: F, recreate_sdfg=recreate_sdfg, regenerate_code=regenerate_code, recompile=recompile, - distributed_compilation=distributed_compilation) + distributed_compilation=distributed_compilation, + use_experimental_cfg_blocks=use_experimental_cfg_blocks) function = program diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 60469919f5..0d40e13282 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3,7 +3,6 @@ from collections import OrderedDict import copy import itertools -import inspect import networkx as nx import re import sys @@ -25,14 +24,14 @@ from dace.frontend.python.astutils import ExtNodeVisitor, ExtNodeTransformer from dace.frontend.python.astutils import rname from dace.frontend.python import nested_call, replacements, preprocessing -from dace.frontend.python.memlet_parser import (DaceSyntaxError, parse_memlet, pyexpr_to_symbolic, ParseMemlet, - inner_eval_ast, MemletExpr) -from dace.sdfg import nodes, utils as sdutil +from dace.frontend.python.memlet_parser import DaceSyntaxError, parse_memlet, ParseMemlet, inner_eval_ast, MemletExpr +from dace.sdfg import nodes from dace.sdfg.propagation import propagate_memlet, propagate_subset, propagate_states from dace.memlet import Memlet from dace.properties import LambdaProperty, CodeBlock from dace.sdfg import SDFG, SDFGState -from dace.sdfg.state import BreakBlock, ContinueBlock, ControlFlowBlock, FunctionCallRegion, LoopRegion, ControlFlowRegion, NamedRegion +from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, FunctionCallRegion, + LoopRegion, ControlFlowRegion, NamedRegion) from dace.sdfg.replace import replace_datadesc_names from dace.symbolic import pystr_to_symbolic, inequal_symbols @@ -1301,7 +1300,7 @@ def _views_to_data(state: SDFGState, nodes: List[dace.nodes.AccessNode]) -> List return new_nodes # Map view access nodes to their respective data - for state in self.sdfg.states(): + for state in self.sdfg.all_states(): # NOTE: We need to support views of views nodes = list(state.data_nodes()) while nodes: @@ -2371,7 +2370,7 @@ def visit_For(self, node: ast.For): extra_symbols=extra_syms, parent=loop_region, unconnected_last_block=False) loop_region.start_block = loop_region.node_id(first_subblock) - + self._connect_break_blocks(loop_region) # Handle else clause if node.orelse: # Continue visiting body @@ -2509,14 +2508,17 @@ def visit_While(self, node: ast.While): self._generate_orelse(loop_region, postloop_block) self.last_block = loop_region + self._connect_break_blocks(loop_region) + + def _connect_break_blocks(self, loop_region: LoopRegion): + for node, parent in loop_region.all_nodes_recursive(lambda n, _: not isinstance(n, (LoopRegion, SDFGState))): + if isinstance(node, BreakBlock): + for in_edge in parent.in_edges(node): + in_edge.data.assignments['__dace_did_break_' + loop_region.label] = '1' def _generate_orelse(self, loop_region: LoopRegion, postloop_block: ControlFlowBlock): - did_break_symbol = 'did_break_' + loop_region.label + did_break_symbol = '__dace_did_break_' + loop_region.label self.sdfg.add_symbol(did_break_symbol, dace.int32) - for n in loop_region.nodes(): - if isinstance(n, BreakBlock): - for iedge in loop_region.in_edges(n): - iedge.data.assignments[did_break_symbol] = '1' for iedge in self.cfg_target.in_edges(loop_region): iedge.data.assignments[did_break_symbol] = '0' oedges = self.cfg_target.out_edges(loop_region) @@ -2525,61 +2527,59 @@ def _generate_orelse(self, loop_region: LoopRegion, postloop_block: ControlFlowB intermediate = self.cfg_target.add_state(f'{loop_region.label}_normal_exit') self.cfg_target.add_edge(loop_region, intermediate, - dace.InterstateEdge(condition=f"(not {did_break_symbol} == 1)")) + dace.InterstateEdge(condition=f'(not {did_break_symbol} == 1)')) oedge = oedges[0] self.cfg_target.add_edge(intermediate, oedge.dst, copy.deepcopy(oedge.data)) self.cfg_target.remove_edge(oedge) - self.cfg_target.add_edge(loop_region, postloop_block, dace.InterstateEdge(condition=f"{did_break_symbol} == 1")) + self.cfg_target.add_edge(loop_region, postloop_block, dace.InterstateEdge(condition=f'{did_break_symbol} == 1')) + + def _has_loop_ancestor(self, node: ControlFlowBlock) -> bool: + while node is not None and node is not self.sdfg: + if isinstance(node, LoopRegion): + return True + node = node.parent_graph + return False + def visit_Break(self, node: ast.Break): - if isinstance(self.cfg_target, LoopRegion): - self._on_block_added(self.cfg_target.add_break(f'break_{self.cfg_target.label}_{node.lineno}')) - else: - error_msg = "'break' is only supported inside loops " - if self.nested: - error_msg += ("('break' is not supported in Maps and cannot be used in nested DaCe program calls to " - " break out of loops of outer scopes)") - raise DaceSyntaxError(self, node, error_msg) + if not self._has_loop_ancestor(self.cfg_target): + raise DaceSyntaxError(self, node, "Break block outside loop region") + break_block = BreakBlock(f'break_{node.lineno}') + self.cfg_target.add_node(break_block, ensure_unique_name=True) + self._on_block_added(break_block) def visit_Continue(self, node: ast.Continue): - if isinstance(self.cfg_target, LoopRegion): - self._on_block_added(self.cfg_target.add_continue(f'continue_{self.cfg_target.label}_{node.lineno}')) - else: - error_msg = ("'continue' is only supported inside loops ") - if self.nested: - error_msg += ("('continue' is not supported in Maps and cannot be used in nested DaCe program calls to " - " continue loops of outer scopes)") - raise DaceSyntaxError(self, node, error_msg) + if not self._has_loop_ancestor(self.cfg_target): + raise DaceSyntaxError(self, node, 'Continue block outside loop region') + continue_block = ContinueBlock(f'continue_{node.lineno}') + self.cfg_target.add_node(continue_block, ensure_unique_name=True) + self._on_block_added(continue_block) def visit_If(self, node: ast.If): - # Add a guard state - self._add_state('if_guard') - self.last_block.debuginfo = self.current_lineinfo - # Generate conditions - cond, cond_else, _ = self._visit_test(node.test) + cond, _, _ = self._visit_test(node.test) - # Visit recursively - laststate, first_if_state, last_if_state, return_stmt = \ - self._recursive_visit(node.body, 'if', node.lineno, self.cfg_target, True) - end_if_state = self.last_block + # Add conditional region + cond_block = ConditionalBlock(f'if_{node.lineno}') + self.cfg_target.add_node(cond_block) + self._on_block_added(cond_block) - # Connect the states - self.cfg_target.add_edge(laststate, first_if_state, dace.InterstateEdge(cond)) - self.cfg_target.add_edge(last_if_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) + if_body = ControlFlowRegion(cond_block.label + '_body', sdfg=self.sdfg) + cond_block.branches.append((CodeBlock(cond), if_body)) + if_body.parent_graph = self.cfg_target + + # Visit recursively + self._recursive_visit(node.body, 'if', node.lineno, if_body, False) # Process 'else'/'elif' statements if len(node.orelse) > 0: + else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', + sdfg=self.sdfg) + #cond_block.branches.append((CodeBlock(cond_else), else_body)) + cond_block.branches.append((None, else_body)) + else_body.parent_graph = self.cfg_target # Visit recursively - _, first_else_state, last_else_state, return_stmt = \ - self._recursive_visit(node.orelse, 'else', node.lineno, self.cfg_target, False) - - # Connect the states - self.cfg_target.add_edge(laststate, first_else_state, dace.InterstateEdge(cond_else)) - self.cfg_target.add_edge(last_else_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) - else: - self.cfg_target.add_edge(laststate, end_if_state, dace.InterstateEdge(cond_else)) - self.last_block = end_if_state + self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False) def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index e55829933c..b0ef56907f 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -494,8 +494,9 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify) if not self.use_experimental_cfg_blocks: - sdutils.inline_loop_blocks(sdfg) - sdutils.inline_control_flow_regions(sdfg) + for nsdfg in sdfg.all_sdfgs_recursive(): + sdutils.inline_conditional_blocks(nsdfg) + sdutils.inline_control_flow_regions(nsdfg) sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks # Apply simplification pass automatically diff --git a/dace/sdfg/analysis/cfg.py b/dace/sdfg/analysis/cfg.py index 1d5b1e50eb..c96ef5aff0 100644 --- a/dace/sdfg/analysis/cfg.py +++ b/dace/sdfg/analysis/cfg.py @@ -6,7 +6,7 @@ import sympy as sp from typing import Dict, Iterator, List, Optional, Set -from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion def acyclic_dominance_frontier(cfg: ControlFlowRegion, idom=None) -> Dict[ControlFlowBlock, Set[ControlFlowBlock]]: @@ -374,6 +374,13 @@ def blockorder_topological_sort(cfg: ControlFlowRegion, yield block if recursive: yield from blockorder_topological_sort(block, recursive, ignore_nonstate_blocks) + elif isinstance(block, ConditionalBlock): + if not ignore_nonstate_blocks: + yield block + for _, branch in block.branches: + if not ignore_nonstate_blocks: + yield branch + yield from blockorder_topological_sort(branch, recursive, ignore_nonstate_blocks) elif isinstance(block, SDFGState): yield block else: diff --git a/dace/sdfg/analysis/cutout.py b/dace/sdfg/analysis/cutout.py index 50272167bb..5d2eae7c6f 100644 --- a/dace/sdfg/analysis/cutout.py +++ b/dace/sdfg/analysis/cutout.py @@ -13,7 +13,7 @@ from dace.sdfg import nodes as nd, SDFG, SDFGState, utils as sdutil, InterstateEdge from dace.memlet import Memlet from dace.sdfg.graph import Edge, MultiConnectorEdge -from dace.sdfg.state import StateSubgraphView, SubgraphView +from dace.sdfg.state import ControlFlowBlock, StateSubgraphView, SubgraphView from dace.transformation.transformation import (MultiStateTransformation, PatternTransformation, SubgraphTransformation, @@ -321,7 +321,8 @@ def singlestate_cutout(cls, @classmethod def multistate_cutout(cls, *states: SDFGState, - make_side_effects_global: bool = True) -> Union['SDFGCutout', SDFG]: + make_side_effects_global: bool = True, + override_start_block: Optional[ControlFlowBlock] = None) -> Union['SDFGCutout', SDFG]: """ Cut out a multi-state subgraph from an SDFG to run separately for localized testing or optimization. @@ -336,6 +337,9 @@ def multistate_cutout(cls, :param make_side_effects_global: If True, all transient data containers which are read inside the cutout but may be written to _before_ the cutout, or any data containers which are written to inside the cutout but may be read _after_ the cutout, are made global. + :param override_start_block: If set, explicitly force a given control flow block to be the start block. If left + None (default), the start block is automatically determined based on domination + relationships in the original graph. :return: The created SDFGCutout or the original SDFG where no smaller cutout could be obtained. """ create_element = copy.deepcopy @@ -350,10 +354,13 @@ def multistate_cutout(cls, # Determine the start state and ensure there IS a unique start state. If there is no unique start state, keep # adding states from the predecessor frontier in the state machine until a unique start state can be determined. start_state: Optional[SDFGState] = None - for state in cutout_states: - if state == sdfg.start_state: - start_state = state - break + if override_start_block is not None: + start_state = override_start_block + else: + for state in cutout_states: + if state == sdfg.start_state: + start_state = state + break if start_state is None: bfs_queue: Deque[Tuple[Set[SDFGState], Set[Edge[InterstateEdge]]]] = deque() diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 5e5df1b0a2..71b37ea7b7 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -23,7 +23,7 @@ from dace.config import Config from dace.frontend.python import astutils from dace.sdfg import nodes as nd -from dace.sdfg.state import ControlFlowBlock, SDFGState, ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, SDFGState, ControlFlowRegion from dace.distr_types import ProcessGrid, SubArray, RedistrArray from dace.dtypes import validate_name from dace.properties import (DebugInfoProperty, EnumProperty, ListProperty, make_properties, Property, CodeProperty, @@ -1512,6 +1512,17 @@ def shared_transients(self, check_toplevel: bool = True, include_nested_data: bo seen[sym] = interstate_edge shared.append(sym) + # The same goes for the conditions of conditional blocks. + for block in self.all_control_flow_blocks(): + if isinstance(block, ConditionalBlock): + for cond, _ in block.branches: + if cond is not None: + cond_symbols = set(map(str, dace.symbolic.symbols_in_ast(cond.code[0]))) + for sym in cond_symbols: + if sym in self.arrays and self.arrays[sym].transient: + seen[sym] = block + shared.append(sym) + # If transient is accessed in more than one state, it is shared for state in self.states(): for node in state.data_nodes(): diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index e8a8161747..8d443e6beb 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -11,7 +11,10 @@ from typing import (TYPE_CHECKING, Any, AnyStr, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, overload) +import sympy + import dace +from dace.frontend.python import astutils import dace.serialize from dace import data as dt from dace import dtypes @@ -22,8 +25,8 @@ from dace.properties import (CodeBlock, DebugInfoProperty, DictProperty, EnumProperty, Property, SubsetProperty, SymbolicProperty, CodeProperty, make_properties) from dace.sdfg import nodes as nd -from dace.sdfg.graph import (MultiConnectorEdge, OrderedMultiDiConnectorGraph, SubgraphView, OrderedDiGraph, Edge, - generate_element_id) +from dace.sdfg.graph import (MultiConnectorEdge, NodeNotFoundError, OrderedMultiDiConnectorGraph, SubgraphView, + OrderedDiGraph, Edge, generate_element_id) from dace.sdfg.propagation import propagate_memlet from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset @@ -1140,6 +1143,11 @@ def set_default_lineinfo(self, lineinfo: dace.dtypes.DebugInfo): """ self._default_lineinfo = lineinfo + def view(self): + from dace.sdfg.analysis.cutout import SDFGCutout + cutout = SDFGCutout.multistate_cutout(self, make_side_effects_global=False, override_start_block=self) + cutout.view() + def to_json(self, parent=None): tmp = { 'type': self.__class__.__name__, @@ -2561,21 +2569,21 @@ def inline(self) -> Tuple[bool, Any]: """ parent = self.parent_graph if parent: - end_state = parent.add_state(self.label + '_end') # Add all region states and make sure to keep track of all the ones that need to be connected in the end. to_connect: Set[SDFGState] = set() block_to_state_map: Dict[ControlFlowBlock, SDFGState] = dict() for node in self.nodes(): node.label = self.label + '_' + node.label - parent.add_node(node, ensure_unique_name=True) if isinstance(node, ReturnBlock) and isinstance(parent, dace.SDFG): # If a return block is being inlined into an SDFG, convert it into a regular state. Otherwise it # remains as-is. newnode = parent.add_state(node.label) block_to_state_map[node] = newnode - elif self.out_degree(node) == 0: - to_connect.add(node) + else: + parent.add_node(node, ensure_unique_name=True) + if self.out_degree(node) == 0 and not isinstance(node, (BreakBlock, ContinueBlock, ReturnBlock)): + to_connect.add(node) # Add all region edges. for edge in self.edges(): @@ -2587,14 +2595,26 @@ def inline(self) -> Tuple[bool, Any]: for b_edge in parent.in_edges(self): parent.add_edge(b_edge.src, self.start_block, b_edge.data) parent.remove_edge(b_edge) - # Redirect all edges exiting the region to instead exit the end state. - for a_edge in parent.out_edges(self): - parent.add_edge(end_state, a_edge.dst, a_edge.data) - parent.remove_edge(a_edge) - - for node in to_connect: - parent.add_edge(node, end_state, dace.InterstateEdge()) - + + end_state = None + if len(to_connect) > 0: + end_state = parent.add_state(self.label + '_end') + # Redirect all edges exiting the region to instead exit the end state. + for a_edge in parent.out_edges(self): + parent.add_edge(end_state, a_edge.dst, a_edge.data) + parent.remove_edge(a_edge) + + for node in to_connect: + parent.add_edge(node, end_state, dace.InterstateEdge()) + else: + # TODO: Move this to dead state elimination. + dead_blocks = [succ for succ in parent.successors(self) if parent.in_degree(succ) == 1] + while dead_blocks: + layer = list(dead_blocks) + dead_blocks.clear() + for u in layer: + dead_blocks.extend([succ for succ in parent.successors(u) if parent.in_degree(succ) == 1]) + parent.remove_node(u) # Remove the original control flow region (self) from the parent graph. parent.remove_node(self) @@ -2741,6 +2761,9 @@ def all_control_flow_regions(self, recursive=False) -> Iterator['ControlFlowRegi yield from node.sdfg.all_control_flow_regions(recursive=recursive) elif isinstance(block, ControlFlowRegion): yield from block.all_control_flow_regions(recursive=recursive) + elif isinstance(block, ConditionalBlock): + for _, branch in block.branches: + yield from branch.all_control_flow_regions(recursive=recursive) def all_sdfgs_recursive(self) -> Iterator['SDFG']: """ Iterate over this and all nested SDFGs. """ @@ -2755,6 +2778,9 @@ def all_states(self) -> Iterator[SDFGState]: yield block elif isinstance(block, ControlFlowRegion): yield from block.all_states() + elif isinstance(block, ConditionalBlock): + for _, region in block.branches: + yield from region.all_states() def all_control_flow_blocks(self, recursive=False) -> Iterator[ControlFlowBlock]: """ Iterate over all control flow blocks in this control flow graph. """ @@ -2788,7 +2814,7 @@ def _used_symbols_internal(self, for block in ordered_blocks: state_symbols = set() - if isinstance(block, ControlFlowRegion): + if isinstance(block, (ControlFlowRegion, ConditionalBlock)): b_free_syms, b_defined_syms, b_used_before_syms = block._used_symbols_internal(all_symbols, defined_syms, free_syms, @@ -3020,7 +3046,7 @@ def inline(self) -> Tuple[bool, Any]: # and return are inlined correctly. def recursive_inline_cf_regions(region: ControlFlowRegion) -> None: for block in region.nodes(): - if isinstance(block, ControlFlowRegion) and not isinstance(block, LoopRegion): + if (isinstance(block, ControlFlowRegion) or isinstance(block, ConditionalBlock)) and not isinstance(block, LoopRegion): recursive_inline_cf_regions(block) block.inline() recursive_inline_cf_regions(self) @@ -3189,16 +3215,165 @@ def has_return(self) -> bool: return True return False + +@make_properties +class ConditionalBlock(ControlFlowBlock, ControlGraphView): + + _branches: List[Tuple[Optional[CodeBlock], ControlFlowRegion]] + + def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, parent: Optional['ControlFlowRegion'] = None): + super().__init__(label, sdfg, parent) + self._branches = [] + + def __str__(self): + return self._label + + def __repr__(self) -> str: + return f'ConditionalBlock ({self.label})' + + @property + def branches(self) -> List[Tuple[Optional[CodeBlock], ControlFlowRegion]]: + return self._branches + + def nodes(self) -> List['ControlFlowBlock']: + return [node for _, node in self._branches if node is not None] + + def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: + return [] + + def _used_symbols_internal(self, + all_symbols: bool, + defined_syms: Optional[Set] = None, + free_syms: Optional[Set] = None, + used_before_assignment: Optional[Set] = None, + keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms = set() if defined_syms is None else defined_syms + free_syms = set() if free_syms is None else free_syms + used_before_assignment = set() if used_before_assignment is None else used_before_assignment + + for condition, region in self._branches: + if condition is not None: + free_syms |= condition.get_free_symbols(defined_syms) + b_free_symbols, b_defined_symbols, b_used_before_assignment = region._used_symbols_internal( + all_symbols, defined_syms, free_syms, used_before_assignment, keep_defined_in_mapping) + free_syms |= b_free_symbols + defined_syms |= b_defined_symbols + used_before_assignment |= b_used_before_assignment + + defined_syms -= used_before_assignment + free_syms -= defined_syms + + return free_syms, defined_syms, used_before_assignment + + def replace_dict(self, + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, + replace_in_graph: bool = True, + replace_keys: bool = True): + if replace_keys: + from dace.sdfg.replace import replace_properties_dict + replace_properties_dict(self, repl, symrepl) + + for _, region in self._branches: + region.replace_dict(repl, symrepl, replace_in_graph) + + def to_json(self, parent=None): + json = super().to_json(parent) + json['branches'] = [(condition.to_json() if condition is not None else None, cfg.to_json()) + for condition, cfg in self._branches] + return json + + @classmethod + def from_json(cls, json_obj, context=None): + context = context or {'sdfg': None, 'parent_graph': None} + _type = json_obj['type'] + if _type != cls.__name__: + raise TypeError('Class type mismatch') + + ret = cls(label=json_obj['label'], sdfg=context['sdfg']) + + dace.serialize.set_properties_from_json(ret, json_obj) + + for condition, region in json_obj['branches']: + if condition is not None: + ret._branches.append((CodeBlock.from_json(condition), ControlFlowRegion.from_json(region, context))) + else: + ret._branches.append((None, ControlFlowRegion.from_json(region, context))) + return ret + + def inline(self) -> Tuple[bool, Any]: + """ + Inlines the conditional region into its parent control flow region. + + :return: True if the inlining succeeded, false otherwise. + """ + parent = self.parent_graph + if not parent: + raise RuntimeError('No top-level SDFG present to inline into') + + # Add all boilerplate states necessary for the structure. + guard_state = parent.add_state(self.label + '_guard') + end_state = parent.add_state(self.label + '_end') + + # Redirect all edges to the region to the init state. + for b_edge in parent.in_edges(self): + parent.add_edge(b_edge.src, guard_state, b_edge.data) + parent.remove_edge(b_edge) + # Redirect all edges exiting the region to instead exit the end state. + for a_edge in parent.out_edges(self): + parent.add_edge(end_state, a_edge.dst, a_edge.data) + parent.remove_edge(a_edge) + + from dace.sdfg.sdfg import InterstateEdge + else_branch = None + full_cond_expression: Optional[List[ast.AST]] = None + for condition, region in self._branches: + if condition is None: + else_branch = region + else: + if full_cond_expression is None: + full_cond_expression = condition.code[0] + else: + full_cond_expression = astutils.and_expr(full_cond_expression, condition.code[0]) + parent.add_node(region) + parent.add_edge(guard_state, region, InterstateEdge(condition=condition)) + parent.add_edge(region, end_state, InterstateEdge()) + if full_cond_expression is not None: + negative_full_cond = astutils.negate_expr(full_cond_expression) + negative_cond = CodeBlock([negative_full_cond]) + else: + negative_cond = CodeBlock('1') + + if else_branch is not None: + parent.add_node(else_branch) + parent.add_edge(guard_state, else_branch, InterstateEdge(condition=negative_cond)) + parent.add_edge(region, end_state, InterstateEdge()) + else: + parent.add_edge(guard_state, end_state, InterstateEdge(condition=negative_cond)) + + parent.remove_node(self) + + sdfg = parent if isinstance(parent, dace.SDFG) else parent.sdfg + sdfg.reset_cfg_list() + + return True, (guard_state, end_state) + + @make_properties class NamedRegion(ControlFlowRegion): + debuginfo = DebugInfoProperty() + def __init__(self, label: str, sdfg: Optional['SDFG']=None, debuginfo: Optional[dtypes.DebugInfo]=None): super().__init__(label, sdfg) self.debuginfo = debuginfo @make_properties -class FunctionCallRegion(ControlFlowRegion): +class FunctionCallRegion(NamedRegion): + arguments = DictProperty(str, str) - def __init__(self, label: str, arguments: Dict[str, str] = {}, sdfg: 'SDFG' = None): - super().__init__(label, sdfg) + + def __init__(self, label: str, arguments: Dict[str, str] = {}, sdfg: 'SDFG' = None, + debuginfo: Optional[dtypes.DebugInfo]=None): + super().__init__(label, sdfg, debuginfo) self.arguments = arguments diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index a90a232aeb..5b9ce1a431 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -13,7 +13,7 @@ from dace.sdfg.graph import MultiConnectorEdge from dace.sdfg.sdfg import SDFG from dace.sdfg.nodes import Node, NestedSDFG -from dace.sdfg.state import SDFGState, StateSubgraphView, LoopRegion, ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, SDFGState, StateSubgraphView, LoopRegion, ControlFlowRegion from dace.sdfg.scope import ScopeSubgraphView from dace.sdfg import nodes as nd, graph as gr, propagation from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs @@ -1262,11 +1262,10 @@ def inline_loop_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = No def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: - blocks = [n for n, _ in sdfg.all_nodes_recursive() - if isinstance(n, ControlFlowRegion) and not isinstance(n, (LoopRegion, SDFG))] + blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ControlFlowRegion)] count = 0 - for _block in optional_progressbar(reversed(blocks), title='Inlining control flow blocks', + for _block in optional_progressbar(reversed(blocks), title='Inlining control flow regions', n=len(blocks), progress=progress): block: ControlFlowRegion = _block if block.inline()[0]: @@ -1274,6 +1273,18 @@ def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: return count +def inline_conditional_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: + blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)] + count = 0 + + for _block in optional_progressbar(reversed(blocks), title='Inlining conditional blocks', + n=len(blocks), progress=progress): + block: ConditionalBlock = _block + if block.inline()[0]: + count += 1 + + return count + def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, multistate: bool = True) -> int: """ diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 2869743dcb..f305affb80 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -34,7 +34,7 @@ def validate_control_flow_region(sdfg: 'SDFG', symbols: dict, references: Set[int] = None, **context: bool): - from dace.sdfg.state import SDFGState, ControlFlowRegion + from dace.sdfg.state import SDFGState, ControlFlowRegion, ConditionalBlock from dace.sdfg.scope import is_in_scope if len(region.source_nodes()) > 1 and region.start_block is None: @@ -118,6 +118,10 @@ def validate_control_flow_region(sdfg: 'SDFG', if isinstance(edge.dst, SDFGState): validate_state(edge.dst, region.node_id(edge.dst), sdfg, symbols, initialized_transients, references, **context) + elif isinstance(edge.dst, ConditionalBlock): + for _, r in edge.dst.branches: + if r is not None: + validate_control_flow_region(sdfg, r, initialized_transients, symbols, references, **context) elif isinstance(edge.dst, ControlFlowRegion): validate_control_flow_region(sdfg, edge.dst, initialized_transients, symbols, references, **context) # End of block DFS diff --git a/tests/python_frontend/conditional_regions_test.py b/tests/python_frontend/conditional_regions_test.py new file mode 100644 index 0000000000..07e214653c --- /dev/null +++ b/tests/python_frontend/conditional_regions_test.py @@ -0,0 +1,92 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np +from dace.sdfg.state import ConditionalBlock + + +def test_dataflow_if_check(): + + @dace.program + def dataflow_if_check(A: dace.int32[10], i: dace.int64): + if A[i] < 10: + return 0 + elif A[i] == 10: + return 10 + return 100 + + dataflow_if_check.use_experimental_cfg_blocks = True + sdfg = dataflow_if_check.to_sdfg() + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + A = np.zeros((10,), np.int32) + A[4] = 10 + A[5] = 100 + assert sdfg(A, 0)[0] == 0 + assert sdfg(A, 4)[0] == 10 + assert sdfg(A, 5)[0] == 100 + assert sdfg(A, 6)[0] == 0 + + +def test_nested_if_chain(): + + @dace.program + def nested_if_chain(i: dace.int64): + if i < 2: + return 0 + else: + if i < 4: + return 1 + else: + if i < 6: + return 2 + else: + if i < 8: + return 3 + else: + return 4 + + nested_if_chain.use_experimental_cfg_blocks = True + sdfg = nested_if_chain.to_sdfg() + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + assert nested_if_chain(0)[0] == 0 + assert nested_if_chain(2)[0] == 1 + assert nested_if_chain(4)[0] == 2 + assert nested_if_chain(7)[0] == 3 + assert nested_if_chain(15)[0] == 4 + + +def test_elif_chain(): + + @dace.program + def elif_chain(i: dace.int64): + if i < 2: + return 0 + elif i < 4: + return 1 + elif i < 6: + return 2 + elif i < 8: + return 3 + else: + return 4 + + elif_chain.use_experimental_cfg_blocks = True + sdfg = elif_chain.to_sdfg() + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + assert elif_chain(0)[0] == 0 + assert elif_chain(2)[0] == 1 + assert elif_chain(4)[0] == 2 + assert elif_chain(7)[0] == 3 + assert elif_chain(15)[0] == 4 + + +if __name__ == '__main__': + test_dataflow_if_check() + test_nested_if_chain() + test_elif_chain() diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py new file mode 100644 index 0000000000..4e4eda3f44 --- /dev/null +++ b/tests/sdfg/conditional_region_test.py @@ -0,0 +1,94 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import dace +from dace.properties import CodeBlock +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion +import dace.serialize + + +def test_cond_region_if(): + sdfg = dace.SDFG('regular_if') + sdfg.add_array("A", (1,), dace.float32) + sdfg.add_symbol("i", dace.int32) + state0 = sdfg.add_state('state0', is_start_block=True) + + if1 = ConditionalBlock("if1") + sdfg.add_node(if1) + sdfg.add_edge(state0, if1, InterstateEdge()) + + if_body = ControlFlowRegion("if_body", sdfg=sdfg) + if1.branches.append((CodeBlock("i == 1"), if_body)) + + state1 = if_body.add_state("state1", is_start_block=True) + acc_a = state1.add_access('A') + t1 = state1.add_tasklet("t1", None, {"a"}, "a = 100") + state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[0]')) + + assert sdfg.is_valid() + A = np.ones((1,), dtype=np.float32) + sdfg(i=1, A=A) + assert A[0] == 100 + + A = np.ones((1,), dtype=np.float32) + sdfg(i=0, A=A) + assert A[0] == 1 + +def test_serialization(): + sdfg = SDFG("test_serialization") + cond_region = ConditionalBlock("cond_region") + sdfg.add_node(cond_region, is_start_block=True) + sdfg.add_symbol("i", dace.int32) + + for j in range(10): + cfg = ControlFlowRegion(f"cfg_{j}", sdfg) + cond_region.branches.append((CodeBlock(f"i == {j}"), cfg)) + + assert sdfg.is_valid() + + new_sdfg = SDFG.from_json(sdfg.to_json()) + assert new_sdfg.is_valid() + new_cond_region: ConditionalBlock = new_sdfg.nodes()[0] + for j in range(10): + condition, cfg = new_cond_region.branches[j] + assert condition == CodeBlock(f"i == {j}") + assert cfg.label == f"cfg_{j}" + +def test_if_else(): + sdfg = dace.SDFG('regular_if_else') + sdfg.add_array("A", (1,), dace.float32) + sdfg.add_symbol("i", dace.int32) + state0 = sdfg.add_state('state0', is_start_block=True) + + if1 = ConditionalBlock("if1") + sdfg.add_node(if1) + sdfg.add_edge(state0, if1, InterstateEdge()) + + if_body = ControlFlowRegion("if_body", sdfg=sdfg) + state1 = if_body.add_state("state1", is_start_block=True) + acc_a = state1.add_access('A') + t1 = state1.add_tasklet("t1", None, {"a"}, "a = 100") + state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[0]')) + if1.branches.append((CodeBlock("i == 1"), if_body)) + + else_body = ControlFlowRegion("else_body", sdfg=sdfg) + state2 = else_body.add_state("state1", is_start_block=True) + acc_a2 = state2.add_access('A') + t2 = state2.add_tasklet("t2", None, {"a"}, "a = 200") + state2.add_edge(t2, 'a', acc_a2, None, dace.Memlet('A[0]')) + if1.branches.append((CodeBlock("i == 0"), else_body)) + + assert sdfg.is_valid() + A = np.ones((1,), dtype=np.float32) + sdfg(i=1, A=A) + assert A[0] == 100 + + A = np.ones((1,), dtype=np.float32) + sdfg(i=0, A=A) + assert A[0] == 200 + +if __name__ == '__main__': + test_cond_region_if() + test_serialization() + test_if_else()