Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Add functions to quantum module and make quantum_functional independent #494

Merged
merged 5 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions examples/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@
"outputs": [],
"source": [
"from guppylang.prelude.builtins import owned\n",
"from guppylang.prelude.quantum import qubit, measure\n",
"from guppylang.prelude.quantum_functional import h, cx\n",
"from guppylang.prelude.quantum import qubit, measure, h, cx\n",
"\n",
"module.load(qubit, h, cx, measure)\n",
"\n",
Expand All @@ -101,8 +100,9 @@
" # Allocate two fresh qubits\n",
" q1, q2 = qubit(), qubit()\n",
" # Entangle\n",
" q1, q2 = cx(h(q1), q2)\n",
" q2 = h(q2)\n",
" h(q1)\n",
" cx(q1, q2)\n",
" h(q2)\n",
" # Measure\n",
" b1, b2 = measure(q1), measure(q2)\n",
" return b1 ^ b2"
Expand Down Expand Up @@ -298,10 +298,10 @@
"Guppy compilation failed. Error in file <In [9]>:6\n",
"\n",
"4: @guppy(bad_module)\n",
"5: def bad(q: qubit @owned) -> tuple[qubit, qubit]:\n",
"6: return cx(q, q)\n",
" ^\n",
"GuppyError: Variable `q` with linear type `qubit` was already used (at 6:14)\n"
"5: def bad(q: qubit @owned) -> qubit:\n",
"6: cx(q, q)\n",
" ^\n",
"GuppyError: Variable `q` with linear type `qubit` was already used (at 6:7)\n"
]
}
],
Expand All @@ -310,8 +310,9 @@
"bad_module.load_all(guppylang.prelude.quantum)\n",
"\n",
"@guppy(bad_module)\n",
"def bad(q: qubit @owned) -> tuple[qubit, qubit]:\n",
" return cx(q, q)\n",
"def bad(q: qubit @owned) -> qubit:\n",
" cx(q, q)\n",
" return q\n",
"\n",
"bad_module.compile() # Raises an error"
]
Expand Down Expand Up @@ -341,9 +342,9 @@
"Guppy compilation failed. Error in file <In [10]>:7\n",
"\n",
"5: def bad(q: qubit @owned) -> qubit:\n",
"6: tmp = h(qubit())\n",
"7: tmp, q = cx(tmp, q)\n",
" ^^^\n",
"6: tmp = qubit()\n",
"7: cx(tmp, q)\n",
" ^^^\n",
"GuppyError: Variable `tmp` with linear type `qubit` is not used on all control-flow paths\n"
]
}
Expand All @@ -354,8 +355,8 @@
"\n",
"@guppy(bad_module)\n",
"def bad(q: qubit @owned) -> qubit:\n",
" tmp = h(qubit())\n",
" tmp, q = cx(tmp, q)\n",
" tmp = qubit()\n",
" cx(tmp, q)\n",
" #discard(tmp) # Compiler complains if tmp is not explicitly discarded\n",
" return q\n",
"\n",
Expand Down Expand Up @@ -422,15 +423,14 @@
" q2: qubit\n",
"\n",
" @guppy(module)\n",
" def method(self: \"QubitPair @owned\") -> \"QubitPair\":\n",
" self.q1 = h(self.q1)\n",
" self.q1, self.q2 = cx(self.q1, self.q2)\n",
" return self\n",
" def method(self: \"QubitPair\") -> None:\n",
" h(self.q1)\n",
" cx(self.q1, self.q2)\n",
"\n",
"@guppy(module)\n",
"def make_struct() -> QubitPair:\n",
" pair = QubitPair(qubit(), qubit())\n",
" # pair = pair.method() # TODO: Calling methods doesn't work yet\n",
" pair.method()\n",
" return pair\n",
"\n",
"program = module.compile()"
Expand Down Expand Up @@ -559,7 +559,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
ast.BitAnd: ("__and__", "__rand__", "&"),
ast.MatMult: ("__matmul__", "__rmatmul__", "@"),
ast.Eq: ("__eq__", "__eq__", "=="),
ast.NotEq: ("__neq__", "__neq__", "!="),
ast.NotEq: ("__ne__", "__ne__", "!="),
ast.Lt: ("__lt__", "__gt__", "<"),
ast.LtE: ("__le__", "__ge__", "<="),
ast.Gt: ("__gt__", "__lt__", ">"),
Expand Down
4 changes: 2 additions & 2 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from guppylang.definition.const import RawConstDef
from guppylang.definition.custom import (
CustomCallChecker,
CustomCallCompiler,
CustomInoutCallCompiler,
DefaultCallChecker,
NotImplementedCallCompiler,
OpCompiler,
Expand Down Expand Up @@ -197,7 +197,7 @@ def nat_var(self, module: GuppyModule, name: str) -> ConstVarDef:
def custom(
self,
module: GuppyModule,
compiler: CustomCallCompiler | None = None,
compiler: CustomInoutCallCompiler | None = None,
checker: CustomCallChecker | None = None,
higher_order_value: bool = True,
name: str = "",
Expand Down
43 changes: 30 additions & 13 deletions guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.nodes import GlobalCall
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, NoneType, Type
from guppylang.tys.ty import (
FuncInput,
FunctionType,
InputFlags,
NoneType,
Type,
type_to_row,
)


@dataclass(frozen=True)
Expand All @@ -41,7 +48,7 @@ class RawCustomFunctionDef(ParsableDef):

defined_at: ast.FunctionDef
call_checker: "CustomCallChecker"
call_compiler: "CustomCallCompiler"
call_compiler: "CustomInoutCallCompiler"

# Whether the function may be used as a higher-order value. This is only possible
# if a static type for the function is provided.
Expand Down Expand Up @@ -85,15 +92,16 @@ def compile_call(
) -> Sequence[Wire]:
"""Compiles a call to the function."""
# Note: We have _compiled_ globals rather than `Globals` here,
# so we cannot use `self._get_signature()`.
# so we cannot use `self._get_signature()` to build a `CustomFunctionDef`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having tweaked this doc slightly, I still have to say it'd be good to have some doc as to why we might ever do this without parsing it into a CustomFunctionDef first...

self.call_compiler._setup(
type_args,
dfg,
globals,
node,
function_ty,
None,
)
return self.call_compiler.compile(args)
return self.call_compiler.compile_with_inouts(args).regular_returns

def _get_signature(self, globals: Globals) -> FunctionType | None:
"""Returns the type of the function, if known.
Expand Down Expand Up @@ -237,7 +245,7 @@ def compile_call(
)
hugr_ty = concrete_ty.to_hugr()

self.call_compiler._setup(type_args, dfg, globals, node, hugr_ty)
self.call_compiler._setup(type_args, dfg, globals, node, hugr_ty, self)
return self.call_compiler.compile_with_inouts(args)


Expand Down Expand Up @@ -284,6 +292,7 @@ class CustomInoutCallCompiler(ABC):
globals: CompiledGlobals
node: AstNode
ty: ht.FunctionType
func: CustomFunctionDef | None

def _setup(
self,
Expand All @@ -292,12 +301,14 @@ def _setup(
globals: CompiledGlobals,
node: AstNode,
hugr_ty: ht.FunctionType,
func: CustomFunctionDef | None,
) -> None:
self.type_args = type_args
self.dfg = dfg
self.globals = globals
self.node = node
self.ty = hugr_ty
self.func = func

@abstractmethod
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
Expand All @@ -307,6 +318,11 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
passed through the function.
"""

@property
def builder(self) -> DfBase[ops.DfParentOp]:
"""The hugr dataflow builder."""
return self.dfg.builder


class CustomCallCompiler(CustomInoutCallCompiler, ABC):
"""Abstract base class for custom function call compilers with only owned args."""
Expand All @@ -318,11 +334,6 @@ def compile(self, args: list[Wire]) -> list[Wire]:
Use the provided `self.builder` to add nodes to the Hugr graph.
"""

@property
def builder(self) -> DfBase[ops.DfParentOp]:
"""The hugr dataflow builder."""
return self.dfg.builder

def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
return CallReturnWires(self.compile(args), inout_returns=[])

Expand Down Expand Up @@ -353,7 +364,7 @@ def compile(self, args: list[Wire]) -> list[Wire]:
raise InternalGuppyError("Function should have been removed during checking")


class OpCompiler(CustomCallCompiler):
class OpCompiler(CustomInoutCallCompiler):
"""Call compiler for functions that are directly implemented via Hugr ops.

args:
Expand All @@ -366,10 +377,16 @@ class OpCompiler(CustomCallCompiler):
def __init__(self, op: Callable[[ht.FunctionType, Inst], ops.DataflowOp]) -> None:
self.op = op

def compile(self, args: list[Wire]) -> list[Wire]:
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
op = self.op(self.ty, self.type_args)
node = self.builder.add_op(op, *args)
return list(node)
num_returns = (
len(type_to_row(self.func.ty.output)) if self.func else len(self.ty.output)
)
return CallReturnWires(
regular_returns=list(node[:num_returns]),
inout_returns=list(node[num_returns:]),
)


class NoopCompiler(CustomCallCompiler):
Expand Down
28 changes: 6 additions & 22 deletions guppylang/prelude/_internal/compiler/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from hugr import Wire
from hugr import tys as ht

from guppylang.definition.custom import CustomCallCompiler
from guppylang.definition.custom import CustomInoutCallCompiler
from guppylang.definition.value import CallReturnWires
from guppylang.prelude._internal.json_defs import load_extension

# ----------------------------------------------
Expand Down Expand Up @@ -35,32 +36,15 @@
# ------------------------------------------------------


class MeasureCompiler(CustomCallCompiler):
"""Compiler for the `measure` function."""
class MeasureReturnCompiler(CustomInoutCallCompiler):
"""Compiler for the `measure_return` function."""

def compile(self, args: list[Wire]) -> list[Wire]:
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
from guppylang.prelude._internal.util import quantum_op

[q] = args
[q, bit] = self.builder.add_op(
quantum_op("Measure")(ht.FunctionType([ht.Qubit], [ht.Qubit, ht.Bool]), []),
q,
)
self.builder.add_op(quantum_op("QFree")(ht.FunctionType([ht.Qubit], []), []), q)
return [bit]


class QAllocCompiler(CustomCallCompiler):
"""Compiler for the `qubit` function."""

def compile(self, args: list[Wire]) -> list[Wire]:
from guppylang.prelude._internal.util import quantum_op

assert not args, "qubit() does not take any arguments"
q = self.builder.add_op(
quantum_op("QAlloc")(ht.FunctionType([], [ht.Qubit]), [])
)
q = self.builder.add_op(
quantum_op("Reset")(ht.FunctionType([ht.Qubit], [ht.Qubit]), []), q
)
return [q]
return CallReturnWires(regular_returns=[bit], inout_returns=[q])
Loading
Loading