Skip to content

Commit ee4d702

Browse files
authored
enable simulation of controlled gates in classical simulator (#6589)
1 parent 528b2d2 commit ee4d702

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

cirq-core/cirq/sim/classical_simulator.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,25 @@ def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompos
117117
118118
Raises:
119119
ValueError: If initial_state shape for type np.ndarray is not equal to 1.
120-
If gate is not one of X, CNOT, SWAP, CCNOT, or a measurement.
120+
If gate is not one of X, SWAP, a controlled version of X or SWAP,
121+
or a measurement.
121122
"""
122123
if isinstance(self._state.basis, np.ndarray) and len(self._state.basis.shape) != 1:
123124
raise ValueError('initial_state shape for type np.ndarray is not equal to 1')
124125
gate = action.gate if isinstance(action, ops.Operation) else action
125126
mapped_qubits = [self.qubit_map[i] for i in qubits]
127+
128+
if isinstance(gate, ops.ControlledGate):
129+
control_qubits = mapped_qubits[: gate.num_controls()]
130+
mapped_qubits = mapped_qubits[gate.num_controls() :]
131+
132+
controls_state = tuple(self._state.basis[c] for c in control_qubits)
133+
if controls_state not in gate.control_values.expand():
134+
# gate has no effect; controls were off
135+
return True
136+
137+
gate = gate.sub_gate
138+
126139
if _is_identity(gate):
127140
pass
128141
elif gate == ops.X:
@@ -138,7 +151,10 @@ def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompos
138151
c1, c2, q = mapped_qubits
139152
self._state.basis[q] ^= self._state.basis[c1] & self._state.basis[c2]
140153
else:
141-
raise ValueError(f'{gate} is not one of X, CNOT, SWAP, CCNOT, or a measurement')
154+
raise ValueError(
155+
f'{gate} is not one of X, SWAP; a controlled version '
156+
'of X or SWAP; or a measurement'
157+
)
142158
return True
143159

144160

cirq-core/cirq/sim/classical_simulator_test.py

+38
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from itertools import product
1516
import numpy as np
1617
import pytest
1718
import cirq
@@ -78,6 +79,43 @@ def test_CCNOT():
7879
np.testing.assert_equal(results, expected_results)
7980

8081

82+
@pytest.mark.parametrize(['initial_state'], [(list(x),) for x in product([0, 1], repeat=4)])
83+
def test_CCCX(initial_state):
84+
CCCX = cirq.CCNOT.controlled()
85+
qubits = cirq.LineQubit.range(4)
86+
87+
circuit = cirq.Circuit()
88+
circuit.append(CCCX(*qubits))
89+
circuit.append(cirq.measure(qubits, key='key'))
90+
91+
final_state = initial_state.copy()
92+
final_state[-1] ^= all(final_state[:-1])
93+
94+
sim = cirq.ClassicalStateSimulator()
95+
results = sim.simulate(circuit, initial_state=initial_state).measurements['key']
96+
np.testing.assert_equal(results, final_state)
97+
98+
99+
@pytest.mark.parametrize(['initial_state'], [(list(x),) for x in product([0, 1], repeat=3)])
100+
def test_CSWAP(initial_state):
101+
CSWAP = cirq.SWAP.controlled()
102+
qubits = cirq.LineQubit.range(3)
103+
circuit = cirq.Circuit()
104+
105+
circuit = cirq.Circuit()
106+
circuit.append(CSWAP(*qubits))
107+
circuit.append(cirq.measure(qubits, key='key'))
108+
109+
a, b, c = initial_state
110+
if a:
111+
b, c = c, b
112+
final_state = [a, b, c]
113+
114+
sim = cirq.ClassicalStateSimulator()
115+
results = sim.simulate(circuit, initial_state=initial_state).measurements['key']
116+
np.testing.assert_equal(results, final_state)
117+
118+
81119
def test_measurement_gate():
82120
q0, q1 = cirq.LineQubit.range(2)
83121
circuit = cirq.Circuit()

0 commit comments

Comments
 (0)