Skip to content

Commit

Permalink
enable simulation of controlled gates in classical simulator (#6589)
Browse files Browse the repository at this point in the history
  • Loading branch information
GregDMeyer authored May 23, 2024
1 parent 528b2d2 commit ee4d702
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
20 changes: 18 additions & 2 deletions cirq-core/cirq/sim/classical_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,25 @@ def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompos
Raises:
ValueError: If initial_state shape for type np.ndarray is not equal to 1.
If gate is not one of X, CNOT, SWAP, CCNOT, or a measurement.
If gate is not one of X, SWAP, a controlled version of X or SWAP,
or a measurement.
"""
if isinstance(self._state.basis, np.ndarray) and len(self._state.basis.shape) != 1:
raise ValueError('initial_state shape for type np.ndarray is not equal to 1')
gate = action.gate if isinstance(action, ops.Operation) else action
mapped_qubits = [self.qubit_map[i] for i in qubits]

if isinstance(gate, ops.ControlledGate):
control_qubits = mapped_qubits[: gate.num_controls()]
mapped_qubits = mapped_qubits[gate.num_controls() :]

controls_state = tuple(self._state.basis[c] for c in control_qubits)
if controls_state not in gate.control_values.expand():
# gate has no effect; controls were off
return True

gate = gate.sub_gate

if _is_identity(gate):
pass
elif gate == ops.X:
Expand All @@ -138,7 +151,10 @@ def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompos
c1, c2, q = mapped_qubits
self._state.basis[q] ^= self._state.basis[c1] & self._state.basis[c2]
else:
raise ValueError(f'{gate} is not one of X, CNOT, SWAP, CCNOT, or a measurement')
raise ValueError(
f'{gate} is not one of X, SWAP; a controlled version '
'of X or SWAP; or a measurement'
)
return True


Expand Down
38 changes: 38 additions & 0 deletions cirq-core/cirq/sim/classical_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from itertools import product
import numpy as np
import pytest
import cirq
Expand Down Expand Up @@ -78,6 +79,43 @@ def test_CCNOT():
np.testing.assert_equal(results, expected_results)


@pytest.mark.parametrize(['initial_state'], [(list(x),) for x in product([0, 1], repeat=4)])
def test_CCCX(initial_state):
CCCX = cirq.CCNOT.controlled()
qubits = cirq.LineQubit.range(4)

circuit = cirq.Circuit()
circuit.append(CCCX(*qubits))
circuit.append(cirq.measure(qubits, key='key'))

final_state = initial_state.copy()
final_state[-1] ^= all(final_state[:-1])

sim = cirq.ClassicalStateSimulator()
results = sim.simulate(circuit, initial_state=initial_state).measurements['key']
np.testing.assert_equal(results, final_state)


@pytest.mark.parametrize(['initial_state'], [(list(x),) for x in product([0, 1], repeat=3)])
def test_CSWAP(initial_state):
CSWAP = cirq.SWAP.controlled()
qubits = cirq.LineQubit.range(3)
circuit = cirq.Circuit()

circuit = cirq.Circuit()
circuit.append(CSWAP(*qubits))
circuit.append(cirq.measure(qubits, key='key'))

a, b, c = initial_state
if a:
b, c = c, b
final_state = [a, b, c]

sim = cirq.ClassicalStateSimulator()
results = sim.simulate(circuit, initial_state=initial_state).measurements['key']
np.testing.assert_equal(results, final_state)


def test_measurement_gate():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit()
Expand Down

0 comments on commit ee4d702

Please sign in to comment.