Skip to content

Commit dea1494

Browse files
authored
Add classical simulator (quantumlib#6124)
1 parent bdb646c commit dea1494

File tree

5 files changed

+327
-0
lines changed

5 files changed

+327
-0
lines changed

cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@
441441

442442
from cirq.sim import (
443443
CIRCUIT_LIKE,
444+
ClassicalStateSimulator,
444445
CliffordSimulator,
445446
CliffordState,
446447
CliffordSimulatorStepResult,

cirq/protocols/json_test_data/spec.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
'ZerosSampler',
5959
],
6060
should_not_be_serialized=[
61+
'ClassicalStateSimulator',
6162
# Heatmaps
6263
'Heatmap',
6364
'TwoQubitInteractionHeatmap',

cirq/sim/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@
6868

6969
from cirq.sim.state_vector_simulation_state import StateVectorSimulationState
7070

71+
from cirq.sim.classical_simulator import ClassicalStateSimulator
72+
7173
from cirq.sim.state_vector_simulator import (
7274
SimulatesIntermediateStateVector,
7375
StateVectorStepResult,

cirq/sim/classical_simulator.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2023 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict
16+
from collections import defaultdict
17+
from cirq.sim.simulator import SimulatesSamples
18+
from cirq import ops, protocols
19+
from cirq.study.resolver import ParamResolver
20+
from cirq.circuits.circuit import AbstractCircuit
21+
from cirq.ops.raw_types import Qid
22+
import numpy as np
23+
24+
25+
class ClassicalStateSimulator(SimulatesSamples):
26+
"""A simulator that only accepts only gates with classical counterparts.
27+
28+
This simulator evolves a single state, using only gates that output a single state for each
29+
input state. The simulator runs in linear time, at the cost of not supporting superposition.
30+
It can be used to estimate costs and simulate circuits for simple non-quantum algorithms using
31+
many more qubits than fully capable quantum simulators.
32+
33+
The supported gates are:
34+
- cirq.X
35+
- cirq.CNOT
36+
- cirq.SWAP
37+
- cirq.TOFFOLI
38+
- cirq.measure
39+
40+
Args:
41+
circuit: The circuit to simulate.
42+
param_resolver: Parameters to run with the program.
43+
repetitions: Number of times to repeat the run. It is expected that
44+
this is validated greater than zero before calling this method.
45+
46+
Returns:
47+
A dictionary mapping measurement keys to measurement results.
48+
49+
Raises:
50+
ValueError: If one of the gates is not an X, CNOT, SWAP, TOFFOLI or a measurement.
51+
"""
52+
53+
def _run(
54+
self, circuit: AbstractCircuit, param_resolver: ParamResolver, repetitions: int
55+
) -> Dict[str, np.ndarray]:
56+
results_dict: Dict[str, np.ndarray] = {}
57+
values_dict: Dict[Qid, int] = defaultdict(int)
58+
param_resolver = param_resolver or ParamResolver({})
59+
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
60+
61+
for moment in resolved_circuit:
62+
for op in moment:
63+
gate = op.gate
64+
if gate == ops.X:
65+
values_dict[op.qubits[0]] = 1 - values_dict[op.qubits[0]]
66+
67+
elif (
68+
isinstance(gate, ops.CNotPowGate)
69+
and gate.exponent == 1
70+
and gate.global_shift == 0
71+
):
72+
if values_dict[op.qubits[0]] == 1:
73+
values_dict[op.qubits[1]] = 1 - values_dict[op.qubits[1]]
74+
75+
elif (
76+
isinstance(gate, ops.SwapPowGate)
77+
and gate.exponent == 1
78+
and gate.global_shift == 0
79+
):
80+
hold_qubit = values_dict[op.qubits[1]]
81+
values_dict[op.qubits[1]] = values_dict[op.qubits[0]]
82+
values_dict[op.qubits[0]] = hold_qubit
83+
84+
elif (
85+
isinstance(gate, ops.CCXPowGate)
86+
and gate.exponent == 1
87+
and gate.global_shift == 0
88+
):
89+
if (values_dict[op.qubits[0]] == 1) and (values_dict[op.qubits[1]] == 1):
90+
values_dict[op.qubits[2]] = 1 - values_dict[op.qubits[2]]
91+
92+
elif isinstance(gate, ops.MeasurementGate):
93+
qubits_in_order = op.qubits
94+
# add the new instance of a key to the numpy array in results dictionary
95+
if gate.key in results_dict:
96+
shape = len(qubits_in_order)
97+
current_array = results_dict[gate.key]
98+
new_instance = np.zeros(shape, dtype=np.uint8)
99+
for bits in range(0, len(qubits_in_order)):
100+
new_instance[bits] = values_dict[qubits_in_order[bits]]
101+
results_dict[gate.key] = np.insert(
102+
current_array, len(current_array[0]), new_instance, axis=1
103+
)
104+
else:
105+
# create the array for the results dictionary
106+
new_array_shape = (repetitions, 1, len(qubits_in_order))
107+
new_array = np.zeros(new_array_shape, dtype=np.uint8)
108+
for reps in range(0, repetitions):
109+
for instances in range(1):
110+
for bits in range(0, len(qubits_in_order)):
111+
new_array[reps][instances][bits] = values_dict[
112+
qubits_in_order[bits]
113+
]
114+
results_dict[gate.key] = new_array
115+
116+
elif not (
117+
(isinstance(gate, ops.XPowGate) and gate.exponent == 0)
118+
or (isinstance(gate, ops.CCXPowGate) and gate.exponent == 0)
119+
or (isinstance(gate, ops.SwapPowGate) and gate.exponent == 0)
120+
or (isinstance(gate, ops.CNotPowGate) and gate.exponent == 0)
121+
):
122+
raise ValueError(
123+
"Can not simulate gates other than cirq.XGate, "
124+
+ "cirq.CNOT, cirq.SWAP, and cirq.CCNOT"
125+
)
126+
127+
return results_dict

cirq/sim/classical_simulator_test.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2023 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytest
16+
import cirq
17+
import sympy
18+
19+
20+
class TestSimulator:
21+
def test_x_gate(self):
22+
q0, q1 = cirq.LineQubit.range(2)
23+
circuit = cirq.Circuit()
24+
circuit.append(cirq.X(q0))
25+
circuit.append(cirq.X(q1))
26+
circuit.append(cirq.X(q1))
27+
circuit.append(cirq.measure((q0, q1), key='key'))
28+
expected_results = {'key': np.array([[[1, 0]]], dtype=np.uint8)}
29+
sim = cirq.ClassicalStateSimulator()
30+
results = sim.run(circuit, param_resolver=None, repetitions=1).records
31+
np.testing.assert_equal(results, expected_results)
32+
33+
def test_CNOT(self):
34+
q0, q1 = cirq.LineQubit.range(2)
35+
circuit = cirq.Circuit()
36+
circuit.append(cirq.X(q0))
37+
circuit.append(cirq.CNOT(q0, q1))
38+
circuit.append(cirq.measure(q1, key='key'))
39+
expected_results = {'key': np.array([[[1]]], dtype=np.uint8)}
40+
sim = cirq.ClassicalStateSimulator()
41+
results = sim.run(circuit, param_resolver=None, repetitions=1).records
42+
np.testing.assert_equal(results, expected_results)
43+
44+
def test_Swap(self):
45+
q0, q1 = cirq.LineQubit.range(2)
46+
circuit = cirq.Circuit()
47+
circuit.append(cirq.X(q0))
48+
circuit.append(cirq.SWAP(q0, q1))
49+
circuit.append(cirq.measure((q0, q1), key='key'))
50+
expected_results = {'key': np.array([[[0, 1]]], dtype=np.uint8)}
51+
sim = cirq.ClassicalStateSimulator()
52+
results = sim.run(circuit, param_resolver=None, repetitions=1).records
53+
np.testing.assert_equal(results, expected_results)
54+
55+
def test_CCNOT(self):
56+
q0, q1, q2 = cirq.LineQubit.range(3)
57+
circuit = cirq.Circuit()
58+
circuit.append(cirq.CCNOT(q0, q1, q2))
59+
circuit.append(cirq.measure((q0, q1, q2), key='key'))
60+
circuit.append(cirq.X(q0))
61+
circuit.append(cirq.CCNOT(q0, q1, q2))
62+
circuit.append(cirq.measure((q0, q1, q2), key='key'))
63+
circuit.append(cirq.X(q1))
64+
circuit.append(cirq.X(q0))
65+
circuit.append(cirq.CCNOT(q0, q1, q2))
66+
circuit.append(cirq.measure((q0, q1, q2), key='key'))
67+
circuit.append(cirq.X(q0))
68+
circuit.append(cirq.CCNOT(q0, q1, q2))
69+
circuit.append(cirq.measure((q0, q1, q2), key='key'))
70+
expected_results = {
71+
'key': np.array([[[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 1]]], dtype=np.uint8)
72+
}
73+
sim = cirq.ClassicalStateSimulator()
74+
results = sim.run(circuit, param_resolver=None, repetitions=1).records
75+
np.testing.assert_equal(results, expected_results)
76+
77+
def test_measurement_gate(self):
78+
q0, q1 = cirq.LineQubit.range(2)
79+
circuit = cirq.Circuit()
80+
circuit.append(cirq.measure((q0, q1), key='key'))
81+
expected_results = {'key': np.array([[[0, 0]]], dtype=np.uint8)}
82+
sim = cirq.ClassicalStateSimulator()
83+
results = sim.run(circuit, param_resolver=None, repetitions=1).records
84+
np.testing.assert_equal(results, expected_results)
85+
86+
def test_qubit_order(self):
87+
q0, q1 = cirq.LineQubit.range(2)
88+
circuit = cirq.Circuit()
89+
circuit.append(cirq.CNOT(q0, q1))
90+
circuit.append(cirq.X(q0))
91+
circuit.append(cirq.measure((q0, q1), key='key'))
92+
expected_results = {'key': np.array([[[1, 0]]], dtype=np.uint8)}
93+
sim = cirq.ClassicalStateSimulator()
94+
results = sim.run(circuit, param_resolver=None, repetitions=1).records
95+
np.testing.assert_equal(results, expected_results)
96+
97+
def test_same_key_instances(self):
98+
q0, q1 = cirq.LineQubit.range(2)
99+
circuit = cirq.Circuit()
100+
circuit.append(cirq.measure((q0, q1), key='key'))
101+
circuit.append(cirq.X(q0))
102+
circuit.append(cirq.measure((q0, q1), key='key'))
103+
expected_results = {'key': np.array([[[0, 0], [1, 0]]], dtype=np.uint8)}
104+
sim = cirq.ClassicalStateSimulator()
105+
results = sim.run(circuit, param_resolver=None, repetitions=1).records
106+
np.testing.assert_equal(results, expected_results)
107+
108+
def test_same_key_instances_order(self):
109+
q0, q1 = cirq.LineQubit.range(2)
110+
circuit = cirq.Circuit()
111+
circuit.append(cirq.X(q0))
112+
circuit.append(cirq.measure((q0, q1), key='key'))
113+
circuit.append(cirq.X(q0))
114+
circuit.append(cirq.measure((q1, q0), key='key'))
115+
expected_results = {'key': np.array([[[1, 0], [0, 0]]], dtype=np.uint8)}
116+
sim = cirq.ClassicalStateSimulator()
117+
results = sim.run(circuit, param_resolver=None, repetitions=1).records
118+
np.testing.assert_equal(results, expected_results)
119+
120+
def test_repetitions(self):
121+
q0 = cirq.LineQubit.range(1)
122+
circuit = cirq.Circuit()
123+
circuit.append(cirq.measure(q0, key='key'))
124+
expected_results = {
125+
'key': np.array(
126+
[[[0]], [[0]], [[0]], [[0]], [[0]], [[0]], [[0]], [[0]], [[0]], [[0]]],
127+
dtype=np.uint8,
128+
)
129+
}
130+
sim = cirq.ClassicalStateSimulator()
131+
results = sim.run(circuit, param_resolver=None, repetitions=10).records
132+
np.testing.assert_equal(results, expected_results)
133+
134+
def test_multiple_gates(self):
135+
q0, q1 = cirq.LineQubit.range(2)
136+
circuit = cirq.Circuit()
137+
circuit.append(cirq.X(q0))
138+
circuit.append(cirq.CNOT(q0, q1))
139+
circuit.append(cirq.CNOT(q0, q1))
140+
circuit.append(cirq.CNOT(q0, q1))
141+
circuit.append(cirq.X(q1))
142+
circuit.append(cirq.measure((q0, q1), key='key'))
143+
expected_results = {'key': np.array([[[1, 0]]], dtype=np.uint8)}
144+
sim = cirq.ClassicalStateSimulator()
145+
results = sim.run(circuit, param_resolver=None, repetitions=1).records
146+
np.testing.assert_equal(results, expected_results)
147+
148+
def test_multiple_gates_order(self):
149+
q0, q1 = cirq.LineQubit.range(2)
150+
circuit = cirq.Circuit()
151+
circuit.append(cirq.X(q0))
152+
circuit.append(cirq.CNOT(q0, q1))
153+
circuit.append(cirq.CNOT(q1, q0))
154+
circuit.append(cirq.measure((q0, q1), key='key'))
155+
expected_results = {'key': np.array([[[0, 1]]], dtype=np.uint8)}
156+
sim = cirq.ClassicalStateSimulator()
157+
results = sim.run(circuit, param_resolver=None, repetitions=1).records
158+
np.testing.assert_equal(results, expected_results)
159+
160+
def test_param_resolver(self):
161+
gate = cirq.CNOT ** sympy.Symbol('t')
162+
q0, q1 = cirq.LineQubit.range(2)
163+
circuit = cirq.Circuit()
164+
circuit.append(cirq.X(q0))
165+
circuit.append(gate(q0, q1))
166+
circuit.append(cirq.measure((q1), key='key'))
167+
resolver = cirq.ParamResolver({'t': 0})
168+
sim = cirq.ClassicalStateSimulator()
169+
results_with_parameter_zero = sim.run(
170+
circuit, param_resolver=resolver, repetitions=1
171+
).records
172+
resolver = cirq.ParamResolver({'t': 1})
173+
results_with_parameter_one = sim.run(
174+
circuit, param_resolver=resolver, repetitions=1
175+
).records
176+
np.testing.assert_equal(
177+
results_with_parameter_zero, {'key': np.array([[[0]]], dtype=np.uint8)}
178+
)
179+
np.testing.assert_equal(
180+
results_with_parameter_one, {'key': np.array([[[1]]], dtype=np.uint8)}
181+
)
182+
183+
def test_unknown_gates(self):
184+
gate = cirq.CNOT ** sympy.Symbol('t')
185+
q0, q1 = cirq.LineQubit.range(2)
186+
circuit = cirq.Circuit()
187+
circuit.append(gate(q0, q1))
188+
circuit.append(cirq.measure((q0), key='key'))
189+
resolver = cirq.ParamResolver({'t': 0.5})
190+
sim = cirq.ClassicalStateSimulator()
191+
with pytest.raises(
192+
ValueError,
193+
match="Can not simulate gates other than "
194+
+ "cirq.XGate, cirq.CNOT, cirq.SWAP, and cirq.CCNOT",
195+
):
196+
_ = sim.run(circuit, param_resolver=resolver, repetitions=1).records

0 commit comments

Comments
 (0)