Skip to content

Commit a0d0b9b

Browse files
Permit 2D _run output for backwards compatibility. (quantumlib#5014)
Fixes quantumlib#5000. This PR reinstates support for 2D measurement data from `_run`, but logs a warning if that path is used. External simulators will need to modify their `_run` implementation before v0.15.
1 parent b86f75a commit a0d0b9b

File tree

2 files changed

+88
-65
lines changed

2 files changed

+88
-65
lines changed

cirq/sim/simulator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
TypeVar,
4545
Union,
4646
)
47+
import warnings
4748

4849
import numpy as np
4950

@@ -110,6 +111,20 @@ def run_sweep_iter(
110111
records = self._run(
111112
circuit=program, param_resolver=param_resolver, repetitions=repetitions
112113
)
114+
flat_records = False
115+
for k, v in records.items():
116+
if v.ndim == 2:
117+
flat_records = True
118+
records[k] = v.reshape((v.shape[0], 1, v.shape[1]))
119+
if flat_records:
120+
warnings.warn(
121+
(
122+
'Starting in Cirq v0.15, values in the output of simulator._run must '
123+
'be 3D instead of 2D, with a new dimension between the existing two '
124+
'to capture "instances" of a key.'
125+
),
126+
DeprecationWarning,
127+
)
113128
yield study.ResultDict(params=param_resolver, records=records)
114129

115130
@abc.abstractmethod

cirq/sim/simulator_test.py

Lines changed: 73 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,40 @@
2929
SimulatesExpectationValues,
3030
SimulatesFinalState,
3131
SimulatesIntermediateState,
32+
SimulatesSamples,
3233
SimulationTrialResult,
3334
TActOnArgs,
3435
)
3536

3637

38+
class FakeSimulatesSamples(SimulatesSamples):
39+
"""A SimulatesSamples that returns specified values from _run."""
40+
41+
def __init__(self, run_output: Dict[str, np.ndarray]):
42+
self._run_output = run_output
43+
44+
def _run(self, *args, **kwargs) -> Dict[str, np.ndarray]:
45+
return self._run_output
46+
47+
48+
class FakeStepResult(cirq.StepResult):
49+
def __init__(self, *, ones_qubits=None, final_state=None):
50+
self._ones_qubits = set(ones_qubits or [])
51+
self._final_state = final_state
52+
53+
def _simulator_state(self):
54+
return self._final_state
55+
56+
def state_vector(self):
57+
pass
58+
59+
def __setstate__(self, state):
60+
pass
61+
62+
def sample(self, qubits, repetitions=1, seed=None):
63+
return np.array([[qubit in self._ones_qubits for qubit in qubits]] * repetitions)
64+
65+
3766
class SimulatesIntermediateStateImpl(
3867
Generic[TStepResult, TSimulatorState, TActOnArgs],
3968
SimulatesIntermediateState[TStepResult, 'SimulationTrialResult', TSimulatorState, TActOnArgs],
@@ -62,43 +91,48 @@ def _create_simulator_trial_result(
6291
)
6392

6493

65-
@mock.patch.multiple(cirq.SimulatesSamples, __abstractmethods__=set(), _run=mock.Mock())
6694
def test_run_simulator_run():
67-
simulator = cirq.SimulatesSamples()
68-
expected_measurements = {'a': np.array([[[1]]])}
69-
simulator._run.return_value = expected_measurements
70-
circuit = mock.Mock(cirq.Circuit)
71-
circuit.__iter__ = mock.Mock(return_value=iter([]))
72-
param_resolver = mock.Mock(cirq.ParamResolver)
73-
param_resolver.param_dict = {}
74-
expected_result = cirq.ResultDict(records=expected_measurements, params=param_resolver)
95+
expected_records = {'a': np.array([[[1]]])}
96+
simulator = FakeSimulatesSamples(expected_records)
97+
circuit = cirq.Circuit(cirq.measure(cirq.LineQubit(0), key='k'))
98+
param_resolver = cirq.ParamResolver({})
99+
expected_result = cirq.ResultDict(records=expected_records, params=param_resolver)
75100
assert expected_result == simulator.run(
76101
program=circuit, repetitions=10, param_resolver=param_resolver
77102
)
78-
simulator._run.assert_called_once_with(
79-
circuit=circuit, repetitions=10, param_resolver=param_resolver
80-
)
81103

82104

83-
@mock.patch.multiple(cirq.SimulatesSamples, __abstractmethods__=set(), _run=mock.Mock())
84105
def test_run_simulator_sweeps():
85-
simulator = cirq.SimulatesSamples()
86-
expected_measurements = {'a': np.array([[[1]]])}
87-
simulator._run.return_value = expected_measurements
88-
circuit = mock.Mock(cirq.Circuit)
89-
circuit.__iter__ = mock.Mock(return_value=iter([]))
90-
param_resolvers = [mock.Mock(cirq.ParamResolver), mock.Mock(cirq.ParamResolver)]
91-
for resolver in param_resolvers:
92-
resolver.param_dict = {}
106+
expected_records = {'a': np.array([[[1]]])}
107+
simulator = FakeSimulatesSamples(expected_records)
108+
circuit = cirq.Circuit(cirq.measure(cirq.LineQubit(0), key='k'))
109+
param_resolvers = [cirq.ParamResolver({}), cirq.ParamResolver({})]
93110
expected_results = [
94-
cirq.ResultDict(records=expected_measurements, params=param_resolvers[0]),
95-
cirq.ResultDict(records=expected_measurements, params=param_resolvers[1]),
111+
cirq.ResultDict(records=expected_records, params=param_resolvers[0]),
112+
cirq.ResultDict(records=expected_records, params=param_resolvers[1]),
96113
]
97114
assert expected_results == simulator.run_sweep(
98115
program=circuit, repetitions=10, params=param_resolvers
99116
)
100-
simulator._run.assert_called_with(circuit=circuit, repetitions=10, param_resolver=mock.ANY)
101-
assert simulator._run.call_count == 2
117+
118+
119+
def test_run_simulator_sweeps_with_deprecated_run():
120+
expected_measurements = {'a': np.array([[1]])}
121+
simulator = FakeSimulatesSamples(expected_measurements)
122+
circuit = cirq.Circuit(cirq.measure(cirq.LineQubit(0), key='k'))
123+
param_resolvers = [cirq.ParamResolver({}), cirq.ParamResolver({})]
124+
expected_records = {'a': np.array([[[1]]])}
125+
expected_results = [
126+
cirq.ResultDict(records=expected_records, params=param_resolvers[0]),
127+
cirq.ResultDict(records=expected_records, params=param_resolvers[1]),
128+
]
129+
with cirq.testing.assert_deprecated(
130+
'values in the output of simulator._run must be 3D',
131+
deadline='v0.15',
132+
):
133+
assert expected_results == simulator.run_sweep(
134+
program=circuit, repetitions=10, params=param_resolvers
135+
)
102136

103137

104138
@mock.patch.multiple(
@@ -157,8 +191,7 @@ def steps(*args, **kwargs):
157191
program=circuit, params=param_resolvers, qubit_order=qubit_order, initial_state=2
158192
)
159193

160-
final_step_result = mock.Mock()
161-
final_step_result._simulator_state.return_value = final_state
194+
final_step_result = FakeStepResult(final_state=final_state)
162195
expected_results = [
163196
cirq.SimulationTrialResult(
164197
measurements={'a': np.array([True, True])},
@@ -174,27 +207,10 @@ def steps(*args, **kwargs):
174207
assert results == expected_results
175208

176209

177-
class FakeStepResult(cirq.StepResult):
178-
def __init__(self, ones_qubits):
179-
self._ones_qubits = set(ones_qubits)
180-
181-
def _simulator_state(self):
182-
pass
183-
184-
def state_vector(self):
185-
pass
186-
187-
def __setstate__(self, state):
188-
pass
189-
190-
def sample(self, qubits, repetitions=1, seed=None):
191-
return np.array([[qubit in self._ones_qubits for qubit in qubits]] * repetitions)
192-
193-
194210
def test_step_sample_measurement_ops():
195211
q0, q1, q2 = cirq.LineQubit.range(3)
196212
measurement_ops = [cirq.measure(q0, q1), cirq.measure(q2)]
197-
step_result = FakeStepResult([q1])
213+
step_result = FakeStepResult(ones_qubits=[q1])
198214

199215
measurements = step_result.sample_measurement_ops(measurement_ops)
200216
np.testing.assert_equal(measurements, {'0,1': [[False, True]], '2': [[False]]})
@@ -203,7 +219,7 @@ def test_step_sample_measurement_ops():
203219
def test_step_sample_measurement_ops_repetitions():
204220
q0, q1, q2 = cirq.LineQubit.range(3)
205221
measurement_ops = [cirq.measure(q0, q1), cirq.measure(q2)]
206-
step_result = FakeStepResult([q1])
222+
step_result = FakeStepResult(ones_qubits=[q1])
207223

208224
measurements = step_result.sample_measurement_ops(measurement_ops, repetitions=3)
209225
np.testing.assert_equal(measurements, {'0,1': [[False, True]] * 3, '2': [[False]] * 3})
@@ -215,29 +231,29 @@ def test_step_sample_measurement_ops_invert_mask():
215231
cirq.measure(q0, q1, invert_mask=(True,)),
216232
cirq.measure(q2, invert_mask=(False,)),
217233
]
218-
step_result = FakeStepResult([q1])
234+
step_result = FakeStepResult(ones_qubits=[q1])
219235

220236
measurements = step_result.sample_measurement_ops(measurement_ops)
221237
np.testing.assert_equal(measurements, {'0,1': [[True, True]], '2': [[False]]})
222238

223239

224240
def test_step_sample_measurement_ops_no_measurements():
225-
step_result = FakeStepResult([])
241+
step_result = FakeStepResult(ones_qubits=[])
226242

227243
measurements = step_result.sample_measurement_ops([])
228244
assert measurements == {}
229245

230246

231247
def test_step_sample_measurement_ops_not_measurement():
232248
q0 = cirq.LineQubit(0)
233-
step_result = FakeStepResult([q0])
249+
step_result = FakeStepResult(ones_qubits=[q0])
234250
with pytest.raises(ValueError, match='MeasurementGate'):
235251
step_result.sample_measurement_ops([cirq.X(q0)])
236252

237253

238254
def test_step_sample_measurement_ops_repeated_qubit():
239255
q0, q1, q2 = cirq.LineQubit.range(3)
240-
step_result = FakeStepResult([q0])
256+
step_result = FakeStepResult(ones_qubits=[q0])
241257
with pytest.raises(ValueError, match='Measurement key 0 repeated'):
242258
step_result.sample_measurement_ops(
243259
[cirq.measure(q0), cirq.measure(q1, q2), cirq.measure(q0)]
@@ -246,8 +262,7 @@ def test_step_sample_measurement_ops_repeated_qubit():
246262

247263
def test_simulation_trial_result_equality():
248264
eq = cirq.testing.EqualsTester()
249-
final_step_result = mock.Mock(cirq.StepResult)
250-
final_step_result._simulator_state.return_value = ()
265+
final_step_result = FakeStepResult(final_state=())
251266
eq.add_equality_group(
252267
cirq.SimulationTrialResult(
253268
params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result
@@ -270,7 +285,7 @@ def test_simulation_trial_result_equality():
270285
final_step_result=final_step_result,
271286
)
272287
)
273-
final_step_result._simulator_state.return_value = (0, 1)
288+
final_step_result._final_state = (0, 1)
274289
eq.add_equality_group(
275290
cirq.SimulationTrialResult(
276291
params=cirq.ParamResolver({'s': 1}),
@@ -281,8 +296,7 @@ def test_simulation_trial_result_equality():
281296

282297

283298
def test_simulation_trial_result_repr():
284-
final_step_result = mock.Mock(cirq.StepResult)
285-
final_step_result._simulator_state.return_value = (0, 1)
299+
final_step_result = FakeStepResult(final_state=(0, 1))
286300
assert repr(
287301
cirq.SimulationTrialResult(
288302
params=cirq.ParamResolver({'s': 1}),
@@ -298,8 +312,7 @@ def test_simulation_trial_result_repr():
298312

299313

300314
def test_simulation_trial_result_str():
301-
final_step_result = mock.Mock(cirq.StepResult)
302-
final_step_result._simulator_state.return_value = (0, 1)
315+
final_step_result = FakeStepResult(final_state=(0, 1))
303316
assert (
304317
str(
305318
cirq.SimulationTrialResult(
@@ -369,13 +382,10 @@ def text(self, to_print):
369382
@duet.sync
370383
async def test_async_sample():
371384
m = {'mock': np.array([[[0]], [[1]]])}
372-
373-
class MockSimulator(cirq.SimulatesSamples):
374-
def _run(self, circuit, param_resolver, repetitions):
375-
return m
385+
simulator = FakeSimulatesSamples(m)
376386

377387
q = cirq.LineQubit(0)
378-
f = MockSimulator().run_async(cirq.Circuit(cirq.measure(q)), repetitions=10)
388+
f = simulator.run_async(cirq.Circuit(cirq.measure(q)), repetitions=10)
379389
result = await f
380390
np.testing.assert_equal(result.records, m)
381391

@@ -458,10 +468,8 @@ def _kraus_(self):
458468

459469

460470
def test_iter_definitions():
461-
final_step_result = mock.Mock(cirq.StepResult)
462-
final_step_result._simulator_state.return_value = []
463471
dummy_trial_result = SimulationTrialResult(
464-
params={}, measurements={}, final_step_result=final_step_result
472+
params={}, measurements={}, final_step_result=FakeStepResult(final_state=[])
465473
)
466474

467475
class FakeNonIterSimulatorImpl(

0 commit comments

Comments
 (0)