Skip to content

Commit

Permalink
Qutrit mixed apply operation (#5032)
Browse files Browse the repository at this point in the history
**Context:**
Currently the qutrit_mixed device is being developed this is a necessary
addition to the overall project as this applies operations to a qutrit
mixed state. This is a prerequisite for other functionality relating to
the qutrit mixed device for noisy qutrit simulation.

**Description of the Change:**
Added functionality for applying operations to a qutrit mixed-state. The
new ``apply_operation`` function can be used to apply gates and Channels
to a qutrit mixed-state.

**Benefits:**
Allows for Channels and operations to be applied to a mixed state, will
be used to add execute functionality to qutrit mixed-state device
allowing for noisy qutrit simulation

**Possible Drawbacks:**
Abstracting for qubits and more generally qutrits may have added
challenges and will require a reforctor or copied code code-smell.

**Related GitHub Issues:**
N/A

---------

Co-authored-by: Gabe PC <bottrill@student.ubc.ca>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
Co-authored-by: Christina Lee <chrissie.c.l@gmail.com>
Co-authored-by: Olivia Di Matteo <2068515+glassnotes@users.noreply.github.com>
  • Loading branch information
6 people committed Jan 29, 2024
1 parent 4da6ba6 commit 4823273
Show file tree
Hide file tree
Showing 10 changed files with 678 additions and 10 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@
* The transform `split_non_commuting` now accepts measurements of type `probs`, `sample` and `counts` which accept both wires and observables.
[(#4972)](https://github.com/PennyLaneAI/pennylane/pull/4972)

* A function called `apply_operation` has been added to the new `qutrit_mixed` module found in `qml.devices` that applies operations to device-compatible states.
[(#5032)](https://github.com/PennyLaneAI/pennylane/pull/5032)

<h3>Breaking changes 💔</h3>

* Pin Black to `v23.12` to prevent unnecessary formatting changes.
Expand Down Expand Up @@ -262,6 +265,7 @@ This release contains contributions from (in alphabetical order):

Abhishek Abhishek,
Utkarsh Azad,
Gabriel Bottrill,
Astral Cai,
Isaac De Vlugt,
Korbinian Kottmann,
Expand Down
3 changes: 3 additions & 0 deletions pennylane/devices/qutrit_mixed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
This submodule is internal and subject to change without a deprecation cycle. Use
at your own discretion.
.. currentmodule:: pennylane.devices.qutrit_mixed
.. autosummary::
:toctree: api
create_initial_state
apply_operation
"""

from .apply_operation import apply_operation
from .initialize_state import create_initial_state
184 changes: 184 additions & 0 deletions pennylane/devices/qutrit_mixed/apply_operation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to apply operations to a qutrit mixed state."""
# pylint: disable=unused-argument

from functools import singledispatch
from string import ascii_letters as alphabet
import pennylane as qml
from pennylane import math
from pennylane import numpy as np
from pennylane.operation import Channel
from .utils import QUDIT_DIM, get_einsum_mapping, get_new_state_einsum_indices

alphabet_array = np.array(list(alphabet))


def _map_indices_apply_channel(**kwargs):
"""Map indices to einsum string
Args:
**kwargs (dict): Stores indices calculated in `get_einsum_mapping`
Returns:
String of einsum indices to complete einsum calculations
"""
op_1_indices = f"{kwargs['kraus_index']}{kwargs['new_row_indices']}{kwargs['row_indices']}"
op_2_indices = f"{kwargs['kraus_index']}{kwargs['col_indices']}{kwargs['new_col_indices']}"

new_state_indices = get_new_state_einsum_indices(
old_indices=kwargs["col_indices"] + kwargs["row_indices"],
new_indices=kwargs["new_col_indices"] + kwargs["new_row_indices"],
state_indices=kwargs["state_indices"],
)
# index mapping for einsum, e.g., '...iga,...abcdef,...idh->...gbchef'
return (
f"...{op_1_indices},...{kwargs['state_indices']},...{op_2_indices}->...{new_state_indices}"
)


def apply_operation_einsum(op: qml.operation.Operator, state, is_state_batched: bool = False):
r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the
quantum state. For a unitary gate, there is a single Kraus operator.
Args:
op (Operator): Operator to apply to the quantum state
state (array[complex]): Input quantum state
is_state_batched (bool): Boolean representing whether the state is batched or not
Returns:
array[complex]: output_state
"""
einsum_indices = get_einsum_mapping(op, state, _map_indices_apply_channel, is_state_batched)

num_ch_wires = len(op.wires)

# This could be pulled into separate function if tensordot is added
if isinstance(op, Channel):
kraus = op.kraus_matrices()
else:
kraus = [op.matrix()]

# Shape kraus operators
kraus_shape = [len(kraus)] + [QUDIT_DIM] * num_ch_wires * 2
if not isinstance(op, Channel):
mat = op.matrix()
dim = QUDIT_DIM**num_ch_wires
batch_size = math.get_batch_size(mat, (dim, dim), dim**2)
if batch_size is not None:
# Add broadcasting dimension to shape
kraus_shape = [batch_size] + kraus_shape
if op.batch_size is None:
op._batch_size = batch_size # pylint:disable=protected-access

kraus = math.stack(kraus)
kraus_transpose = math.stack(math.moveaxis(kraus, source=-1, destination=-2))
# Torch throws error if math.conj is used before stack
kraus_dagger = math.conj(kraus_transpose)

kraus = math.cast(math.reshape(kraus, kraus_shape), complex)
kraus_dagger = math.cast(math.reshape(kraus_dagger, kraus_shape), complex)

return math.einsum(einsum_indices, kraus, state, kraus_dagger)


@singledispatch
def apply_operation(
op: qml.operation.Operator, state, is_state_batched: bool = False, debugger=None
):
"""Apply an operation to a given state.
Args:
op (Operator): The operation to apply to ``state``
state (TensorLike): The starting state.
is_state_batched (bool): Boolean representing whether the state is batched or not
debugger (_Debugger): The debugger to use
Returns:
ndarray: output state
.. warning::
``apply_operation`` is an internal function, and thus subject to change without a deprecation cycle.
.. warning::
``apply_operation`` applies no validation to its inputs.
This function assumes that the wires of the operator correspond to indices
of the state. See :func:`~.map_wires` to convert operations to integer wire labels.
The shape of state should be ``[QUDIT_DIM]*(num_wires * 2)``, where ``QUDIT_DIM`` is
the dimension of the system.
This is a ``functools.singledispatch`` function, so additional specialized kernels
for specific operations can be registered like:
.. code-block:: python
@apply_operation.register
def _(op: type_op, state):
# custom op application method here
**Example:**
>>> state = np.zeros((3,3))
>>> state[0][0] = 1
>>> state
tensor([[1., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], requires_grad=True)
>>> apply_operation(qml.TShift(0), state)
tensor([[0., 0., 0.],
[0., 1., 0],
[0., 0., 0.],], requires_grad=True)
"""
return _apply_operation_default(op, state, is_state_batched, debugger)


def _apply_operation_default(op, state, is_state_batched, debugger):
"""The default behaviour of apply_operation, accessed through the standard dispatch
of apply_operation, as well as conditionally in other dispatches.
"""

return apply_operation_einsum(op, state, is_state_batched=is_state_batched)
# TODO add tensordot and benchmark for performance


# TODO add diagonal for speed up.


@apply_operation.register
def apply_snapshot(op: qml.Snapshot, state, is_state_batched: bool = False, debugger=None):
"""Take a snapshot of the mixed state"""
if debugger and debugger.active:
measurement = op.hyperparameters["measurement"]
if measurement:
# TODO replace with: measure once added
raise NotImplementedError # TODO
if is_state_batched:
dim = int(math.sqrt(math.size(state[0])))
flat_shape = [math.shape(state)[0], dim, dim]
else:
dim = int(math.sqrt(math.size(state)))
flat_shape = [dim, dim]

snapshot = math.reshape(state, flat_shape)
if op.tag:
debugger.snapshots[op.tag] = snapshot
else:
debugger.snapshots[len(debugger.snapshots)] = snapshot
return state


# TODO add special case speedups
19 changes: 9 additions & 10 deletions pennylane/devices/qutrit_mixed/initialize_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from typing import Iterable, Union
import pennylane as qml
from pennylane.operation import StatePrepBase

qudit_dim = 3 # specifies qudit dimension
from .utils import QUDIT_DIM


def create_initial_state(
Expand Down Expand Up @@ -55,18 +54,18 @@ def _apply_state_vector(state, num_wires): # function is easy to abstract for q
Args:
state (array[complex]): normalized input state of length
``qudit_dim**num_wires``, where ``qudit_dim`` is the dimension of the system.
``QUDIT_DIM**num_wires``, where ``QUDIT_DIM`` is the dimension of the system.
num_wires (int): number of wires that get initialized in the state
Returns:
array[complex]: complex array of shape ``[qudit_dim] * (2 * num_wires)``
representing the density matrix of this state, where ``qudit_dim`` is
array[complex]: complex array of shape ``[QUDIT_DIM] * (2 * num_wires)``
representing the density matrix of this state, where ``QUDIT_DIM`` is
the dimension of the system.
"""

# Initialize the entire set of wires with the state
rho = qml.math.outer(state, qml.math.conj(state))
return qml.math.reshape(rho, [qudit_dim] * 2 * num_wires)
return qml.math.reshape(rho, [QUDIT_DIM] * 2 * num_wires)


def _create_basis_state(num_wires, index): # function is easy to abstract for qudit
Expand All @@ -77,10 +76,10 @@ def _create_basis_state(num_wires, index): # function is easy to abstract for q
index (int): integer representing the computational basis state.
Returns:
array[complex]: complex array of shape ``[qudit_dim] * (2 * num_wires)``
representing the density matrix of the basis state, where ``qudit_dim`` is
array[complex]: complex array of shape ``[QUDIT_DIM] * (2 * num_wires)``
representing the density matrix of the basis state, where ``QUDIT_DIM`` is
the dimension of the system.
"""
rho = qml.math.zeros((qudit_dim**num_wires, qudit_dim**num_wires))
rho = qml.math.zeros((QUDIT_DIM**num_wires, QUDIT_DIM**num_wires))
rho[index, index] = 1
return qml.math.reshape(rho, [qudit_dim] * (2 * num_wires))
return qml.math.reshape(rho, [QUDIT_DIM] * (2 * num_wires))
86 changes: 86 additions & 0 deletions pennylane/devices/qutrit_mixed/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions and variables to be utilized by qutrit mixed state simulator."""
import functools
from string import ascii_letters as alphabet
import pennylane as qml
from pennylane import numpy as np

alphabet_array = np.array(list(alphabet))
QUDIT_DIM = 3 # specifies qudit dimension


def get_einsum_mapping(
op: qml.operation.Operator, state, map_indices, is_state_batched: bool = False
):
r"""Finds the indices for einsum to apply kraus operators to a mixed state
Args:
op (Operator): Operator to apply to the quantum state
state (array[complex]): Input quantum state
map_indices (function): Maps the calculated indices to an einsum indices string
is_state_batched (bool): Boolean representing whether the state is batched or not
Returns:
str: indices mapping that defines the einsum
"""
num_ch_wires = len(op.wires)
num_wires = int((len(qml.math.shape(state)) - is_state_batched) / 2)
rho_dim = 2 * num_wires

# Tensor indices of the state. For each qutrit, need an index for rows *and* columns
state_indices = alphabet[:rho_dim]

# row indices of the quantum state affected by this operation
row_wires_list = op.wires.tolist()
row_indices = "".join(alphabet_array[row_wires_list].tolist())

# column indices are shifted by the number of wires
col_wires_list = [w + num_wires for w in row_wires_list]
col_indices = "".join(alphabet_array[col_wires_list].tolist())

# indices in einsum must be replaced with new ones
new_row_indices = alphabet[rho_dim : rho_dim + num_ch_wires]
new_col_indices = alphabet[rho_dim + num_ch_wires : rho_dim + 2 * num_ch_wires]

# index for summation over Kraus operators
kraus_index = alphabet[rho_dim + 2 * num_ch_wires : rho_dim + 2 * num_ch_wires + 1]

# apply mapping function
return map_indices(
state_indices=state_indices,
kraus_index=kraus_index,
row_indices=row_indices,
new_row_indices=new_row_indices,
col_indices=col_indices,
new_col_indices=new_col_indices,
)


def get_new_state_einsum_indices(old_indices, new_indices, state_indices):
"""Retrieves the einsum indices string for the new state
Args:
old_indices (str): indices that are summed
new_indices (str): indices that must be replaced with sums
state_indices (str): indices of the original state
Returns:
str: the einsum indices of the new state
"""
return functools.reduce(
lambda old_string, idx_pair: old_string.replace(idx_pair[0], idx_pair[1]),
zip(old_indices, new_indices),
state_indices,
)
2 changes: 2 additions & 0 deletions tests/devices/qubit/test_apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,8 @@ def test_correctness_jax(self, op_wires, state_wires, batch_dim):
when applying it to a Jax state."""
import jax

jax.config.update("jax_enable_x64", True)

batched = batch_dim is not None
shape = [batch_dim] + [2] * state_wires if batched else [2] * state_wires
# Input state
Expand Down
Loading

0 comments on commit 4823273

Please sign in to comment.