Skip to content

Commit

Permalink
Fixing kron compatibility numpy @ torch (#5540)
Browse files Browse the repository at this point in the history
Fixes #5542

**Relevant Shortcut Stories:**
[sc-61616]

---------

Co-authored-by: Jay Soni <jbsoni@uwaterloo.ca>
  • Loading branch information
KetpuntoG and Jaybsoni authored Apr 19, 2024
1 parent bb672a3 commit 08d286f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 0 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,9 @@
* Fixes a bug in `hamiltonian_expand` that produces incorrect output dimensions when shot vectors are combined with parameter broadcasting.
[(#5494)](https://github.com/PennyLaneAI/pennylane/pull/5494)

* Fixes a bug in `qml.math.kron` that makes torch incompatible with numpy.
[(#5540)](https://github.com/PennyLaneAI/pennylane/pull/5540)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
8 changes: 8 additions & 0 deletions pennylane/math/multi_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ def kron(*args, like=None, **kwargs):
"""The kronecker/tensor product of args."""
if like == "scipy":
return onp.kron(*args, **kwargs) # Dispatch scipy kron to numpy backed specifically.

if like == "torch":
mats = [
ar.numpy.asarray(arg, like="torch") if isinstance(arg, onp.ndarray) else arg
for arg in args
]
return ar.numpy.kron(*mats)

return ar.numpy.kron(*args, like=like, **kwargs)


Expand Down
11 changes: 11 additions & 0 deletions tests/math/test_multi_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,17 @@ def test_dot_autograd():
assert fn.allclose(qml_grad(fn.dot)(x, y), x)


def test_kron():
"""Test the kronecker product function."""
x = torch.tensor([[1, 2], [3, 4]])
y = np.array([[0, 5], [6, 7]])

res = fn.kron(x, y)
expected = torch.tensor([[0, 5, 0, 10], [6, 7, 12, 14], [0, 15, 0, 20], [18, 21, 24, 28]])

assert fn.allclose(res, expected)


class TestMatmul:
@pytest.mark.torch
def test_matmul_torch(self):
Expand Down

0 comments on commit 08d286f

Please sign in to comment.