Skip to content

Commit

Permalink
Validate LinearCombination and qml.dot inputs (#5618)
Browse files Browse the repository at this point in the history
**Context:**

Now, we technically allow something like `LinearCombination([1,2],
qml.X(0)@qml.Y(1))`, but that's probabably doesn't do what the user was
actually expecting, and this syntax could lead to some very difficult to
diagnose bugs.

**Description of the Change:**

Explicitly forbid the above syntax.

**Benefits:**

Preventing people from doing weird things accidentally.

**Possible Drawbacks:**

More validation. Maybe someone actually does want the strange behaviour.

**Related GitHub Issues:**
 
[sc-62162]
  • Loading branch information
albi3ro authored May 2, 2024
1 parent 5808548 commit 8fb99cb
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pennylane/ops/functions/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def dot(
Args:
coeffs (Sequence[float, Callable]): sequence containing the coefficients of the linear combination
ops (Sequence[Operator]): sequence containing the operators of the linear combination
ops (Sequence[Operator, PauliWord, PauliSentence]): sequence containing the operators of the linear combination.
Can also be ``PauliWord`` or ``PauliSentence`` instances.
pauli (bool, optional): If ``True``, a :class:`~.PauliSentence`
operator is used to represent the linear combination. If False, a :class:`Sum` operator
is returned. Defaults to ``False``. Note that when ``ops`` consists solely of ``PauliWord``
Expand Down Expand Up @@ -136,6 +137,12 @@ def dot(
if len(coeffs) == 0 and len(ops) == 0:
raise ValueError("Cannot compute the dot product of an empty sequence.")

for t in (Operator, PauliWord, PauliSentence):
if isinstance(ops, t):
raise ValueError(
f"ops must be an Iterable of {t.__name__}'s, not a {t.__name__} itself."
)

if any(callable(c) for c in coeffs):
return ParametrizedHamiltonian(coeffs, ops)

Expand Down
4 changes: 4 additions & 0 deletions pennylane/ops/op_math/linear_combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def __init__(
_pauli_rep=None,
id=None,
):
if isinstance(observables, Operator):
raise ValueError(
"observables must be an Iterable of Operator's, and not an Operator itself."
)
if qml.math.shape(coeffs)[0] != len(observables):
raise ValueError(
"Could not create valid LinearCombination; "
Expand Down
19 changes: 19 additions & 0 deletions tests/ops/functions/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@
class TestDotSum:
"""Unittests for the dot function when ``pauli=False``."""

def test_error_if_ops_operator(self):
"""Test that dot raises an error if ops is an operator itself."""
with pytest.raises(ValueError, match=r"ops must be an Iterable of Operator's"):
qml.dot([1, 1], qml.X(0) @ qml.Y(1))

def test_dot_returns_sum(self):
"""Test that the dot function returns a Sum operator when ``pauli=False``."""
c = [1.0, 2.0, 3.0]
Expand Down Expand Up @@ -280,6 +285,20 @@ def test_identities_with_pauli_sentences_pauli_false(self):
class TestDotPauliSentence:
"""Unittest for the dot function when ``pauli=True``"""

def test_error_if_ops_PauliWord(self):
"""Test that dot raises an error if ops is a PauliWord itself."""
_pw = qml.pauli.PauliWord({0: "X", 1: "Y"})
with pytest.raises(ValueError, match=r"ops must be an Iterable of PauliWord's"):
qml.dot([1, 2], _pw)

def test_error_if_ops_PauliSentence(self):
"""Test that dot raises an error if ops is a PauliSentence itself."""
_pw1 = qml.pauli.PauliWord({0: "X", 1: "Y"})
_pw2 = qml.pauli.PauliWord({2: "Z"})
ps = 2 * _pw1 + 3 * _pw2
with pytest.raises(ValueError, match=r"ops must be an Iterable of PauliSentence's"):
qml.dot([1, 2], ps)

def test_dot_returns_pauli_sentence(self):
"""Test that the dot function returns a PauliSentence class."""
ps = qml.dot(coeffs0, ops0, pauli=True)
Expand Down
6 changes: 6 additions & 0 deletions tests/ops/op_math/test_linear_combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,12 @@ def circuit2(param):
class TestLinearCombination:
"""Test the LinearCombination class"""

def test_error_if_observables_operator(self):
"""Test thatt an error is raised if an operator is provided to observables."""

with pytest.raises(ValueError, match=r"observables must be an Iterable of Operator's"):
qml.ops.LinearCombination([1, 1], qml.X(0) @ qml.Y(1))

PAULI_REPS = (
([], [], PauliSentence({})),
(
Expand Down

0 comments on commit 8fb99cb

Please sign in to comment.