Skip to content

Commit

Permalink
feat(hugr-py): Allow defining functions, consts, and aliases inside D…
Browse files Browse the repository at this point in the history
…FGs (#1394)

Move `define_function` and `add_alias_defn` from `Module` to a common
root for both it and `DFG`.
Move both `add_const` definitions to that common base.

I had to move `Function` to `dfg.py` due to circular deps, but I added a
reexport to avoid breaking changes.
  • Loading branch information
aborgna-q authored Aug 2, 2024
1 parent 77795b9 commit d554072
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 89 deletions.
99 changes: 81 additions & 18 deletions hugr-py/src/hugr/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass, field, replace
from typing import (
TYPE_CHECKING,
Generic,
TypeVar,
)

Expand All @@ -22,13 +23,66 @@
from .cfg import Cfg
from .cond_loop import Conditional, If, TailLoop
from .node_port import Node, OutPort, PortOffset, ToNode, Wire
from .tys import Type, TypeParam, TypeRow

OpVar = TypeVar("OpVar", bound=ops.Op)


@dataclass()
class _DefinitionBuilder(Generic[OpVar]):
"""Base class for builders that can define functions, constants, and aliases.
As this class may be a root node, it does not extend `ParentBuilder`.
"""

hugr: Hugr[OpVar]

def define_function(
self,
name: str,
input_types: TypeRow,
type_params: list[TypeParam] | None = None,
) -> Function:
"""Start building a function definition in the graph.
Args:
name: The name of the function.
input_types: The input types for the function.
type_params: The type parameters for the function, if polymorphic.
Returns:
The new function builder.
"""
parent_op = ops.FuncDefn(name, input_types, type_params or [])
return Function.new_nested(parent_op, self.hugr)

def add_const(self, value: val.Value) -> Node:
"""Add a static constant to the graph.
Args:
value: The constant value to add.
Returns:
The node holding the :class:`Const <hugr.ops.Const>` operation.
Example:
>>> dfg = Dfg()
>>> const_n = dfg.add_const(val.TRUE)
>>> dfg.hugr[const_n].op
Const(TRUE)
"""
return self.hugr.add_node(ops.Const(value), self.hugr.root)

def add_alias_defn(self, name: str, ty: Type) -> Node:
"""Add a type alias definition."""
return self.hugr.add_node(ops.AliasDefn(name, ty), self.hugr.root)


DP = TypeVar("DP", bound=ops.DfParentOp)


@dataclass()
class _DfBase(ParentBuilder[DP], AbstractContextManager):
class _DfBase(ParentBuilder[DP], _DefinitionBuilder, AbstractContextManager):
"""Base class for dataflow graph builders.
Args:
Expand Down Expand Up @@ -428,23 +482,6 @@ def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
self.hugr.add_link(src.out(-1), dst.inp(-1))

def add_const(self, val: val.Value) -> Node:
"""Add a static constant to the graph.
Args:
val: The value to add.
Returns:
The node holding the :class:`Const <hugr.ops.Const>` operation.
Example:
>>> dfg = Dfg()
>>> const_n = dfg.add_const(val.TRUE)
>>> dfg.hugr[const_n].op
Const(TRUE)
"""
return self.hugr.add_const(val, self.parent_node)

def load(self, const: ToNode | val.Value) -> Node:
"""Load a constant into the graph as a dataflow value.
Expand Down Expand Up @@ -594,3 +631,29 @@ def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:
tgt = tgt_parent

return None


@dataclass
class Function(_DfBase[ops.FuncDefn]):
"""Build a function definition as a HUGR dataflow graph.
Args:
name: The name of the function.
input_types: The input types for the function (output types are
computed by propagating types from input node through the graph).
type_params: The type parameters for the function, if polymorphic.
Examples:
>>> f = Function("f", [tys.Bool])
>>> f.parent_op
FuncDefn(name='f', inputs=[Bool], params=[])
"""

def __init__(
self,
name: str,
input_types: TypeRow,
type_params: list[TypeParam] | None = None,
) -> None:
root_op = ops.FuncDefn(name, input_types, type_params or [])
super().__init__(root_op)
78 changes: 7 additions & 71 deletions hugr-py/src/hugr/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,19 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from hugr import ops, val

from .dfg import _DfBase
from . import ops
from .dfg import Function, _DefinitionBuilder
from .hugr import Hugr

if TYPE_CHECKING:
from hugr.node_port import Node

from .tys import PolyFuncType, Type, TypeBound, TypeParam, TypeRow


@dataclass
class Function(_DfBase[ops.FuncDefn]):
"""Build a function definition as a HUGR dataflow graph.
Args:
name: The name of the function.
input_types: The input types for the function (output types are
computed by propagating types from input node through the graph).
type_params: The type parameters for the function, if polymorphic.
Examples:
>>> f = Function("f", [tys.Bool])
>>> f.parent_op
FuncDefn(name='f', inputs=[Bool], params=[])
"""
from .node_port import Node
from .tys import PolyFuncType, TypeBound, TypeRow

def __init__(
self,
name: str,
input_types: TypeRow,
type_params: list[TypeParam] | None = None,
) -> None:
root_op = ops.FuncDefn(name, input_types, type_params or [])
super().__init__(root_op)
__all__ = ["Function", "Module"]


@dataclass
class Module:
class Module(_DefinitionBuilder[ops.Module]):
"""Build a top-level HUGR module.
Examples:
Expand All @@ -57,25 +31,6 @@ class Module:
def __init__(self) -> None:
self.hugr = Hugr(ops.Module())

def define_function(
self,
name: str,
input_types: TypeRow,
type_params: list[TypeParam] | None = None,
) -> Function:
"""Start building a function definition in the module.
Args:
name: The name of the function.
input_types: The input types for the function.
type_params: The type parameters for the function, if polymorphic.
Returns:
The new function builder.
"""
parent_op = ops.FuncDefn(name, input_types, type_params or [])
return Function.new_nested(parent_op, self.hugr)

def define_main(self, input_types: TypeRow) -> Function:
"""Define the 'main' function in the module. See :meth:`define_function`."""
return self.define_function("main", input_types)
Expand All @@ -91,33 +46,14 @@ def declare_function(self, name: str, signature: PolyFuncType) -> Node:
The node representing the function declaration.
Examples:
>>> from hugr.function import Module
>>> m = Module()
>>> sig = tys.PolyFuncType([], tys.FunctionType.empty())
>>> m.declare_function("f", sig)
Node(1)
"""
return self.hugr.add_node(ops.FuncDecl(name, signature), self.hugr.root)

def add_const(self, value: val.Value) -> Node:
"""Add a static constant to the module.
Args:
value: The constant value to add.
Returns:
The node holding the constant.
Examples:
>>> m = Module()
>>> m.add_const(val.FALSE)
Node(1)
"""
return self.hugr.add_node(ops.Const(value), self.hugr.root)

def add_alias_defn(self, name: str, ty: Type) -> Node:
"""Add a type alias definition."""
return self.hugr.add_node(ops.AliasDefn(name, ty), self.hugr.root)

def add_alias_decl(self, name: str, bound: TypeBound) -> Node:
"""Add a type alias declaration."""
return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.root)

0 comments on commit d554072

Please sign in to comment.