-
Notifications
You must be signed in to change notification settings - Fork 586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[unitaryHACK] Create a Pytorch simulator #1225 #1360
Changes from 208 commits
bcdd7bf
a3bff30
9d0f8aa
dd762b8
de5575d
b0ed9d0
f2275d1
1e23c2d
12e455a
0e802ba
3f952b1
7e1eb51
b7f68cb
7bff1f9
d03f75f
3ab9ca0
e81fba1
ed72557
62584e2
cdeeede
5768a3f
1f60f1c
be6f3d5
5c4b79d
7fc758b
e0d7fdf
8b2e78d
5cb82a7
65c5205
6dea4fa
9174bcc
00db293
3950a54
59303c9
cf6df7e
b84c771
5c33df8
dea37a8
542d95f
cc182f3
376d448
aa1432d
ff53099
3d11745
f4877f5
8eff218
b237ea6
4fb5eeb
19b912c
246f9c0
92289be
5eac43e
8bef3cb
f042659
c0bf8b1
b5fbf55
55868be
e27b9a1
62b5141
e3c3986
a9f8b80
d25545a
e84e82b
780a7a0
aa57e0b
d6af59a
19e498e
ac75362
972b1fd
4f64eeb
21daf78
f1a127a
d23f5fc
ecfc209
9314aba
69985ef
c71d91f
d39628f
73b51e5
4302b93
22a4f51
b51ae0f
cfd4dfd
9940a47
124245e
fe6f214
393228c
e75188e
76c23b0
b045f8e
9097567
bff6b63
97ea50a
c5b83e3
93d9c88
03663cd
e2c0773
de49d0b
4a4094d
84bb566
fe65a27
8b442b8
cf16922
256aa04
b4e07e9
0db490e
946b77f
72d02a5
26421d4
f5406df
819c8a9
647a3e4
87922dd
684d0d1
83fee40
5168967
edbc0e9
be57772
db12c2b
577902d
14393b7
70b30f1
bfece02
32cfc75
2f62245
662b0e6
eab9f9d
ad74b74
28c0457
c9b1ce6
5e33e1a
8082ea1
c626508
51edc0c
c7b0851
a1793b9
7cc9b6b
984d4bd
d2432b3
97319f0
eb3cacd
a7f7f90
b0c8a23
53d5de5
45a6f2c
9b30c3b
0d0cbf5
33c1fe0
0058e6f
63f7853
4fa4cb5
a5f7f7e
0e9fba5
9110268
1185952
3320119
b5d973d
710fa6b
bd5aa4c
ecf5ada
f4c854e
62a1f16
1bef7c3
a734ad2
c616e64
2f1a154
7c7172e
bbc08cf
9c032de
9b69fc6
545cb08
a67c595
3e49056
f638f3e
d7ed317
1c05c94
eb2a49e
84e2757
77a57e9
970516e
53f92e8
78f5c6d
01a507f
7b44a6b
0bc0669
7e4fd9b
e40c155
997fcfc
4aa7a73
74710dd
915bb63
aeb4d5d
81c8de6
716a7de
c014d63
9681077
17a5793
81d31b9
ae70fd5
32f2ae7
4fe6381
d5f662e
f2cd2b2
e4e0d3e
d4e7bb3
9ce7501
ba3885a
03a8bfb
7252743
ad9a8b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,301 @@ | ||
# Copyright 2018-2021 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""This module contains a PyTorch implementation of the :class:`~.DefaultQubit` | ||
reference plugin. | ||
""" | ||
import semantic_version | ||
|
||
try: | ||
import torch | ||
|
||
VERSION_SUPPORT = semantic_version.match(">=1.8.1", torch.__version__) | ||
if not VERSION_SUPPORT: | ||
raise ImportError("default.qubit.torch device requires Torch>=1.8.1") | ||
|
||
except ImportError as e: | ||
raise ImportError("default.qubit.torch device requires Torch>=1.8.1") from e | ||
|
||
import numpy as np | ||
from pennylane.operation import DiagonalOperation | ||
from pennylane.devices import torch_ops | ||
from . import DefaultQubit | ||
|
||
|
||
class DefaultQubitTorch(DefaultQubit): | ||
"""Simulator plugin based on ``"default.qubit"``, written using PyTorch. | ||
|
||
**Short name:** ``default.qubit.torch`` | ||
|
||
This device provides a pure-state qubit simulator written using PyTorch. | ||
As a result, it supports classical backpropagation as a means to compute the Jacobian. This can | ||
be faster than the parameter-shift rule for analytic quantum gradients | ||
when the number of parameters to be optimized is large. | ||
|
||
To use this device, you will need to install PyTorch: | ||
|
||
.. code-block:: console | ||
|
||
pip install torch>=1.8.0 | ||
|
||
**Example** | ||
|
||
The ``default.qubit.torch`` is designed to be used with end-to-end classical backpropagation | ||
(``diff_method="backprop"``) and the PyTorch interface. This is the default method | ||
of differentiation when creating a QNode with this device. | ||
|
||
Using this method, the created QNode is a 'white-box', and is | ||
tightly integrated with your PyTorch computation: | ||
|
||
.. code-block:: python | ||
|
||
dev = qml.device("default.qubit.torch", wires=1) | ||
|
||
@qml.qnode(dev, interface="torch", diff_method="backprop") | ||
def circuit(x): | ||
qml.RX(x[1], wires=0) | ||
qml.Rot(x[0], x[1], x[2], wires=0) | ||
return qml.expval(qml.PauliZ(0)) | ||
|
||
>>> weights = torch.tensor([0.2, 0.5, 0.1], requires_grad=True) | ||
>>> res = circuit(weights) | ||
>>> res.backward() | ||
>>> print(weights.grad) | ||
tensor([-2.2527e-01, -1.0086e+00, 1.3878e-17]) | ||
|
||
Autograd mode will also work when using classical backpropagation: | ||
|
||
>>> def cost(weights): | ||
... return torch.sum(circuit(weights)**3) - 1 | ||
>>> res = circuit(weights) | ||
>>> res.backward() | ||
>>> print(weights.grad) | ||
tensor([-4.5053e-01, -2.0173e+00, 5.9837e-17]) | ||
|
||
Executing the pipeline in PyTorch will allow the whole computation to be run on the GPU, | ||
and therefore providing an acceleration. Your parameters need to be instantiated on the same | ||
device as the backend device. | ||
|
||
.. code-block:: python | ||
|
||
dev = qml.device("default.qubit.torch", wires=1, torch_device='cuda') | ||
|
||
@qml.qnode(dev, interface="torch", diff_method="backprop") | ||
def circuit(x): | ||
qml.RX(x[1], wires=0) | ||
qml.Rot(x[0], x[1], x[2], wires=0) | ||
return qml.expval(qml.PauliZ(0)) | ||
|
||
>>> weights = torch.tensor([0.2, 0.5, 0.1], requires_grad=True, device='cuda') | ||
>>> res = circuit(weights) | ||
>>> res.backward() | ||
>>> print(weights.grad) | ||
tensor([-2.2527e-01, -1.0086e+00, 1.3878e-17]) | ||
|
||
|
||
There are a couple of things to keep in mind when using the ``"backprop"`` | ||
differentiation method for QNodes: | ||
|
||
* You must use the ``"torch"`` interface for classical backpropagation, as PyTorch is | ||
used as the device backend. | ||
|
||
* Only exact expectation values, variances, and probabilities are differentiable. | ||
When instantiating the device with ``shots!=None``, differentiating QNode | ||
outputs will result in ``None``. | ||
|
||
If you wish to use a different machine-learning interface, or prefer to calculate quantum | ||
gradients using the ``parameter-shift`` or ``finite-diff`` differentiation methods, | ||
consider using the ``default.qubit`` device instead. | ||
|
||
Args: | ||
wires (int, Iterable): Number of subsystems represented by the device, | ||
or iterable that contains unique labels for the subsystems. Default 1 if not specified. | ||
shots (None, int): How many times the circuit should be evaluated (or sampled) to estimate | ||
the expectation values. Defaults to ``None`` if not specified, which means | ||
that the device returns analytical results. | ||
If ``shots > 0`` is used, the ``diff_method="backprop"`` | ||
QNode differentiation method is not supported and it is recommended to consider | ||
switching device to ``default.qubit`` and using ``diff_method="parameter-shift"``. | ||
torch_device='cpu' (str): the device on which the computation will be run, ``'cpu'`` or ``'cuda'`` | ||
""" | ||
|
||
name = "Default qubit (Torch) PennyLane plugin" | ||
short_name = "default.qubit.torch" | ||
|
||
parametric_ops = { | ||
"PhaseShift": torch_ops.PhaseShift, | ||
"ControlledPhaseShift": torch_ops.ControlledPhaseShift, | ||
"RX": torch_ops.RX, | ||
"RY": torch_ops.RY, | ||
"RZ": torch_ops.RZ, | ||
"MultiRZ": torch_ops.MultiRZ, | ||
"Rot": torch_ops.Rot, | ||
"CRX": torch_ops.CRX, | ||
"CRY": torch_ops.CRY, | ||
"CRZ": torch_ops.CRZ, | ||
"CRot": torch_ops.CRot, | ||
"IsingXX": torch_ops.IsingXX, | ||
"IsingYY": torch_ops.IsingYY, | ||
"IsingZZ": torch_ops.IsingZZ, | ||
"SingleExcitation": torch_ops.SingleExcitation, | ||
"SingleExcitationPlus": torch_ops.SingleExcitationPlus, | ||
"SingleExcitationMinus": torch_ops.SingleExcitationMinus, | ||
"DoubleExcitation": torch_ops.DoubleExcitation, | ||
"DoubleExcitationPlus": torch_ops.DoubleExcitationPlus, | ||
"DoubleExcitationMinus": torch_ops.DoubleExcitationMinus, | ||
} | ||
|
||
C_DTYPE = torch.complex128 | ||
R_DTYPE = torch.float64 | ||
|
||
_abs = staticmethod(torch.abs) | ||
_einsum = staticmethod(torch.einsum) | ||
_flatten = staticmethod(torch.flatten) | ||
_reshape = staticmethod(torch.reshape) | ||
_roll = staticmethod(torch.roll) | ||
_stack = staticmethod(lambda arrs, axis=0, out=None: torch.stack(arrs, axis=axis, out=out)) | ||
_tensordot = staticmethod( | ||
lambda a, b, axes: torch.tensordot( | ||
a, b, axes if isinstance(axes, int) else tuple(map(list, axes)) | ||
) | ||
) | ||
_transpose = staticmethod(lambda a, axes=None: a.permute(*axes)) | ||
_asnumpy = staticmethod(lambda x: x.cpu().numpy()) | ||
_conj = staticmethod(torch.conj) | ||
_imag = staticmethod(torch.imag) | ||
_norm = staticmethod(torch.norm) | ||
_flatten = staticmethod(torch.flatten) | ||
|
||
def __init__(self, wires, *, shots=None, analytic=None, torch_device="cpu"): | ||
self._torch_device = torch_device | ||
super().__init__(wires, shots=shots, cache=0, analytic=analytic) | ||
|
||
# Move state to torch device (e.g. CPU, GPU, XLA, ...) | ||
self._state.requires_grad = True | ||
self._state = self._state.to(self._torch_device) | ||
self._pre_rotated_state = self._state | ||
|
||
@staticmethod | ||
def _asarray(a, dtype=None): | ||
if isinstance(a, list): | ||
# Handle unexpected cases where we don't have a list of tensors | ||
if not isinstance(a[0], torch.Tensor): | ||
res = np.asarray(a) | ||
res = torch.from_numpy(res) | ||
else: | ||
res = torch.cat([torch.reshape(i, (-1,)) for i in a], dim=0) | ||
res = torch.cat([torch.reshape(i, (-1,)) for i in res], dim=0) | ||
else: | ||
res = torch.as_tensor(a, dtype=dtype) | ||
return res | ||
|
||
@staticmethod | ||
def _dot(x, y): | ||
if x.device != y.device: | ||
if x.device != "cpu": | ||
return torch.tensordot(x, y.to(x.device), dims=1) | ||
if y.device != "cpu": | ||
return torch.tensordot(x.to(y.device), y, dims=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a GPU device fix, so we can't test it without having a GPU test. |
||
|
||
return torch.tensordot(x, y, dims=1) | ||
|
||
def _cast(self, a, dtype=None): | ||
return torch.as_tensor(self._asarray(a, dtype=dtype), device=self._torch_device) | ||
|
||
@staticmethod | ||
def _reduce_sum(array, axes): | ||
if not axes: | ||
return array | ||
return torch.sum(array, dim=axes) | ||
|
||
@staticmethod | ||
def _conj(array): | ||
if isinstance(array, torch.Tensor): | ||
return torch.conj(array) | ||
return np.conj(array) | ||
|
||
@staticmethod | ||
def _scatter(indices, array, new_dimensions): | ||
|
||
# `array` is now a torch tensor | ||
tensor = array | ||
new_tensor = torch.zeros(new_dimensions, dtype=tensor.dtype, device=tensor.device) | ||
new_tensor[indices] = tensor | ||
return new_tensor | ||
|
||
@classmethod | ||
def capabilities(cls): | ||
capabilities = super().capabilities().copy() | ||
capabilities.update(passthru_interface="torch", supports_reversible_diff=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return capabilities | ||
|
||
def _get_unitary_matrix(self, unitary): | ||
"""Return the matrix representing a unitary operation. | ||
|
||
Args: | ||
unitary (~.Operation): a PennyLane unitary operation | ||
|
||
Returns: | ||
torch.Tensor[complex]: Returns a 2D matrix representation of | ||
the unitary in the computational basis, or, in the case of a diagonal unitary, | ||
a 1D array representing the matrix diagonal. | ||
""" | ||
op_name = unitary.base_name | ||
if op_name in self.parametric_ops: | ||
if op_name == "MultiRZ": | ||
mat = self.parametric_ops[op_name]( | ||
*unitary.parameters, len(unitary.wires), device=self._torch_device | ||
) | ||
else: | ||
mat = self.parametric_ops[op_name](*unitary.parameters, device=self._torch_device) | ||
if unitary.inverse: | ||
if isinstance(unitary, DiagonalOperation): | ||
mat = self._conj(mat) | ||
else: | ||
mat = self._transpose(self._conj(mat), axes=[1, 0]) | ||
return mat | ||
|
||
if isinstance(unitary, DiagonalOperation): | ||
return self._asarray(unitary.eigvals, dtype=self.C_DTYPE) | ||
return self._asarray(unitary.matrix, dtype=self.C_DTYPE) | ||
|
||
def sample_basis_states(self, number_of_states, state_probability): | ||
"""Sample from the computational basis states based on the state | ||
probability. | ||
|
||
This is an auxiliary method to the ``generate_samples`` method. | ||
|
||
Args: | ||
number_of_states (int): the number of basis states to sample from | ||
state_probability (torch.Tensor[float]): the computational basis probability vector | ||
|
||
Returns: | ||
List[int]: the sampled basis states | ||
""" | ||
return super().sample_basis_states( | ||
number_of_states, state_probability.cpu().detach().numpy() | ||
) | ||
|
||
def _apply_operation(self, state, operation): | ||
"""Applies operations to the input state. | ||
|
||
Args: | ||
state (torch.Tensor[complex]): input state | ||
operation (~.Operation): operation to apply on the device | ||
|
||
Returns: | ||
torch.Tensor[complex]: output state | ||
""" | ||
if state.device != self._torch_device: | ||
state = state.to(self._torch_device) | ||
return super()._apply_operation(state, operation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @PCesteban, just wanted to make sure that the suggested updates in
aligned well with the intentions for this method. We've noticed that previously the device might output
numpy
arrays, instead oftorch
tensors (hence the suggestions from before). Let us know if we've missed a use case here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @antalszava, I don't think there is any missed use case to the best of my knowledge.
torch
tensors are a better fit for this method.FYI
@Slimane33 @arshpreetsingh