Skip to content

Commit

Permalink
added weights initialization example (#2735)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qottmann authored Jun 16, 2022
1 parent 1d45127 commit 1bb564d
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions pennylane/qnn/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ def qnode(inputs, weights_0, weight_1):
**Initializing weights**
The optional ``init_method`` argument of :class:`~.TorchLayer` allows for the initialization
If ``init_method`` is not specified, weights are randomly initialized from the uniform
distribution on the interval :math:`[0, 2 \pi]`.
Alternative a): The optional ``init_method`` argument of :class:`~.TorchLayer` allows for the initialization
method of the QNode weights to be specified. The function passed to the argument must be
from the `torch.nn.init <https://pytorch.org/docs/stable/nn.init.html>`__ module. For
example, weights can be randomly initialized from the normal distribution by passing:
Expand All @@ -113,8 +116,39 @@ def qnode(inputs, weights_0, weight_1):
init_method = torch.nn.init.normal_
If ``init_method`` is not specified, weights are randomly initialized from the uniform
distribution on the interval :math:`[0, 2 \pi]`.
Alternative b): Two dictionaries ``weight_shapes`` and ``init_method`` are passed, whose ``keys`` match the ``args`` of the qnode.
.. code-block::
@qml.qnode(dev)
def qnode(inputs, weights_0, weights_1, weights_2, weight_3, weight_4):
qml.templates.AngleEmbedding(inputs, wires=range(n_qubits))
qml.templates.StronglyEntanglingLayers(weights_0, wires=range(n_qubits))
qml.templates.BasicEntanglerLayers(weights_1, wires=range(n_qubits))
qml.Rot(*weights_2, wires=0)
qml.RY(weight_3, wires=1)
qml.RZ(weight_4, wires=1)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))
weight_shapes = {
"weights_0": (3, n_qubits, 3),
"weights_1": (3, n_qubits),
"weights_2": 3,
"weight_3": 1,
"weight_4": (1,),
}
init_method = {
"weights_0": torch.nn.init.normal_,
"weights_1": torch.nn.init.uniform,
"weights_2": torch.tensor([1., 2., 3.]),
"weight_3": torch.tensor(1.), # scalar when shape is not an iterable and is <= 1
"weight_4": torch.tensor([1.]),
}
qlayer = qml.qnn.TorchLayer(qnode, weight_shapes=weight_shapes, init_method=init_method)
**Full code example**
Expand Down

0 comments on commit 1bb564d

Please sign in to comment.