Skip to content

Commit

Permalink
feat!: Add functions to quantum module and make quantum_functional in…
Browse files Browse the repository at this point in the history
…dependent (#494)

Closes #378 and closes #465.

BREAKING CHANGE: `quantum_functional` is now its own Guppy module and no
longer implicitly comes with `quantum`.

---------

Co-authored-by: Alan Lawrence <alan.lawrence@cambridgequantum.com>
  • Loading branch information
mark-koch and acl-cqc authored Sep 16, 2024
1 parent 6fdb5d6 commit 0b0b1af
Show file tree
Hide file tree
Showing 18 changed files with 361 additions and 188 deletions.
42 changes: 21 additions & 21 deletions examples/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@
"outputs": [],
"source": [
"from guppylang.prelude.builtins import owned\n",
"from guppylang.prelude.quantum import qubit, measure\n",
"from guppylang.prelude.quantum_functional import h, cx\n",
"from guppylang.prelude.quantum import qubit, measure, h, cx\n",
"\n",
"module.load(qubit, h, cx, measure)\n",
"\n",
Expand All @@ -101,8 +100,9 @@
" # Allocate two fresh qubits\n",
" q1, q2 = qubit(), qubit()\n",
" # Entangle\n",
" q1, q2 = cx(h(q1), q2)\n",
" q2 = h(q2)\n",
" h(q1)\n",
" cx(q1, q2)\n",
" h(q2)\n",
" # Measure\n",
" b1, b2 = measure(q1), measure(q2)\n",
" return b1 ^ b2"
Expand Down Expand Up @@ -298,10 +298,10 @@
"Guppy compilation failed. Error in file <In [9]>:6\n",
"\n",
"4: @guppy(bad_module)\n",
"5: def bad(q: qubit @owned) -> tuple[qubit, qubit]:\n",
"6: return cx(q, q)\n",
" ^\n",
"GuppyError: Variable `q` with linear type `qubit` was already used (at 6:14)\n"
"5: def bad(q: qubit @owned) -> qubit:\n",
"6: cx(q, q)\n",
" ^\n",
"GuppyError: Variable `q` with linear type `qubit` was already used (at 6:7)\n"
]
}
],
Expand All @@ -310,8 +310,9 @@
"bad_module.load_all(guppylang.prelude.quantum)\n",
"\n",
"@guppy(bad_module)\n",
"def bad(q: qubit @owned) -> tuple[qubit, qubit]:\n",
" return cx(q, q)\n",
"def bad(q: qubit @owned) -> qubit:\n",
" cx(q, q)\n",
" return q\n",
"\n",
"bad_module.compile() # Raises an error"
]
Expand Down Expand Up @@ -341,9 +342,9 @@
"Guppy compilation failed. Error in file <In [10]>:7\n",
"\n",
"5: def bad(q: qubit @owned) -> qubit:\n",
"6: tmp = h(qubit())\n",
"7: tmp, q = cx(tmp, q)\n",
" ^^^\n",
"6: tmp = qubit()\n",
"7: cx(tmp, q)\n",
" ^^^\n",
"GuppyError: Variable `tmp` with linear type `qubit` is not used on all control-flow paths\n"
]
}
Expand All @@ -354,8 +355,8 @@
"\n",
"@guppy(bad_module)\n",
"def bad(q: qubit @owned) -> qubit:\n",
" tmp = h(qubit())\n",
" tmp, q = cx(tmp, q)\n",
" tmp = qubit()\n",
" cx(tmp, q)\n",
" #discard(tmp) # Compiler complains if tmp is not explicitly discarded\n",
" return q\n",
"\n",
Expand Down Expand Up @@ -422,15 +423,14 @@
" q2: qubit\n",
"\n",
" @guppy(module)\n",
" def method(self: \"QubitPair @owned\") -> \"QubitPair\":\n",
" self.q1 = h(self.q1)\n",
" self.q1, self.q2 = cx(self.q1, self.q2)\n",
" return self\n",
" def method(self: \"QubitPair\") -> None:\n",
" h(self.q1)\n",
" cx(self.q1, self.q2)\n",
"\n",
"@guppy(module)\n",
"def make_struct() -> QubitPair:\n",
" pair = QubitPair(qubit(), qubit())\n",
" # pair = pair.method() # TODO: Calling methods doesn't work yet\n",
" pair.method()\n",
" return pair\n",
"\n",
"program = module.compile()"
Expand Down Expand Up @@ -559,7 +559,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
ast.BitAnd: ("__and__", "__rand__", "&"),
ast.MatMult: ("__matmul__", "__rmatmul__", "@"),
ast.Eq: ("__eq__", "__eq__", "=="),
ast.NotEq: ("__neq__", "__neq__", "!="),
ast.NotEq: ("__ne__", "__ne__", "!="),
ast.Lt: ("__lt__", "__gt__", "<"),
ast.LtE: ("__le__", "__ge__", "<="),
ast.Gt: ("__gt__", "__lt__", ">"),
Expand Down
4 changes: 2 additions & 2 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from guppylang.definition.const import RawConstDef
from guppylang.definition.custom import (
CustomCallChecker,
CustomCallCompiler,
CustomInoutCallCompiler,
DefaultCallChecker,
NotImplementedCallCompiler,
OpCompiler,
Expand Down Expand Up @@ -197,7 +197,7 @@ def nat_var(self, module: GuppyModule, name: str) -> ConstVarDef:
def custom(
self,
module: GuppyModule,
compiler: CustomCallCompiler | None = None,
compiler: CustomInoutCallCompiler | None = None,
checker: CustomCallChecker | None = None,
higher_order_value: bool = True,
name: str = "",
Expand Down
43 changes: 30 additions & 13 deletions guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.nodes import GlobalCall
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, NoneType, Type
from guppylang.tys.ty import (
FuncInput,
FunctionType,
InputFlags,
NoneType,
Type,
type_to_row,
)


@dataclass(frozen=True)
Expand All @@ -41,7 +48,7 @@ class RawCustomFunctionDef(ParsableDef):

defined_at: ast.FunctionDef
call_checker: "CustomCallChecker"
call_compiler: "CustomCallCompiler"
call_compiler: "CustomInoutCallCompiler"

# Whether the function may be used as a higher-order value. This is only possible
# if a static type for the function is provided.
Expand Down Expand Up @@ -85,15 +92,16 @@ def compile_call(
) -> Sequence[Wire]:
"""Compiles a call to the function."""
# Note: We have _compiled_ globals rather than `Globals` here,
# so we cannot use `self._get_signature()`.
# so we cannot use `self._get_signature()` to build a `CustomFunctionDef`.
self.call_compiler._setup(
type_args,
dfg,
globals,
node,
function_ty,
None,
)
return self.call_compiler.compile(args)
return self.call_compiler.compile_with_inouts(args).regular_returns

def _get_signature(self, globals: Globals) -> FunctionType | None:
"""Returns the type of the function, if known.
Expand Down Expand Up @@ -237,7 +245,7 @@ def compile_call(
)
hugr_ty = concrete_ty.to_hugr()

self.call_compiler._setup(type_args, dfg, globals, node, hugr_ty)
self.call_compiler._setup(type_args, dfg, globals, node, hugr_ty, self)
return self.call_compiler.compile_with_inouts(args)


Expand Down Expand Up @@ -284,6 +292,7 @@ class CustomInoutCallCompiler(ABC):
globals: CompiledGlobals
node: AstNode
ty: ht.FunctionType
func: CustomFunctionDef | None

def _setup(
self,
Expand All @@ -292,12 +301,14 @@ def _setup(
globals: CompiledGlobals,
node: AstNode,
hugr_ty: ht.FunctionType,
func: CustomFunctionDef | None,
) -> None:
self.type_args = type_args
self.dfg = dfg
self.globals = globals
self.node = node
self.ty = hugr_ty
self.func = func

@abstractmethod
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
Expand All @@ -307,6 +318,11 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
passed through the function.
"""

@property
def builder(self) -> DfBase[ops.DfParentOp]:
"""The hugr dataflow builder."""
return self.dfg.builder


class CustomCallCompiler(CustomInoutCallCompiler, ABC):
"""Abstract base class for custom function call compilers with only owned args."""
Expand All @@ -318,11 +334,6 @@ def compile(self, args: list[Wire]) -> list[Wire]:
Use the provided `self.builder` to add nodes to the Hugr graph.
"""

@property
def builder(self) -> DfBase[ops.DfParentOp]:
"""The hugr dataflow builder."""
return self.dfg.builder

def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
return CallReturnWires(self.compile(args), inout_returns=[])

Expand Down Expand Up @@ -353,7 +364,7 @@ def compile(self, args: list[Wire]) -> list[Wire]:
raise InternalGuppyError("Function should have been removed during checking")


class OpCompiler(CustomCallCompiler):
class OpCompiler(CustomInoutCallCompiler):
"""Call compiler for functions that are directly implemented via Hugr ops.
args:
Expand All @@ -366,10 +377,16 @@ class OpCompiler(CustomCallCompiler):
def __init__(self, op: Callable[[ht.FunctionType, Inst], ops.DataflowOp]) -> None:
self.op = op

def compile(self, args: list[Wire]) -> list[Wire]:
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
op = self.op(self.ty, self.type_args)
node = self.builder.add_op(op, *args)
return list(node)
num_returns = (
len(type_to_row(self.func.ty.output)) if self.func else len(self.ty.output)
)
return CallReturnWires(
regular_returns=list(node[:num_returns]),
inout_returns=list(node[num_returns:]),
)


class NoopCompiler(CustomCallCompiler):
Expand Down
28 changes: 6 additions & 22 deletions guppylang/prelude/_internal/compiler/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from hugr import Wire
from hugr import tys as ht

from guppylang.definition.custom import CustomCallCompiler
from guppylang.definition.custom import CustomInoutCallCompiler
from guppylang.definition.value import CallReturnWires
from guppylang.prelude._internal.json_defs import load_extension

# ----------------------------------------------
Expand Down Expand Up @@ -35,32 +36,15 @@
# ------------------------------------------------------


class MeasureCompiler(CustomCallCompiler):
"""Compiler for the `measure` function."""
class MeasureReturnCompiler(CustomInoutCallCompiler):
"""Compiler for the `measure_return` function."""

def compile(self, args: list[Wire]) -> list[Wire]:
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
from guppylang.prelude._internal.util import quantum_op

[q] = args
[q, bit] = self.builder.add_op(
quantum_op("Measure")(ht.FunctionType([ht.Qubit], [ht.Qubit, ht.Bool]), []),
q,
)
self.builder.add_op(quantum_op("QFree")(ht.FunctionType([ht.Qubit], []), []), q)
return [bit]


class QAllocCompiler(CustomCallCompiler):
"""Compiler for the `qubit` function."""

def compile(self, args: list[Wire]) -> list[Wire]:
from guppylang.prelude._internal.util import quantum_op

assert not args, "qubit() does not take any arguments"
q = self.builder.add_op(
quantum_op("QAlloc")(ht.FunctionType([], [ht.Qubit]), [])
)
q = self.builder.add_op(
quantum_op("Reset")(ht.FunctionType([ht.Qubit], [ht.Qubit]), []), q
)
return [q]
return CallReturnWires(regular_returns=[bit], inout_returns=[q])
Loading

0 comments on commit 0b0b1af

Please sign in to comment.