Skip to content

Commit a2425eb

Browse files
author
Shef
authored
Update ClassicalSimulator to confirm to simulation abstraction (#6432)
1 parent 7780c01 commit a2425eb

File tree

2 files changed

+308
-87
lines changed

2 files changed

+308
-87
lines changed

Diff for: cirq-core/cirq/sim/classical_simulator.py

+211-87
Original file line numberDiff line numberDiff line change
@@ -12,96 +12,220 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict
16-
from collections import defaultdict
17-
from cirq.sim.simulator import SimulatesSamples
18-
from cirq import ops, protocols
19-
from cirq.study.resolver import ParamResolver
20-
from cirq.circuits.circuit import AbstractCircuit
21-
from cirq.ops.raw_types import Qid
15+
16+
from typing import Dict, Generic, Any, Sequence, List, Optional, Union, TYPE_CHECKING
17+
from copy import deepcopy, copy
18+
from cirq import ops, qis
19+
from cirq.value import big_endian_int_to_bits
20+
from cirq import sim
21+
from cirq.sim.simulation_state import TSimulationState, SimulationState
2222
import numpy as np
2323

24+
if TYPE_CHECKING:
25+
import cirq
26+
2427

25-
def _is_identity(op: ops.Operation) -> bool:
26-
if isinstance(op.gate, (ops.XPowGate, ops.CXPowGate, ops.CCXPowGate, ops.SwapPowGate)):
27-
return op.gate.exponent % 2 == 0
28+
def _is_identity(action) -> bool:
29+
"""Check if the given action is equivalent to an identity."""
30+
gate = action.gate if isinstance(action, ops.Operation) else action
31+
if isinstance(gate, (ops.XPowGate, ops.CXPowGate, ops.CCXPowGate, ops.SwapPowGate)):
32+
return gate.exponent % 2 == 0
2833
return False
2934

3035

31-
class ClassicalStateSimulator(SimulatesSamples):
32-
"""A simulator that accepts only gates with classical counterparts.
33-
34-
This simulator evolves a single state, using only gates that output a single state for each
35-
input state. The simulator runs in linear time, at the cost of not supporting superposition.
36-
It can be used to estimate costs and simulate circuits for simple non-quantum algorithms using
37-
many more qubits than fully capable quantum simulators.
38-
39-
The supported gates are:
40-
- cirq.X
41-
- cirq.CNOT
42-
- cirq.SWAP
43-
- cirq.TOFFOLI
44-
- cirq.measure
45-
46-
Args:
47-
circuit: The circuit to simulate.
48-
param_resolver: Parameters to run with the program.
49-
repetitions: Number of times to repeat the run. It is expected that
50-
this is validated greater than zero before calling this method.
51-
52-
Returns:
53-
A dictionary mapping measurement keys to measurement results.
54-
55-
Raises:
56-
ValueError: If
57-
- one of the gates is not an X, CNOT, SWAP, TOFFOLI or a measurement.
58-
- A measurement key is used for measurements on different numbers of qubits.
59-
"""
60-
61-
def _run(
62-
self, circuit: AbstractCircuit, param_resolver: ParamResolver, repetitions: int
63-
) -> Dict[str, np.ndarray]:
64-
results_dict: Dict[str, np.ndarray] = {}
65-
values_dict: Dict[Qid, int] = defaultdict(int)
66-
param_resolver = param_resolver or ParamResolver({})
67-
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
68-
69-
for moment in resolved_circuit:
70-
for op in moment:
71-
if _is_identity(op):
72-
continue
73-
if op.gate == ops.X:
74-
(q,) = op.qubits
75-
values_dict[q] ^= 1
76-
elif op.gate == ops.CNOT:
77-
c, q = op.qubits
78-
values_dict[q] ^= values_dict[c]
79-
elif op.gate == ops.SWAP:
80-
a, b = op.qubits
81-
values_dict[a], values_dict[b] = values_dict[b], values_dict[a]
82-
elif op.gate == ops.TOFFOLI:
83-
c1, c2, q = op.qubits
84-
values_dict[q] ^= values_dict[c1] & values_dict[c2]
85-
elif protocols.is_measurement(op):
86-
measurement_values = np.array(
87-
[[[values_dict[q] for q in op.qubits]]] * repetitions, dtype=np.uint8
88-
)
89-
key = op.gate.key # type: ignore
90-
if key in results_dict:
91-
if op._num_qubits_() != results_dict[key].shape[-1]:
92-
raise ValueError(
93-
f'Measurement shape {len(measurement_values)} does not match '
94-
f'{results_dict[key].shape[-1]} in {key}.'
95-
)
96-
results_dict[key] = np.concatenate(
97-
(results_dict[key], measurement_values), axis=1
98-
)
99-
else:
100-
results_dict[key] = measurement_values
101-
else:
102-
raise ValueError(
103-
f'{op} is not one of cirq.X, cirq.CNOT, cirq.SWAP, '
104-
'cirq.CCNOT, or a measurement'
105-
)
106-
107-
return results_dict
36+
class ClassicalBasisState(qis.QuantumStateRepresentation):
37+
"""Represents a classical basis state for efficient state evolution."""
38+
39+
def __init__(self, initial_state: Union[List[int], np.ndarray]):
40+
"""Initializes the ClassicalBasisState object.
41+
42+
Args:
43+
initial_state: The initial state in the computational basis.
44+
"""
45+
self.basis = initial_state
46+
47+
def copy(self, deep_copy_buffers: bool = True) -> 'ClassicalBasisState':
48+
"""Creates a copy of the ClassicalBasisState object.
49+
50+
Args:
51+
deep_copy_buffers: Whether to deep copy the internal buffers.
52+
Returns:
53+
A copy of the ClassicalBasisState object.
54+
"""
55+
return ClassicalBasisState(
56+
initial_state=deepcopy(self.basis) if deep_copy_buffers else copy(self.basis)
57+
)
58+
59+
def measure(
60+
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
61+
) -> List[int]:
62+
"""Measures the density matrix.
63+
64+
Args:
65+
axes: The axes to measure.
66+
seed: The random number seed to use.
67+
Returns:
68+
The measurements in order.
69+
"""
70+
return [self.basis[i] for i in axes]
71+
72+
73+
class ClassicalBasisSimState(SimulationState[ClassicalBasisState]):
74+
"""Represents the state of a quantum simulation using classical basis states."""
75+
76+
def __init__(
77+
self,
78+
initial_state: Union[int, List[int]] = 0,
79+
qubits: Optional[Sequence['cirq.Qid']] = None,
80+
classical_data: Optional['cirq.ClassicalDataStore'] = None,
81+
):
82+
"""Initializes the ClassicalBasisSimState object.
83+
84+
Args:
85+
qubits: The qubits to simulate.
86+
initial_state: The initial state for the simulation.
87+
classical_data: The classical data container for the simulation.
88+
89+
Raises:
90+
ValueError: If qubits not provided and initial_state is int.
91+
If initial_state is not an int, List[int], or np.ndarray.
92+
93+
An initial_state value of type integer is parsed in big endian order.
94+
"""
95+
if isinstance(initial_state, int):
96+
if qubits is None:
97+
raise ValueError('qubits must be provided if initial_state is not List[int]')
98+
state = ClassicalBasisState(
99+
big_endian_int_to_bits(initial_state, bit_count=len(qubits))
100+
)
101+
elif isinstance(initial_state, (list, np.ndarray)):
102+
state = ClassicalBasisState(initial_state)
103+
else:
104+
raise ValueError('initial_state must be an int or List[int] or np.ndarray')
105+
super().__init__(state=state, qubits=qubits, classical_data=classical_data)
106+
107+
def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True):
108+
"""Acts on the state with a given operation.
109+
110+
Args:
111+
action: The operation to apply.
112+
qubits: The qubits to apply the operation to.
113+
allow_decompose: Whether to allow decomposition of the operation.
114+
115+
Returns:
116+
True if the operation was applied successfully.
117+
118+
Raises:
119+
ValueError: If initial_state shape for type np.ndarray is not equal to 1.
120+
If gate is not one of X, CNOT, SWAP, CCNOT, or a measurement.
121+
"""
122+
if isinstance(self._state.basis, np.ndarray) and len(self._state.basis.shape) != 1:
123+
raise ValueError('initial_state shape for type np.ndarray is not equal to 1')
124+
gate = action.gate if isinstance(action, ops.Operation) else action
125+
mapped_qubits = [self.qubit_map[i] for i in qubits]
126+
if _is_identity(gate):
127+
pass
128+
elif gate == ops.X:
129+
(q,) = mapped_qubits
130+
self._state.basis[q] ^= 1
131+
elif gate == ops.CNOT:
132+
c, q = mapped_qubits
133+
self._state.basis[q] ^= self._state.basis[c]
134+
elif gate == ops.SWAP:
135+
a, b = mapped_qubits
136+
self._state.basis[a], self._state.basis[b] = self._state.basis[b], self._state.basis[a]
137+
elif gate == ops.TOFFOLI:
138+
c1, c2, q = mapped_qubits
139+
self._state.basis[q] ^= self._state.basis[c1] & self._state.basis[c2]
140+
else:
141+
raise ValueError(f'{gate} is not one of X, CNOT, SWAP, CCNOT, or a measurement')
142+
return True
143+
144+
145+
class ClassicalStateStepResult(
146+
sim.StepResultBase['ClassicalBasisSimState'], Generic[TSimulationState]
147+
):
148+
"""The step result provided by `ClassicalStateSimulator.simulate_moment_steps`."""
149+
150+
151+
class ClassicalStateTrialResult(
152+
sim.SimulationTrialResultBase['ClassicalBasisSimState'], Generic[TSimulationState]
153+
):
154+
"""The trial result provided by `ClassicalStateSimulator.simulate`."""
155+
156+
157+
class ClassicalStateSimulator(
158+
sim.SimulatorBase[
159+
ClassicalStateStepResult['ClassicalBasisSimState'],
160+
ClassicalStateTrialResult['ClassicalBasisSimState'],
161+
'ClassicalBasisSimState',
162+
],
163+
Generic[TSimulationState],
164+
):
165+
"""A simulator that accepts only gates with classical counterparts."""
166+
167+
def __init__(
168+
self, *, noise: 'cirq.NOISE_MODEL_LIKE' = None, split_untangled_states: bool = False
169+
):
170+
"""Initializes a ClassicalStateSimulator.
171+
172+
Args:
173+
noise: The noise model used by the simulator.
174+
split_untangled_states: Whether to run the simulation as a product state.
175+
176+
Raises:
177+
ValueError: If noise_model is not None.
178+
"""
179+
if noise is not None:
180+
raise ValueError(f'{noise=} is not supported')
181+
super().__init__(noise=noise, split_untangled_states=split_untangled_states)
182+
183+
def _create_simulator_trial_result(
184+
self,
185+
params: 'cirq.ParamResolver',
186+
measurements: Dict[str, np.ndarray],
187+
final_simulator_state: 'cirq.SimulationStateBase[ClassicalBasisSimState]',
188+
) -> 'ClassicalStateTrialResult[ClassicalBasisSimState]':
189+
"""Creates a trial result for the simulator.
190+
191+
Args:
192+
params: The parameter resolver for the simulation.
193+
measurements: The measurement results.
194+
final_simulator_state: The final state of the simulator.
195+
Returns:
196+
A trial result for the simulator.
197+
"""
198+
return ClassicalStateTrialResult(
199+
params, measurements, final_simulator_state=final_simulator_state
200+
)
201+
202+
def _create_step_result(
203+
self, sim_state: 'cirq.SimulationStateBase[ClassicalBasisSimState]'
204+
) -> 'ClassicalStateStepResult[ClassicalBasisSimState]':
205+
"""Creates a step result for the simulator.
206+
207+
Args:
208+
sim_state: The current state of the simulator.
209+
Returns:
210+
A step result for the simulator.
211+
"""
212+
return ClassicalStateStepResult(sim_state)
213+
214+
def _create_partial_simulation_state(
215+
self,
216+
initial_state: Any,
217+
qubits: Sequence['cirq.Qid'],
218+
classical_data: 'cirq.ClassicalDataStore',
219+
) -> 'ClassicalBasisSimState':
220+
"""Creates a partial simulation state for the simulator.
221+
222+
Args:
223+
initial_state: The initial state for the simulation.
224+
qubits: The qubits associated with the state.
225+
classical_data: The shared classical data container for this simulation.
226+
Returns:
227+
A partial simulation state.
228+
"""
229+
return ClassicalBasisSimState(
230+
initial_state=initial_state, qubits=qubits, classical_data=classical_data
231+
)

0 commit comments

Comments
 (0)