Skip to content

Commit

Permalink
feat(hugr-py): AsCustomOp protocol for user-defined custom op types. (
Browse files Browse the repository at this point in the history
#1290)

Downstream users can implement this protocol to use their own
convenience classes. Typically a use wants to implement:
- `to_custom`
- `__call__` to provide a useful calling signature.
- `from_custom` if different instances of the type correspond to
different ops (non-singleton).

Replace existing custom operation types with this.

Follow ups:
- Similar thing for custom types.
- Optional: allow these types to register themselves with
`serialization.ops.CustomOp` so they can be deserialized directly.
  • Loading branch information
ss2165 authored Jul 10, 2024
1 parent 1baa697 commit 1db43eb
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 71 deletions.
94 changes: 86 additions & 8 deletions hugr-py/src/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from __future__ import annotations

from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, Protocol, TypeVar, runtime_checkable

from typing_extensions import Self

import hugr.serialization.ops as sops
from hugr import tys, val
from hugr.node_port import Direction, InPort, Node, OutPort, Wire
Expand Down Expand Up @@ -197,8 +200,79 @@ def _set_in_types(self, types: tys.TypeRow) -> None:
self._types = types


@dataclass(frozen=True)
class Custom(DataflowOp):
@runtime_checkable
class AsCustomOp(DataflowOp, Protocol):
"""Abstract interface that types can implement
to behave as a custom dataflow operation.
"""

@dataclass(frozen=True)
class InvalidCustomOp(Exception):
"""Custom operation does not match the expected type."""

msg: str

@cached_property
def custom_op(self) -> Custom:
""":class:`Custom` operation that this type represents.
Computed once using :meth:`to_custom` and cached - should be deterministic.
"""
return self.to_custom()

def to_custom(self) -> Custom:
"""Convert this type to a :class:`Custom` operation.
Used by :attr:`custom_op`, so must be deterministic.
"""
... # pragma: no cover

@classmethod
def from_custom(cls, custom: Custom) -> Self | None:
"""Load from a :class:`Custom` operation.
By default assumes the type of `cls` is a singleton,
and compares the result of :meth:`to_custom` with the given `custom`.
If successful, returns the singleton, else None.
Non-singleton types should override this method.
Raises:
InvalidCustomOp: If the given `custom` does not match the expected one for a
given extension/operation name.
"""
default = cls()
if default.custom_op == custom:
return default
return None

def __eq__(self, other: object) -> bool:
if not isinstance(other, AsCustomOp):
return NotImplemented
slf, other = self.custom_op, other.custom_op
return (
slf.extension == other.extension
and slf.op_name == other.op_name
and slf.signature == other.signature
and slf.args == other.args
)

def outer_signature(self) -> tys.FunctionType:
return self.custom_op.signature

def to_serial(self, parent: Node) -> sops.CustomOp:
return self.custom_op.to_serial(parent)

@property
def num_out(self) -> int:
return len(self.custom_op.signature.output)


@dataclass(frozen=True, eq=False)
class Custom(AsCustomOp):
"""A non-core dataflow operation defined in an extension."""

op_name: str
Expand All @@ -207,10 +281,6 @@ class Custom(DataflowOp):
extension: tys.ExtensionId = ""
args: list[tys.TypeArg] = field(default_factory=list)

@property
def num_out(self) -> int:
return len(self.signature.output)

def to_serial(self, parent: Node) -> sops.CustomOp:
return sops.CustomOp(
parent=parent.idx,
Expand All @@ -221,8 +291,16 @@ def to_serial(self, parent: Node) -> sops.CustomOp:
args=ser_it(self.args),
)

def outer_signature(self) -> tys.FunctionType:
return self.signature
def to_custom(self) -> Custom:
return self

@classmethod
def from_custom(cls, custom: Custom) -> Custom:
return custom

def check_id(self, extension: tys.ExtensionId, op_name: str) -> bool:
"""Check if the operation matches the given extension and operation name."""
return self.extension == extension and self.op_name == op_name


@dataclass()
Expand Down
55 changes: 37 additions & 18 deletions hugr-py/src/hugr/std/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@

from __future__ import annotations

from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar

from typing_extensions import Self

from hugr import tys, val
from hugr.ops import Custom
from hugr.ops import AsCustomOp, Custom, DataflowOp

if TYPE_CHECKING:
from hugr.ops import Command, ComWire


def int_t(width: int) -> tys.Opaque:
Expand Down Expand Up @@ -44,27 +50,40 @@ def to_value(self) -> val.Extension:
return val.Extension("int", INT_T, self.v)


@dataclass(frozen=True)
class IntOps(Custom):
"""Base class for integer operations."""

extension: tys.ExtensionId = "arithmetic.int"


_ARG_I32 = tys.BoundedNatArg(n=5)
OPS_EXTENSION: tys.ExtensionId = "arithmetic.int"


@dataclass(frozen=True)
class _DivModDef(IntOps):
class _DivModDef(AsCustomOp):
"""DivMod operation, has two inputs and two outputs."""

num_out: int = 2
extension: tys.ExtensionId = "arithmetic.int"
op_name: str = "idivmod_u"
signature: tys.FunctionType = field(
default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2)
)
args: list[tys.TypeArg] = field(default_factory=lambda: [_ARG_I32, _ARG_I32])
op_name: ClassVar[str] = "idivmod_u"
arg1: int = 5
arg2: int = 5

def to_custom(self) -> Custom:
return Custom(
"idivmod_u",
tys.FunctionType(
input=[int_t(self.arg1)] * 2, output=[int_t(self.arg2)] * 2
),
extension=OPS_EXTENSION,
args=[tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)],
)

@classmethod
def from_custom(cls, custom: Custom) -> Self | None:
if not custom.check_id(OPS_EXTENSION, "idivmod_u"):
return None
match custom.args:
case [tys.BoundedNatArg(n=a1), tys.BoundedNatArg(n=a2)]:
return cls(arg1=a1, arg2=a2)
case _:
msg = f"Invalid args: {custom.args}"
raise AsCustomOp.InvalidCustomOp(msg)

def __call__(self, a: ComWire, b: ComWire) -> Command:
return DataflowOp.__call__(self, a, b)


#: DivMod operation.
Expand Down
20 changes: 6 additions & 14 deletions hugr-py/src/hugr/std/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,24 @@
from typing import TYPE_CHECKING

from hugr import tys
from hugr.ops import Command, Custom
from hugr.ops import AsCustomOp, Command, Custom, DataflowOp

if TYPE_CHECKING:
from hugr.ops import ComWire


@dataclass(frozen=True)
class LogicOps(Custom):
"""Base class for logic operations."""

extension: tys.ExtensionId = "logic"


_NotSig = tys.FunctionType.endo([tys.Bool])
EXTENSION_ID: tys.ExtensionId = "logic"


@dataclass(frozen=True)
class _NotDef(LogicOps):
class _NotDef(AsCustomOp):
"""Not operation."""

num_out: int = 1
op_name: str = "Not"
signature: tys.FunctionType = _NotSig
def to_custom(self) -> Custom:
return Custom("Not", tys.FunctionType.endo([tys.Bool]), extension=EXTENSION_ID)

def __call__(self, a: ComWire) -> Command:
return super().__call__(a)
return DataflowOp.__call__(self, a)


#: Not operation
Expand Down
95 changes: 64 additions & 31 deletions hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,77 +5,110 @@
import pathlib
import subprocess
from dataclasses import dataclass
from typing import TYPE_CHECKING
from enum import Enum
from typing import TYPE_CHECKING, TypeVar

from typing_extensions import Self

from hugr import tys
from hugr.hugr import Hugr
from hugr.ops import Command, Custom
from hugr.ops import AsCustomOp, Command, Custom, DataflowOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.std.float import FLOAT_T

if TYPE_CHECKING:
from hugr.ops import ComWire


@dataclass(frozen=True)
class QuantumOps(Custom):
extension: tys.ExtensionId = "tket2.quantum"
QUANTUM_EXTENSION_ID: tys.ExtensionId = "quantum.tket2"

E = TypeVar("E", bound=Enum)


_OneQbSig = tys.FunctionType.endo([tys.Qubit])
def _load_enum(enum_cls: type[E], custom: Custom) -> E | None:
if (
custom.extension == QUANTUM_EXTENSION_ID
and custom.op_name in enum_cls.__members__
):
return enum_cls(custom.op_name)
return None


@dataclass(frozen=True)
class OneQbGate(QuantumOps):
op_name: str
num_out: int = 1
signature: tys.FunctionType = _OneQbSig
class OneQbGate(AsCustomOp):
# Have to nest enum to avoid meta class conflict
class _Enum(Enum):
H = "H"

_enum: _Enum

def __call__(self, q: ComWire) -> Command:
return super().__call__(q)
return DataflowOp.__call__(self, q)

def to_custom(self) -> Custom:
return Custom(
self._enum.value,
tys.FunctionType.endo([tys.Qubit]),
extension=QUANTUM_EXTENSION_ID,
)

H = OneQbGate("H")
@classmethod
def from_custom(cls, custom: Custom) -> Self | None:
return cls(e) if (e := _load_enum(cls._Enum, custom)) else None


_TwoQbSig = tys.FunctionType.endo([tys.Qubit] * 2)
H = OneQbGate(OneQbGate._Enum.H)


@dataclass(frozen=True)
class TwoQbGate(QuantumOps):
op_name: str
num_out: int = 2
signature: tys.FunctionType = _TwoQbSig
class TwoQbGate(AsCustomOp):
class _Enum(Enum):
CX = "CX"

def __call__(self, q0: ComWire, q1: ComWire) -> Command:
return super().__call__(q0, q1)
_enum: _Enum

def to_custom(self) -> Custom:
return Custom(
self._enum.value,
tys.FunctionType.endo([tys.Qubit] * 2),
extension=QUANTUM_EXTENSION_ID,
)

@classmethod
def from_custom(cls, custom: Custom) -> Self | None:
return cls(e) if (e := _load_enum(cls._Enum, custom)) else None

CX = TwoQbGate("CX")
def __call__(self, q0: ComWire, q1: ComWire) -> Command:
return DataflowOp.__call__(self, q0, q1)

_MeasSig = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool])

CX = TwoQbGate(TwoQbGate._Enum.CX)


@dataclass(frozen=True)
class MeasureDef(QuantumOps):
op_name: str = "Measure"
num_out: int = 2
signature: tys.FunctionType = _MeasSig
class MeasureDef(AsCustomOp):
def to_custom(self) -> Custom:
return Custom(
"Measure",
tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]),
extension=QUANTUM_EXTENSION_ID,
)

def __call__(self, q: ComWire) -> Command:
return super().__call__(q)


Measure = MeasureDef()

_RzSig = tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit])


@dataclass(frozen=True)
class RzDef(QuantumOps):
op_name: str = "Rz"
num_out: int = 1
signature: tys.FunctionType = _RzSig
class RzDef(AsCustomOp):
def to_custom(self) -> Custom:
return Custom(
"Rz",
tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]),
extension=QUANTUM_EXTENSION_ID,
)

def __call__(self, q: ComWire, fl_wire: ComWire) -> Command:
return super().__call__(q, fl_wire)
Expand Down
Loading

0 comments on commit 1db43eb

Please sign in to comment.