Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create expand_plxpr_transforms function for unwrapping transforms natively in plxpr #6722

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

mudit2812
Copy link
Contributor

@mudit2812 mudit2812 commented Dec 13, 2024

[sc-80562]

Before submitting

Please complete the following checklist when submitting a PR:

  • All new features must include a unit test.
    If you've fixed a bug or added code that should be tested, add a test to the
    test directory!

  • All new functions and code must be clearly commented and documented.
    If you do make documentation changes, make sure that the docs build and
    render correctly by running make docs.

  • Ensure that the test suite passes, by running make test.

  • Add a new entry to the doc/releases/changelog-dev.md file, summarizing the
    change, and including a link back to the PR.

  • The PennyLane source code conforms to
    PEP8 standards.
    We check all of our code against Pylint.
    To lint modified files, simply pip install pylint, and then
    run pylint pennylane/path/to/file.py.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


Context:

Description of the Change:

Benefits:

Possible Drawbacks:

Related GitHub Issues:

@mudit2812
Copy link
Contributor Author

Example tested locally:

import pennylane as qml
from pennylane.capture.transforms import expand_plxpr_transforms
import jax

qml.capture.enable()

gate_set = [qml.RX, qml.RY, qml.RZ, qml.CNOT]
wire_map = {i:i+2 for i in range(5)}


def f():
    qml.Rot(1, 2, 3, 4)

    def g():
        qml.Rot(1, 2, 3, 4)
        qml.Toffoli([1, 2, 3])
        qml.U3(1, 2, 3, 4)
        return qml.expval(qml.PauliZ(0))

    m1 = qml.transforms.decompose(g, gate_set=gate_set)()
    m2 = qml.map_wires(g, wire_map)()
    return m1, m2

>>> jax.make_jaxpr(f)()
{ lambda ; . let
    _:AbstractOperator() = Rot[n_wires=1] 1 2 3 4
    a:AbstractMeasurement(n_wires=None) = decompose_transform[
      args_slice=slice(0, 0, None)
      consts_slice=slice(0, 0, None)
      inner_jaxpr={ lambda ; . let
          _:AbstractOperator() = Rot[n_wires=1] 1 2 3 4
          _:AbstractOperator() = Toffoli[n_wires=3] 1 2 3
          _:AbstractOperator() = U3[n_wires=1] 1 2 3 4
          b:AbstractOperator() = PauliZ[n_wires=1] 0
          c:AbstractMeasurement(n_wires=None) = expval_obs b
        in (c,) }
      targs_slice=slice(0, None, None)
      tkwargs={'gate_set': [<class 'pennylane.ops.qubit.parametric_ops_single_qubit.RX'>, <class 'pennylane.ops.qubit.parametric_ops_single_qubit.RY'>, <class 'pennylane.ops.qubit.parametric_ops_single_qubit.RZ'>, <class 'pennylane.ops.op_math.controlled_ops.CNOT'>]}
    ] 
    d:AbstractMeasurement(n_wires=None) = _map_wires_transform_transform[
      args_slice=slice(0, 0, None)
      consts_slice=slice(0, 0, None)
      inner_jaxpr={ lambda ; . let
          _:AbstractOperator() = Rot[n_wires=1] 1 2 3 4
          _:AbstractOperator() = Toffoli[n_wires=3] 1 2 3
          _:AbstractOperator() = U3[n_wires=1] 1 2 3 4
          e:AbstractOperator() = PauliZ[n_wires=1] 0
          f:AbstractMeasurement(n_wires=None) = expval_obs e
        in (f,) }
      targs_slice=slice(0, None, None)
      tkwargs={'wire_map': {0: 2, 1: 3, 2: 4, 3: 5, 4: 6}, 'queue': False}
    ] 
  in (a, d) }

>>> expanded_f = expand_plxpr_transforms(f)
>>> jax.make_jaxpr(expanded_f)()
{ lambda ; . let
    _:AbstractOperator() = Rot[n_wires=1] 1 2 3 4
    _:AbstractOperator() = RZ[n_wires=1] 1 4
    _:AbstractOperator() = RY[n_wires=1] 2 4
    _:AbstractOperator() = RZ[n_wires=1] 3 4
    _:AbstractOperator() = RZ[n_wires=1] 1.5707963267948966 3
    _:AbstractOperator() = RX[n_wires=1] 1.5707963267948966 3
    _:AbstractOperator() = RZ[n_wires=1] 1.5707963267948966 3
    _:AbstractOperator() = CNOT[n_wires=2] 2 3
    _:AbstractOperator() = RZ[n_wires=1] -0.7853981633974483 3
    _:AbstractOperator() = CNOT[n_wires=2] 1 3
    _:AbstractOperator() = RZ[n_wires=1] 0.7853981633974483 3
    _:AbstractOperator() = CNOT[n_wires=2] 2 3
    _:AbstractOperator() = RZ[n_wires=1] -0.7853981633974483 3
    _:AbstractOperator() = CNOT[n_wires=2] 1 3
    _:AbstractOperator() = RZ[n_wires=1] 0.7853981633974483 3
    _:AbstractOperator() = RZ[n_wires=1] 0.7853981633974483 2
    _:AbstractOperator() = CNOT[n_wires=2] 1 2
    _:AbstractOperator() = RZ[n_wires=1] 1.5707963267948966 3
    _:AbstractOperator() = RX[n_wires=1] 1.5707963267948966 3
    _:AbstractOperator() = RZ[n_wires=1] 1.5707963267948966 3
    _:AbstractOperator() = RZ[n_wires=1] 0.7853981633974483 1
    _:AbstractOperator() = RZ[n_wires=1] -0.7853981633974483 2
    _:AbstractOperator() = CNOT[n_wires=2] 1 2
    _:AbstractOperator() = RZ[n_wires=1] 3 4
    _:AbstractOperator() = RY[n_wires=1] 1 4
    _:AbstractOperator() = RZ[n_wires=1] -3 4
    _:AbstractOperator() = RZ[n_wires=1] 3 4
    _:AbstractOperator() = RZ[n_wires=1] 2 4
    a:AbstractOperator() = PauliZ[n_wires=1] 0
    b:AbstractMeasurement(n_wires=None) = expval_obs a
    _:AbstractOperator() = Rot[n_wires=1] 1 2 3 6
    _:AbstractOperator() = Toffoli[n_wires=3] 3 4 5
    _:AbstractOperator() = U3[n_wires=1] 1 2 3 6
    c:AbstractOperator() = PauliZ[n_wires=1] 2
    d:AbstractMeasurement(n_wires=None) = expval_obs c
  in (b, d) }

Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant