Skip to content

Commit

Permalink
[unitaryHACK] Create a Pytorch simulator #1225 (#1360)
Browse files Browse the repository at this point in the history
* add default torch

* install plugin

* basic circuits work

* sampling for torch simulator

* convert array to torch tensor

* enable backprop on expvalues

* rewrite sample_basis_state

* cleaning

* add docstring operations

* add controlphaseshift + multiRZ

* add new operations from tf_ops

* stop check import

* update

* solve RZ gate

* fixing all operations

* check version compatibility

* fix double exc gates

* main docstring

* docstring + cleaning

* draft unit test

* remove torch test

* Device's tests added (test_default_qubit_torch.py)

* correction

* Some checks corrections

* fix tests + autograd

* code factor

* code factor

* code factor

* Torch device test added (To check)

* fix _asarray

* coverage report

* added torch device in passthru_devices

* update test + passthru

* enable inverse operation

* correct inverse operation

* update

* passing 147 test-cases

* suppress tf tests

* changed semantic version

* fix bug

* fix bug

* Version in line 279

* testing for cude device

* Update with suggested changes

* default_qubit_torch added to the documentation list

* removing torch from autograd

* removing cuda test

* removing cuda test support

* Update

* Update

* rewrite _tensordot

* solving PR reviews

* code reformatting after: black -l 100 pennylane tests

* rmoved whitespaces

* doc/requirements

* doc/requirements

* conflict solved

* suggested change in docstring

* Suggested change in docstring fixed

* fix docstring default_qubit_torch

* Update tests.yml

* fix cuda

* fix cuda

* fix docs and CI

* fix docs and CI

* fix cuda

* fix cuda in default.qubit instead of _qubit_device

* fixed doc issue

* pushed chages after black

* fix version to 1.8.0 in torch.py

* unremove test

* black

* suggested changes

* fix versions

* rewrite _apply_ops in torch + fix test

* delete comment

* adding default.qubit.torch to conftest.py

* bump pytorch version

* running black

* fix _apply_state_vector + test_tape_torch

* fix test_qnode_torch

* added Ising operations

* updated doc strings

* running black

* Update .github/workflows/tests.yml

Co-authored-by: Christina Lee <chrissie.c.l@gmail.com>

* Update default_qubit_torch.py

* running black

* Update tests/devices/test_default_qubit_torch.py

* Update tests/devices/test_default_qubit_torch.py

Co-authored-by: antalszava <antalszava@gmail.com>

* Conflicts and code factor notes

* Delete qubit.py

* Update tests/devices/test_default_qubit_torch.py

Co-authored-by: Christina Lee <chrissie.c.l@gmail.com>

* Update tests/devices/test_default_qubit_torch.py

* Part 1 of reformatting

* reverting formatting changes pt 2

* reverting formatting changes pt 3

* revert formatting changes pt 4

* revert formatting changes pt last

* fixing mistake

* Apply suggestions from code review

Co-authored-by: antalszava <antalszava@gmail.com>

* Update tests/devices/test_default_qubit_torch.py

* remove some white space

* update docstring example

* Update tests/devices/test_default_qubit_torch.py

* format

* autograd.py from master

* autograd.py fixes

* apply_state_vector fix

* black

* improve tests, fix diagonal gate application

* Update pennylane/devices/default_qubit_torch.py

Co-authored-by: antalszava <antalszava@gmail.com>

* Update pennylane/devices/default_qubit_torch.py

Co-authored-by: antalszava <antalszava@gmail.com>

* finish polishing tests

* formatting, remove print statements

* add diagonal inverse test

* changelog, black

* style change

* Update pennylane/devices/default_qubit_torch.py

* Update pennylane/devices/default_qubit_torch.py

* error if sampling for shots=None

* test error if sampling for shots=None

* torch test in device test properties test case

* Ising ops support and adjust

* Ising ops tests

* remove commented tests

* error import

* format

* Update pennylane/devices/default_qubit_torch.py

* Update pennylane/devices/default_qubit_torch.py

* Update pennylane/devices/default_qubit_torch.py

* Update pennylane/devices/torch_ops.py

* Update pennylane/devices/torch_ops.py

* Update tests/devices/test_default_qubit_torch.py

* Update tests/devices/test_default_qubit_torch.py

* Update tests/devices/test_default_qubit_torch.py

* Update pennylane/devices/torch_ops.py

* Update pennylane/devices/default_qubit_torch.py

* no warnings import

* format

* minor fixes

* docstring, sampling super, and test qchem ops

* remove unused import

* black

* test double excitation gates

* black

* revert sampling change

* qml.quantumfunctionerror

* please!

* final black

* I swear i just did black

* Update pennylane/devices/default_qubit.py

* fix

Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: estebanpc <you@example.com>
Co-authored-by: PCesteban <estebandpc@outlook.com>
Co-authored-by: arshpreetsingh <arsh840@gmail.com>
Co-authored-by: vishnu <vishnuajith@gmail.com>
Co-authored-by: Maria Schuld <mariaschuld@gmail.com>
Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Nathan Killoran <co9olguy@users.noreply.github.com>
Co-authored-by: antalszava <antalszava@gmail.com>
Co-authored-by: Christina Lee <chrissie.c.l@gmail.com>
  • Loading branch information
11 people authored Aug 27, 2021
1 parent 7ee5990 commit 40aaeb6
Show file tree
Hide file tree
Showing 18 changed files with 2,450 additions and 19 deletions.
12 changes: 8 additions & 4 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

<h3>New features since last release</h3>


* A new pytorch device, `qml.device('default.qubit.torch', wires=wires)`, supports
backpropogation with the torch interface.
[(#1225)](https://github.com/PennyLaneAI/pennylane/pull/1360)

* The ability to define *batch* transforms has been added via the new
`@qml.batch_transform` decorator.
[(#1493)](https://github.com/PennyLaneAI/pennylane/pull/1493)
Expand Down Expand Up @@ -364,10 +369,9 @@ and requirements-ci.txt (unpinned). This latter would be used by the CI.

This release contains contributions from (in alphabetical order):


Vishnu Ajith, Akash Narayanan B, Thomas Bromley, Tanya Garg, Josh Izaac, Prateek Jain, Johannes Jakob Meyer, Pratul Saini, Maria Schuld,
Ingrid Strandberg, David Wierichs, Vincent Wong.

Vishnu Ajith, Akash Narayanan B, Thomas Bromley, Tanya Garg, Josh Izaac, Prateek Jain, Christina Lee,
Johannes Jakob Meyer, Esteban Payares, Pratul Saini, Maria Schuld, Arshpreet Singh, Ingrid Strandberg,
Slimane Thabet, David Wierichs, Vincent Wong.

# Release 0.17.0 (current release)

Expand Down
2 changes: 2 additions & 0 deletions pennylane/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
default_qubit
default_qubit_jax
default_qubit_torch
default_qubit_tf
default_qubit_autograd
default_gaussian
default_mixed
tf_ops
torch_ops
autograd_ops
tests
"""
Expand Down
5 changes: 3 additions & 2 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def expval(self, observable, shot_range=None, bin_size=None):
if isinstance(coeff, qml.numpy.tensor) and not coeff.requires_grad:
coeff = qml.math.toarray(coeff)

res = res + (
res = qml.math.convert_like(res, product) + (
qml.math.cast(qml.math.convert_like(coeff, product), "complex128") * product
)
return qml.math.real(res)
Expand Down Expand Up @@ -536,6 +536,7 @@ def capabilities(cls):
returns_state=True,
passthru_devices={
"tf": "default.qubit.tf",
"torch": "default.qubit.torch",
"autograd": "default.qubit.autograd",
"jax": "default.qubit.jax",
},
Expand Down Expand Up @@ -609,7 +610,7 @@ def _apply_state_vector(self, state, device_wires):
if state.ndim != 1 or n_state_vector != 2 ** len(device_wires):
raise ValueError("State vector must be of length 2**wires.")

if not np.allclose(np.linalg.norm(state, ord=2), 1.0, atol=tolerance):
if not qml.math.allclose(qml.math.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError("Sum of amplitudes-squared does not equal one.")

if len(device_wires) == self.num_wires and sorted(device_wires) == device_wires:
Expand Down
301 changes: 301 additions & 0 deletions pennylane/devices/default_qubit_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains a PyTorch implementation of the :class:`~.DefaultQubit`
reference plugin.
"""
import semantic_version

try:
import torch

VERSION_SUPPORT = semantic_version.match(">=1.8.1", torch.__version__)
if not VERSION_SUPPORT:
raise ImportError("default.qubit.torch device requires Torch>=1.8.1")

except ImportError as e:
raise ImportError("default.qubit.torch device requires Torch>=1.8.1") from e

import numpy as np
from pennylane.operation import DiagonalOperation
from pennylane.devices import torch_ops
from . import DefaultQubit


class DefaultQubitTorch(DefaultQubit):
"""Simulator plugin based on ``"default.qubit"``, written using PyTorch.
**Short name:** ``default.qubit.torch``
This device provides a pure-state qubit simulator written using PyTorch.
As a result, it supports classical backpropagation as a means to compute the Jacobian. This can
be faster than the parameter-shift rule for analytic quantum gradients
when the number of parameters to be optimized is large.
To use this device, you will need to install PyTorch:
.. code-block:: console
pip install torch>=1.8.0
**Example**
The ``default.qubit.torch`` is designed to be used with end-to-end classical backpropagation
(``diff_method="backprop"``) and the PyTorch interface. This is the default method
of differentiation when creating a QNode with this device.
Using this method, the created QNode is a 'white-box', and is
tightly integrated with your PyTorch computation:
.. code-block:: python
dev = qml.device("default.qubit.torch", wires=1)
@qml.qnode(dev, interface="torch", diff_method="backprop")
def circuit(x):
qml.RX(x[1], wires=0)
qml.Rot(x[0], x[1], x[2], wires=0)
return qml.expval(qml.PauliZ(0))
>>> weights = torch.tensor([0.2, 0.5, 0.1], requires_grad=True)
>>> res = circuit(weights)
>>> res.backward()
>>> print(weights.grad)
tensor([-2.2527e-01, -1.0086e+00, 1.3878e-17])
Autograd mode will also work when using classical backpropagation:
>>> def cost(weights):
... return torch.sum(circuit(weights)**3) - 1
>>> res = circuit(weights)
>>> res.backward()
>>> print(weights.grad)
tensor([-4.5053e-01, -2.0173e+00, 5.9837e-17])
Executing the pipeline in PyTorch will allow the whole computation to be run on the GPU,
and therefore providing an acceleration. Your parameters need to be instantiated on the same
device as the backend device.
.. code-block:: python
dev = qml.device("default.qubit.torch", wires=1, torch_device='cuda')
@qml.qnode(dev, interface="torch", diff_method="backprop")
def circuit(x):
qml.RX(x[1], wires=0)
qml.Rot(x[0], x[1], x[2], wires=0)
return qml.expval(qml.PauliZ(0))
>>> weights = torch.tensor([0.2, 0.5, 0.1], requires_grad=True, device='cuda')
>>> res = circuit(weights)
>>> res.backward()
>>> print(weights.grad)
tensor([-2.2527e-01, -1.0086e+00, 1.3878e-17])
There are a couple of things to keep in mind when using the ``"backprop"``
differentiation method for QNodes:
* You must use the ``"torch"`` interface for classical backpropagation, as PyTorch is
used as the device backend.
* Only exact expectation values, variances, and probabilities are differentiable.
When instantiating the device with ``shots!=None``, differentiating QNode
outputs will result in ``None``.
If you wish to use a different machine-learning interface, or prefer to calculate quantum
gradients using the ``parameter-shift`` or ``finite-diff`` differentiation methods,
consider using the ``default.qubit`` device instead.
Args:
wires (int, Iterable): Number of subsystems represented by the device,
or iterable that contains unique labels for the subsystems. Default 1 if not specified.
shots (None, int): How many times the circuit should be evaluated (or sampled) to estimate
the expectation values. Defaults to ``None`` if not specified, which means
that the device returns analytical results.
If ``shots > 0`` is used, the ``diff_method="backprop"``
QNode differentiation method is not supported and it is recommended to consider
switching device to ``default.qubit`` and using ``diff_method="parameter-shift"``.
torch_device='cpu' (str): the device on which the computation will be run, ``'cpu'`` or ``'cuda'``
"""

name = "Default qubit (Torch) PennyLane plugin"
short_name = "default.qubit.torch"

parametric_ops = {
"PhaseShift": torch_ops.PhaseShift,
"ControlledPhaseShift": torch_ops.ControlledPhaseShift,
"RX": torch_ops.RX,
"RY": torch_ops.RY,
"RZ": torch_ops.RZ,
"MultiRZ": torch_ops.MultiRZ,
"Rot": torch_ops.Rot,
"CRX": torch_ops.CRX,
"CRY": torch_ops.CRY,
"CRZ": torch_ops.CRZ,
"CRot": torch_ops.CRot,
"IsingXX": torch_ops.IsingXX,
"IsingYY": torch_ops.IsingYY,
"IsingZZ": torch_ops.IsingZZ,
"SingleExcitation": torch_ops.SingleExcitation,
"SingleExcitationPlus": torch_ops.SingleExcitationPlus,
"SingleExcitationMinus": torch_ops.SingleExcitationMinus,
"DoubleExcitation": torch_ops.DoubleExcitation,
"DoubleExcitationPlus": torch_ops.DoubleExcitationPlus,
"DoubleExcitationMinus": torch_ops.DoubleExcitationMinus,
}

C_DTYPE = torch.complex128
R_DTYPE = torch.float64

_abs = staticmethod(torch.abs)
_einsum = staticmethod(torch.einsum)
_flatten = staticmethod(torch.flatten)
_reshape = staticmethod(torch.reshape)
_roll = staticmethod(torch.roll)
_stack = staticmethod(lambda arrs, axis=0, out=None: torch.stack(arrs, axis=axis, out=out))
_tensordot = staticmethod(
lambda a, b, axes: torch.tensordot(
a, b, axes if isinstance(axes, int) else tuple(map(list, axes))
)
)
_transpose = staticmethod(lambda a, axes=None: a.permute(*axes))
_asnumpy = staticmethod(lambda x: x.cpu().numpy())
_conj = staticmethod(torch.conj)
_imag = staticmethod(torch.imag)
_norm = staticmethod(torch.norm)
_flatten = staticmethod(torch.flatten)

def __init__(self, wires, *, shots=None, analytic=None, torch_device="cpu"):
self._torch_device = torch_device
super().__init__(wires, shots=shots, cache=0, analytic=analytic)

# Move state to torch device (e.g. CPU, GPU, XLA, ...)
self._state.requires_grad = True
self._state = self._state.to(self._torch_device)
self._pre_rotated_state = self._state

@staticmethod
def _asarray(a, dtype=None):
if isinstance(a, list):
# Handle unexpected cases where we don't have a list of tensors
if not isinstance(a[0], torch.Tensor):
res = np.asarray(a)
res = torch.from_numpy(res)
else:
res = torch.cat([torch.reshape(i, (-1,)) for i in a], dim=0)
res = torch.cat([torch.reshape(i, (-1,)) for i in res], dim=0)
else:
res = torch.as_tensor(a, dtype=dtype)
return res

@staticmethod
def _dot(x, y):
if x.device != y.device:
if x.device != "cpu":
return torch.tensordot(x, y.to(x.device), dims=1)
if y.device != "cpu":
return torch.tensordot(x.to(y.device), y, dims=1)

return torch.tensordot(x, y, dims=1)

def _cast(self, a, dtype=None):
return torch.as_tensor(self._asarray(a, dtype=dtype), device=self._torch_device)

@staticmethod
def _reduce_sum(array, axes):
if not axes:
return array
return torch.sum(array, dim=axes)

@staticmethod
def _conj(array):
if isinstance(array, torch.Tensor):
return torch.conj(array)
return np.conj(array)

@staticmethod
def _scatter(indices, array, new_dimensions):

# `array` is now a torch tensor
tensor = array
new_tensor = torch.zeros(new_dimensions, dtype=tensor.dtype, device=tensor.device)
new_tensor[indices] = tensor
return new_tensor

@classmethod
def capabilities(cls):
capabilities = super().capabilities().copy()
capabilities.update(passthru_interface="torch", supports_reversible_diff=False)
return capabilities

def _get_unitary_matrix(self, unitary):
"""Return the matrix representing a unitary operation.
Args:
unitary (~.Operation): a PennyLane unitary operation
Returns:
torch.Tensor[complex]: Returns a 2D matrix representation of
the unitary in the computational basis, or, in the case of a diagonal unitary,
a 1D array representing the matrix diagonal.
"""
op_name = unitary.base_name
if op_name in self.parametric_ops:
if op_name == "MultiRZ":
mat = self.parametric_ops[op_name](
*unitary.parameters, len(unitary.wires), device=self._torch_device
)
else:
mat = self.parametric_ops[op_name](*unitary.parameters, device=self._torch_device)
if unitary.inverse:
if isinstance(unitary, DiagonalOperation):
mat = self._conj(mat)
else:
mat = self._transpose(self._conj(mat), axes=[1, 0])
return mat

if isinstance(unitary, DiagonalOperation):
return self._asarray(unitary.eigvals, dtype=self.C_DTYPE)
return self._asarray(unitary.matrix, dtype=self.C_DTYPE)

def sample_basis_states(self, number_of_states, state_probability):
"""Sample from the computational basis states based on the state
probability.
This is an auxiliary method to the ``generate_samples`` method.
Args:
number_of_states (int): the number of basis states to sample from
state_probability (torch.Tensor[float]): the computational basis probability vector
Returns:
List[int]: the sampled basis states
"""
return super().sample_basis_states(
number_of_states, state_probability.cpu().detach().numpy()
)

def _apply_operation(self, state, operation):
"""Applies operations to the input state.
Args:
state (torch.Tensor[complex]): input state
operation (~.Operation): operation to apply on the device
Returns:
torch.Tensor[complex]: output state
"""
if state.device != self._torch_device:
state = state.to(self._torch_device)
return super()._apply_operation(state, operation)
7 changes: 6 additions & 1 deletion pennylane/devices/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
# Number of shots to call the devices with
N_SHOTS = 1e6
# List of all devices that are included in PennyLane
LIST_CORE_DEVICES = {"default.qubit", "default.qubit.tf", "default.qubit.autograd"}
LIST_CORE_DEVICES = {
"default.qubit",
"default.qubit.torch",
"default.qubit.tf",
"default.qubit.autograd",
}


@pytest.fixture(scope="function")
Expand Down
Loading

0 comments on commit 40aaeb6

Please sign in to comment.