Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding batching support for symbolic operators #2672

Merged
merged 4 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
* The adjoint transform `adjoint` can now accept either a single instantiated operator or
a quantum function. It returns an entity of the same type/ call signature as what it was given:
[(#2222)](https://github.com/PennyLaneAI/pennylane/pull/2222)
[(#2672)](https://github.com/PennyLaneAI/pennylane/pull/2672)
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

```pycon
>>> qml.adjoint(qml.PauliX(0))
Expand Down
12 changes: 12 additions & 0 deletions pennylane/ops/op_math/adjoint_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,18 @@ def _wires(self, new_wires):
def num_wires(self):
return self.base.num_wires

@property
def batch_size(self):
return self.base.batch_size

@property
def ndim_params(self):
return self.base.ndim_params

@property
def is_hermitian(self):
return self.base.is_hermitian

def queue(self, context=QueuingContext):
context.safe_update_info(self.base, owner=self)
context.append(self, owns=self.base)
Expand Down
8 changes: 8 additions & 0 deletions pennylane/ops/op_math/pow_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ def _wires(self, new_wires):
# we should create a better way to set new wires in the future
self.base._wires = new_wires

@property
def batch_size(self):
return self.base.batch_size

@property
def ndim_params(self):
return self.base.ndim_params

@property
def num_wires(self):
return len(self.wires)
Expand Down
27 changes: 27 additions & 0 deletions tests/ops/op_math/test_adjoint_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,33 @@ def test_private_wires(self):
op._wires = wire1
assert op._wires == base._wires == wire1

@pytest.mark.parametrize("value", (True, False))
def test_is_hermitian(self, value):
"""Test `is_hermitian` property mirrors that of the base."""

class DummyOp(qml.operation.Operator):
num_wires = 1
is_hermitian = value

op = Adjoint(DummyOp(0))
assert op.is_hermitian == value

def test_batching_properties(self):
"""Test that Adjoint batching behavior mirrors that of the base."""

class DummyOp(qml.operation.Operator):
ndim_params = (0, 2)
num_wires = 1

param1 = [0.3] * 3
param2 = [[[0.3, 1.2]]] * 3

base = DummyOp(param1, param2, wires=0)
op = Adjoint(base)

assert op.ndim_params == (0, 2)
assert op.batch_size == 3


class TestMiscMethods:
"""Test miscellaneous small methods on the Adjoint class."""
Expand Down
37 changes: 22 additions & 15 deletions tests/ops/op_math/test_pow_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,25 +224,16 @@ def test_has_matrix_false(self):

assert not op.has_matrix

def test_is_hermitian_true(self):
@pytest.mark.parametrize("value", (True, False))
def test_is_hermitian(self, value):
"""Test that if the base is hermitian, then the power is hermitian."""

class HermitianOp(qml.operation.Operator):
class DummyOp(qml.operation.Operator):
num_wires = 1
is_hermitian = True
is_hermitian = value

op = Pow(HermitianOp(1), 2.5)
assert op.is_hermitian is True

def test_is_hermitian_false(self):
"""Test that if the base is not hermitian, then the power is non-hermitian."""

class NonHermitianOp(qml.operation.Operator):
num_wires = 1
is_hermitian = False

op = Pow(NonHermitianOp(1), -2)
assert op.is_hermitian is False
op = Pow(DummyOp(1), 2.5)
assert op.is_hermitian is value

def test_queue_category(self):
"""Test that the queue category `"_ops"` carries over."""
Expand All @@ -254,6 +245,22 @@ def test_queue_category_None(self):
op = Pow(qml.PauliX(0) @ qml.PauliY(1), -1.1)
assert op._queue_category is None

def test_batching_properties(self):
"""Test that Pow batching behavior mirrors that of the base."""

class DummyOp(qml.operation.Operator):
ndim_params = (0, 2)
num_wires = 1

param1 = [0.3] * 3
param2 = [[[0.3, 1.2]]] * 3

base = DummyOp(param1, param2, wires=0)
op = Pow(base, 3)

assert op.ndim_params == (0, 2)
assert op.batch_size == 3


class TestMiscMethods:
"""Test miscellaneous minor Pow methods."""
Expand Down