Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Sampling an observable with more than one wire is not compatible with JAX-JIT #5372

Closed
1 task done
lillian542 opened this issue Mar 12, 2024 · 0 comments · Fixed by #5422
Closed
1 task done
Labels
bug 🐛 Something isn't working

Comments

@lillian542
Copy link
Contributor

Expected behavior

That I can JIT a circuit that returns qml.sample(obs), where obs is some observable with more than one wire.

Actual behavior

A shape mismatch causes the circuit to fail if executing with JAX-JIT

Additional information

No response

Source code

import jax

jax.config.update("jax_enable_x64", True)

dev = qml.device("default.qubit", wires=2, shots=100)

@qml.qnode(dev, interface="jax")
def circuit(x):
    qml.RX(x, wires=0)
    return qml.sample(qml.PauliX(0) @ qml.PauliY(1))

results = jax.jit(circuit)(jax.numpy.array(0.123, dtype=jax.numpy.float64))

Tracebacks

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[131], line 12
      9     qml.RX(x, wires=0)
     10     return qml.sample(qml.PauliX(0) @ qml.PauliY(1))
---> 12 results = jax.jit(circuit)(jax.numpy.array(0.123, dtype=jax.numpy.float64))

    [... skipping hidden 10 frame]

File /opt/homebrew/Caskroom/miniforge/base/envs/pennylane/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1229, in ExecuteReplicated.__call__(self, *args)
   1224   self._handle_token_bufs(
   1225       results.disassemble_prefix_into_single_device_arrays(
   1226           len(self.ordered_effects)),
   1227       results.consume_token())
   1228 else:
-> 1229   results = self.xla_executable.execute_sharded(input_bufs)
   1230 if dispatch.needs_check_special():
   1231   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: Incorrect output shape for return value 0: Expected: (100, 2), Actual: (100,)

System information

pl-dev

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@lillian542 lillian542 added the bug 🐛 Something isn't working label Mar 12, 2024
albi3ro added a commit that referenced this issue Mar 21, 2024
**Context:**

Jitting was failing with `qml.sample` with an observable on more than
one wire.

```
import jax

jax.config.update("jax_enable_x64", True)

dev = qml.device("default.qubit", wires=2, shots=100)

@qml.qnode(dev, interface="jax")
def circuit(x):
    qml.RX(x, wires=0)
    return qml.sample(qml.PauliX(0) @ qml.PauliY(1))

results = jax.jit(circuit)(jax.numpy.array(0.123, dtype=jax.numpy.float64))
```

**Description of the Change:**

Update `SampleMP.shape` so that if there's an observable, there's only
one dimension output.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**

Fixes #5372 [sc-58779]

---------

Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant