29
29
SimulatesExpectationValues ,
30
30
SimulatesFinalState ,
31
31
SimulatesIntermediateState ,
32
+ SimulatesSamples ,
32
33
SimulationTrialResult ,
33
34
TActOnArgs ,
34
35
)
35
36
36
37
38
+ class FakeSimulatesSamples (SimulatesSamples ):
39
+ """A SimulatesSamples that returns specified values from _run."""
40
+
41
+ def __init__ (self , run_output : Dict [str , np .ndarray ]):
42
+ self ._run_output = run_output
43
+
44
+ def _run (self , * args , ** kwargs ) -> Dict [str , np .ndarray ]:
45
+ return self ._run_output
46
+
47
+
48
+ class FakeStepResult (cirq .StepResult ):
49
+ def __init__ (self , * , ones_qubits = None , final_state = None ):
50
+ self ._ones_qubits = set (ones_qubits or [])
51
+ self ._final_state = final_state
52
+
53
+ def _simulator_state (self ):
54
+ return self ._final_state
55
+
56
+ def state_vector (self ):
57
+ pass
58
+
59
+ def __setstate__ (self , state ):
60
+ pass
61
+
62
+ def sample (self , qubits , repetitions = 1 , seed = None ):
63
+ return np .array ([[qubit in self ._ones_qubits for qubit in qubits ]] * repetitions )
64
+
65
+
37
66
class SimulatesIntermediateStateImpl (
38
67
Generic [TStepResult , TSimulatorState , TActOnArgs ],
39
68
SimulatesIntermediateState [TStepResult , 'SimulationTrialResult' , TSimulatorState , TActOnArgs ],
@@ -62,43 +91,48 @@ def _create_simulator_trial_result(
62
91
)
63
92
64
93
65
- @mock .patch .multiple (cirq .SimulatesSamples , __abstractmethods__ = set (), _run = mock .Mock ())
66
94
def test_run_simulator_run ():
67
- simulator = cirq .SimulatesSamples ()
68
- expected_measurements = {'a' : np .array ([[[1 ]]])}
69
- simulator ._run .return_value = expected_measurements
70
- circuit = mock .Mock (cirq .Circuit )
71
- circuit .__iter__ = mock .Mock (return_value = iter ([]))
72
- param_resolver = mock .Mock (cirq .ParamResolver )
73
- param_resolver .param_dict = {}
74
- expected_result = cirq .ResultDict (records = expected_measurements , params = param_resolver )
95
+ expected_records = {'a' : np .array ([[[1 ]]])}
96
+ simulator = FakeSimulatesSamples (expected_records )
97
+ circuit = cirq .Circuit (cirq .measure (cirq .LineQubit (0 ), key = 'k' ))
98
+ param_resolver = cirq .ParamResolver ({})
99
+ expected_result = cirq .ResultDict (records = expected_records , params = param_resolver )
75
100
assert expected_result == simulator .run (
76
101
program = circuit , repetitions = 10 , param_resolver = param_resolver
77
102
)
78
- simulator ._run .assert_called_once_with (
79
- circuit = circuit , repetitions = 10 , param_resolver = param_resolver
80
- )
81
103
82
104
83
- @mock .patch .multiple (cirq .SimulatesSamples , __abstractmethods__ = set (), _run = mock .Mock ())
84
105
def test_run_simulator_sweeps ():
85
- simulator = cirq .SimulatesSamples ()
86
- expected_measurements = {'a' : np .array ([[[1 ]]])}
87
- simulator ._run .return_value = expected_measurements
88
- circuit = mock .Mock (cirq .Circuit )
89
- circuit .__iter__ = mock .Mock (return_value = iter ([]))
90
- param_resolvers = [mock .Mock (cirq .ParamResolver ), mock .Mock (cirq .ParamResolver )]
91
- for resolver in param_resolvers :
92
- resolver .param_dict = {}
106
+ expected_records = {'a' : np .array ([[[1 ]]])}
107
+ simulator = FakeSimulatesSamples (expected_records )
108
+ circuit = cirq .Circuit (cirq .measure (cirq .LineQubit (0 ), key = 'k' ))
109
+ param_resolvers = [cirq .ParamResolver ({}), cirq .ParamResolver ({})]
93
110
expected_results = [
94
- cirq .ResultDict (records = expected_measurements , params = param_resolvers [0 ]),
95
- cirq .ResultDict (records = expected_measurements , params = param_resolvers [1 ]),
111
+ cirq .ResultDict (records = expected_records , params = param_resolvers [0 ]),
112
+ cirq .ResultDict (records = expected_records , params = param_resolvers [1 ]),
96
113
]
97
114
assert expected_results == simulator .run_sweep (
98
115
program = circuit , repetitions = 10 , params = param_resolvers
99
116
)
100
- simulator ._run .assert_called_with (circuit = circuit , repetitions = 10 , param_resolver = mock .ANY )
101
- assert simulator ._run .call_count == 2
117
+
118
+
119
+ def test_run_simulator_sweeps_with_deprecated_run ():
120
+ expected_measurements = {'a' : np .array ([[1 ]])}
121
+ simulator = FakeSimulatesSamples (expected_measurements )
122
+ circuit = cirq .Circuit (cirq .measure (cirq .LineQubit (0 ), key = 'k' ))
123
+ param_resolvers = [cirq .ParamResolver ({}), cirq .ParamResolver ({})]
124
+ expected_records = {'a' : np .array ([[[1 ]]])}
125
+ expected_results = [
126
+ cirq .ResultDict (records = expected_records , params = param_resolvers [0 ]),
127
+ cirq .ResultDict (records = expected_records , params = param_resolvers [1 ]),
128
+ ]
129
+ with cirq .testing .assert_deprecated (
130
+ 'values in the output of simulator._run must be 3D' ,
131
+ deadline = 'v0.15' ,
132
+ ):
133
+ assert expected_results == simulator .run_sweep (
134
+ program = circuit , repetitions = 10 , params = param_resolvers
135
+ )
102
136
103
137
104
138
@mock .patch .multiple (
@@ -157,8 +191,7 @@ def steps(*args, **kwargs):
157
191
program = circuit , params = param_resolvers , qubit_order = qubit_order , initial_state = 2
158
192
)
159
193
160
- final_step_result = mock .Mock ()
161
- final_step_result ._simulator_state .return_value = final_state
194
+ final_step_result = FakeStepResult (final_state = final_state )
162
195
expected_results = [
163
196
cirq .SimulationTrialResult (
164
197
measurements = {'a' : np .array ([True , True ])},
@@ -174,27 +207,10 @@ def steps(*args, **kwargs):
174
207
assert results == expected_results
175
208
176
209
177
- class FakeStepResult (cirq .StepResult ):
178
- def __init__ (self , ones_qubits ):
179
- self ._ones_qubits = set (ones_qubits )
180
-
181
- def _simulator_state (self ):
182
- pass
183
-
184
- def state_vector (self ):
185
- pass
186
-
187
- def __setstate__ (self , state ):
188
- pass
189
-
190
- def sample (self , qubits , repetitions = 1 , seed = None ):
191
- return np .array ([[qubit in self ._ones_qubits for qubit in qubits ]] * repetitions )
192
-
193
-
194
210
def test_step_sample_measurement_ops ():
195
211
q0 , q1 , q2 = cirq .LineQubit .range (3 )
196
212
measurement_ops = [cirq .measure (q0 , q1 ), cirq .measure (q2 )]
197
- step_result = FakeStepResult ([q1 ])
213
+ step_result = FakeStepResult (ones_qubits = [q1 ])
198
214
199
215
measurements = step_result .sample_measurement_ops (measurement_ops )
200
216
np .testing .assert_equal (measurements , {'0,1' : [[False , True ]], '2' : [[False ]]})
@@ -203,7 +219,7 @@ def test_step_sample_measurement_ops():
203
219
def test_step_sample_measurement_ops_repetitions ():
204
220
q0 , q1 , q2 = cirq .LineQubit .range (3 )
205
221
measurement_ops = [cirq .measure (q0 , q1 ), cirq .measure (q2 )]
206
- step_result = FakeStepResult ([q1 ])
222
+ step_result = FakeStepResult (ones_qubits = [q1 ])
207
223
208
224
measurements = step_result .sample_measurement_ops (measurement_ops , repetitions = 3 )
209
225
np .testing .assert_equal (measurements , {'0,1' : [[False , True ]] * 3 , '2' : [[False ]] * 3 })
@@ -215,29 +231,29 @@ def test_step_sample_measurement_ops_invert_mask():
215
231
cirq .measure (q0 , q1 , invert_mask = (True ,)),
216
232
cirq .measure (q2 , invert_mask = (False ,)),
217
233
]
218
- step_result = FakeStepResult ([q1 ])
234
+ step_result = FakeStepResult (ones_qubits = [q1 ])
219
235
220
236
measurements = step_result .sample_measurement_ops (measurement_ops )
221
237
np .testing .assert_equal (measurements , {'0,1' : [[True , True ]], '2' : [[False ]]})
222
238
223
239
224
240
def test_step_sample_measurement_ops_no_measurements ():
225
- step_result = FakeStepResult ([])
241
+ step_result = FakeStepResult (ones_qubits = [])
226
242
227
243
measurements = step_result .sample_measurement_ops ([])
228
244
assert measurements == {}
229
245
230
246
231
247
def test_step_sample_measurement_ops_not_measurement ():
232
248
q0 = cirq .LineQubit (0 )
233
- step_result = FakeStepResult ([q0 ])
249
+ step_result = FakeStepResult (ones_qubits = [q0 ])
234
250
with pytest .raises (ValueError , match = 'MeasurementGate' ):
235
251
step_result .sample_measurement_ops ([cirq .X (q0 )])
236
252
237
253
238
254
def test_step_sample_measurement_ops_repeated_qubit ():
239
255
q0 , q1 , q2 = cirq .LineQubit .range (3 )
240
- step_result = FakeStepResult ([q0 ])
256
+ step_result = FakeStepResult (ones_qubits = [q0 ])
241
257
with pytest .raises (ValueError , match = 'Measurement key 0 repeated' ):
242
258
step_result .sample_measurement_ops (
243
259
[cirq .measure (q0 ), cirq .measure (q1 , q2 ), cirq .measure (q0 )]
@@ -246,8 +262,7 @@ def test_step_sample_measurement_ops_repeated_qubit():
246
262
247
263
def test_simulation_trial_result_equality ():
248
264
eq = cirq .testing .EqualsTester ()
249
- final_step_result = mock .Mock (cirq .StepResult )
250
- final_step_result ._simulator_state .return_value = ()
265
+ final_step_result = FakeStepResult (final_state = ())
251
266
eq .add_equality_group (
252
267
cirq .SimulationTrialResult (
253
268
params = cirq .ParamResolver ({}), measurements = {}, final_step_result = final_step_result
@@ -270,7 +285,7 @@ def test_simulation_trial_result_equality():
270
285
final_step_result = final_step_result ,
271
286
)
272
287
)
273
- final_step_result ._simulator_state . return_value = (0 , 1 )
288
+ final_step_result ._final_state = (0 , 1 )
274
289
eq .add_equality_group (
275
290
cirq .SimulationTrialResult (
276
291
params = cirq .ParamResolver ({'s' : 1 }),
@@ -281,8 +296,7 @@ def test_simulation_trial_result_equality():
281
296
282
297
283
298
def test_simulation_trial_result_repr ():
284
- final_step_result = mock .Mock (cirq .StepResult )
285
- final_step_result ._simulator_state .return_value = (0 , 1 )
299
+ final_step_result = FakeStepResult (final_state = (0 , 1 ))
286
300
assert repr (
287
301
cirq .SimulationTrialResult (
288
302
params = cirq .ParamResolver ({'s' : 1 }),
@@ -298,8 +312,7 @@ def test_simulation_trial_result_repr():
298
312
299
313
300
314
def test_simulation_trial_result_str ():
301
- final_step_result = mock .Mock (cirq .StepResult )
302
- final_step_result ._simulator_state .return_value = (0 , 1 )
315
+ final_step_result = FakeStepResult (final_state = (0 , 1 ))
303
316
assert (
304
317
str (
305
318
cirq .SimulationTrialResult (
@@ -369,13 +382,10 @@ def text(self, to_print):
369
382
@duet .sync
370
383
async def test_async_sample ():
371
384
m = {'mock' : np .array ([[[0 ]], [[1 ]]])}
372
-
373
- class MockSimulator (cirq .SimulatesSamples ):
374
- def _run (self , circuit , param_resolver , repetitions ):
375
- return m
385
+ simulator = FakeSimulatesSamples (m )
376
386
377
387
q = cirq .LineQubit (0 )
378
- f = MockSimulator () .run_async (cirq .Circuit (cirq .measure (q )), repetitions = 10 )
388
+ f = simulator .run_async (cirq .Circuit (cirq .measure (q )), repetitions = 10 )
379
389
result = await f
380
390
np .testing .assert_equal (result .records , m )
381
391
@@ -458,10 +468,8 @@ def _kraus_(self):
458
468
459
469
460
470
def test_iter_definitions ():
461
- final_step_result = mock .Mock (cirq .StepResult )
462
- final_step_result ._simulator_state .return_value = []
463
471
dummy_trial_result = SimulationTrialResult (
464
- params = {}, measurements = {}, final_step_result = final_step_result
472
+ params = {}, measurements = {}, final_step_result = FakeStepResult ( final_state = [])
465
473
)
466
474
467
475
class FakeNonIterSimulatorImpl (
0 commit comments