Skip to content

Commit 6aa9d0d

Browse files
Cleanup classical simulator code and fix a couple of bugs (#6344)
1 parent a55f962 commit 6aa9d0d

File tree

2 files changed

+227
-236
lines changed

2 files changed

+227
-236
lines changed

cirq-core/cirq/sim/classical_simulator.py

+41-61
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,14 @@
2222
import numpy as np
2323

2424

25+
def _is_identity(op: ops.Operation) -> bool:
26+
if isinstance(op.gate, (ops.XPowGate, ops.CXPowGate, ops.CCXPowGate, ops.SwapPowGate)):
27+
return op.gate.exponent % 2 == 0
28+
return False
29+
30+
2531
class ClassicalStateSimulator(SimulatesSamples):
26-
"""A simulator that only accepts only gates with classical counterparts.
32+
"""A simulator that accepts only gates with classical counterparts.
2733
2834
This simulator evolves a single state, using only gates that output a single state for each
2935
input state. The simulator runs in linear time, at the cost of not supporting superposition.
@@ -47,7 +53,9 @@ class ClassicalStateSimulator(SimulatesSamples):
4753
A dictionary mapping measurement keys to measurement results.
4854
4955
Raises:
50-
ValueError: If one of the gates is not an X, CNOT, SWAP, TOFFOLI or a measurement.
56+
ValueError: If
57+
- one of the gates is not an X, CNOT, SWAP, TOFFOLI or a measurement.
58+
- A measurement key is used for measurements on different numbers of qubits.
5159
"""
5260

5361
def _run(
@@ -60,68 +68,40 @@ def _run(
6068

6169
for moment in resolved_circuit:
6270
for op in moment:
63-
gate = op.gate
64-
if gate == ops.X:
65-
values_dict[op.qubits[0]] = 1 - values_dict[op.qubits[0]]
66-
67-
elif (
68-
isinstance(gate, ops.CNotPowGate)
69-
and gate.exponent == 1
70-
and gate.global_shift == 0
71-
):
72-
if values_dict[op.qubits[0]] == 1:
73-
values_dict[op.qubits[1]] = 1 - values_dict[op.qubits[1]]
74-
75-
elif (
76-
isinstance(gate, ops.SwapPowGate)
77-
and gate.exponent == 1
78-
and gate.global_shift == 0
79-
):
80-
hold_qubit = values_dict[op.qubits[1]]
81-
values_dict[op.qubits[1]] = values_dict[op.qubits[0]]
82-
values_dict[op.qubits[0]] = hold_qubit
83-
84-
elif (
85-
isinstance(gate, ops.CCXPowGate)
86-
and gate.exponent == 1
87-
and gate.global_shift == 0
88-
):
89-
if (values_dict[op.qubits[0]] == 1) and (values_dict[op.qubits[1]] == 1):
90-
values_dict[op.qubits[2]] = 1 - values_dict[op.qubits[2]]
91-
92-
elif isinstance(gate, ops.MeasurementGate):
93-
qubits_in_order = op.qubits
94-
# add the new instance of a key to the numpy array in results dictionary
95-
if gate.key in results_dict:
96-
shape = len(qubits_in_order)
97-
current_array = results_dict[gate.key]
98-
new_instance = np.zeros(shape, dtype=np.uint8)
99-
for bits in range(0, len(qubits_in_order)):
100-
new_instance[bits] = values_dict[qubits_in_order[bits]]
101-
results_dict[gate.key] = np.insert(
102-
current_array, len(current_array[0]), new_instance, axis=1
71+
if _is_identity(op):
72+
continue
73+
if op.gate == ops.X:
74+
(q,) = op.qubits
75+
values_dict[q] ^= 1
76+
elif op.gate == ops.CNOT:
77+
c, q = op.qubits
78+
values_dict[q] ^= values_dict[c]
79+
elif op.gate == ops.SWAP:
80+
a, b = op.qubits
81+
values_dict[a], values_dict[b] = values_dict[b], values_dict[a]
82+
elif op.gate == ops.TOFFOLI:
83+
c1, c2, q = op.qubits
84+
values_dict[q] ^= values_dict[c1] & values_dict[c2]
85+
elif protocols.is_measurement(op):
86+
measurement_values = np.array(
87+
[[[values_dict[q] for q in op.qubits]]] * repetitions, dtype=np.uint8
88+
)
89+
key = op.gate.key # type: ignore
90+
if key in results_dict:
91+
if op._num_qubits_() != results_dict[key].shape[-1]:
92+
raise ValueError(
93+
f'Measurement shape {len(measurement_values)} does not match '
94+
f'{results_dict[key].shape[-1]} in {key}.'
10395
)
96+
results_dict[key] = np.concatenate(
97+
(results_dict[key], measurement_values), axis=1
98+
)
10499
else:
105-
# create the array for the results dictionary
106-
new_array_shape = (repetitions, 1, len(qubits_in_order))
107-
new_array = np.zeros(new_array_shape, dtype=np.uint8)
108-
for reps in range(0, repetitions):
109-
for instances in range(1):
110-
for bits in range(0, len(qubits_in_order)):
111-
new_array[reps][instances][bits] = values_dict[
112-
qubits_in_order[bits]
113-
]
114-
results_dict[gate.key] = new_array
115-
116-
elif not (
117-
(isinstance(gate, ops.XPowGate) and gate.exponent == 0)
118-
or (isinstance(gate, ops.CCXPowGate) and gate.exponent == 0)
119-
or (isinstance(gate, ops.SwapPowGate) and gate.exponent == 0)
120-
or (isinstance(gate, ops.CNotPowGate) and gate.exponent == 0)
121-
):
100+
results_dict[key] = measurement_values
101+
else:
122102
raise ValueError(
123-
"Can not simulate gates other than cirq.XGate, "
124-
+ "cirq.CNOT, cirq.SWAP, and cirq.CCNOT"
103+
f'{op} is not one of cirq.X, cirq.CNOT, cirq.SWAP, '
104+
'cirq.CCNOT, or a measurement'
125105
)
126106

127107
return results_dict

0 commit comments

Comments
 (0)