11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import tempfile
14
15
15
16
import numpy as np
16
17
import pytest
26
27
_aggregate_n_repetitions ,
27
28
_check_meas_specs_still_todo ,
28
29
StoppingCriteria ,
30
+ _parse_checkpoint_options ,
29
31
)
30
32
31
33
@@ -155,7 +157,6 @@ def test_params_and_settings():
155
157
156
158
157
159
def test_subdivide_meas_specs ():
158
-
159
160
qubits = cirq .LineQubit .range (2 )
160
161
q0 , q1 = qubits
161
162
setting = cw .InitObsSetting (
@@ -364,8 +365,56 @@ def test_meas_spec_still_todo_lots_of_params(monkeypatch):
364
365
)
365
366
366
367
367
- @pytest .mark .parametrize ('with_circuit_sweep' , (True , False ))
368
- def test_measure_grouped_settings (with_circuit_sweep ):
368
+ def test_checkpoint_options ():
369
+ # There are three ~binary options (the latter two can be either specified or `None`. We
370
+ # test those 2^3 cases.
371
+
372
+ assert _parse_checkpoint_options (False , None , None ) == (None , None )
373
+ with pytest .raises (ValueError ):
374
+ _parse_checkpoint_options (False , 'test' , None )
375
+ with pytest .raises (ValueError ):
376
+ _parse_checkpoint_options (False , None , 'test' )
377
+ with pytest .raises (ValueError ):
378
+ _parse_checkpoint_options (False , 'test1' , 'test2' )
379
+
380
+ chk , chkprev = _parse_checkpoint_options (True , None , None )
381
+ assert chk .startswith (tempfile .gettempdir ())
382
+ assert chk .endswith ('observables.json' )
383
+ assert chkprev .startswith (tempfile .gettempdir ())
384
+ assert chkprev .endswith ('observables.prev.json' )
385
+
386
+ chk , chkprev = _parse_checkpoint_options (True , None , 'prev.json' )
387
+ assert chk .startswith (tempfile .gettempdir ())
388
+ assert chk .endswith ('observables.json' )
389
+ assert chkprev == 'prev.json'
390
+
391
+ chk , chkprev = _parse_checkpoint_options (True , 'my_fancy_observables.json' , None )
392
+ assert chk == 'my_fancy_observables.json'
393
+ assert chkprev == 'my_fancy_observables.prev.json'
394
+
395
+ chk , chkprev = _parse_checkpoint_options (True , 'my_fancy/observables.json' , None )
396
+ assert chk == 'my_fancy/observables.json'
397
+ assert chkprev == 'my_fancy/observables.prev.json'
398
+
399
+ with pytest .raises (ValueError , match = r'Please use a `.json` filename.*' ):
400
+ _parse_checkpoint_options (True , 'my_fancy_observables.obs' , None )
401
+
402
+ with pytest .raises (ValueError , match = r"pattern of 'filename.extension'.*" ):
403
+ _parse_checkpoint_options (True , 'my_fancy_observables' , None )
404
+ with pytest .raises (ValueError , match = r"pattern of 'filename.extension'.*" ):
405
+ _parse_checkpoint_options (True , '.obs' , None )
406
+ with pytest .raises (ValueError , match = r"pattern of 'filename.extension'.*" ):
407
+ _parse_checkpoint_options (True , 'obs.' , None )
408
+ with pytest .raises (ValueError , match = r"pattern of 'filename.extension'.*" ):
409
+ _parse_checkpoint_options (True , '' , None )
410
+
411
+ chk , chkprev = _parse_checkpoint_options (True , 'test1' , 'test2' )
412
+ assert chk == 'test1'
413
+ assert chkprev == 'test2'
414
+
415
+
416
+ @pytest .mark .parametrize (('with_circuit_sweep' , 'checkpoint' ), [(True , True ), (False , False )])
417
+ def test_measure_grouped_settings (with_circuit_sweep , checkpoint , tmpdir ):
369
418
qubits = cirq .LineQubit .range (1 )
370
419
(q ,) = qubits
371
420
tests = [
@@ -381,6 +430,11 @@ def test_measure_grouped_settings(with_circuit_sweep):
381
430
else :
382
431
ss = None
383
432
433
+ if checkpoint :
434
+ checkpoint_fn = f'{ tmpdir } /obs.json'
435
+ else :
436
+ checkpoint_fn = None
437
+
384
438
for init , obs , coef in tests :
385
439
setting = cw .InitObsSetting (
386
440
init_state = init (q ),
@@ -392,8 +446,10 @@ def test_measure_grouped_settings(with_circuit_sweep):
392
446
circuit = circuit ,
393
447
grouped_settings = grouped_settings ,
394
448
sampler = cirq .Simulator (),
395
- stopping_criteria = cw .RepetitionsStoppingCriteria (1_000 ),
449
+ stopping_criteria = cw .RepetitionsStoppingCriteria (1_000 , repetitions_per_chunk = 500 ),
396
450
circuit_sweep = ss ,
451
+ checkpoint = checkpoint ,
452
+ checkpoint_fn = checkpoint_fn ,
397
453
)
398
454
if with_circuit_sweep :
399
455
for result in results :
@@ -430,3 +486,38 @@ def test_measure_grouped_settings_calibration_validation():
430
486
readout_calibrations = dummy_ro_calib ,
431
487
readout_symmetrization = False , # no-no!
432
488
)
489
+
490
+
491
+ def test_measure_grouped_settings_read_checkpoint (tmpdir ):
492
+ qubits = cirq .LineQubit .range (1 )
493
+ (q ,) = qubits
494
+
495
+ setting = cw .InitObsSetting (
496
+ init_state = cirq .KET_ZERO (q ),
497
+ observable = cirq .Z (q ),
498
+ )
499
+ grouped_settings = {setting : [setting ]}
500
+ circuit = cirq .Circuit (cirq .I .on_each (* qubits ))
501
+ with pytest .raises (ValueError , match = r'same filename.*' ):
502
+ _ = cw .measure_grouped_settings (
503
+ circuit = circuit ,
504
+ grouped_settings = grouped_settings ,
505
+ sampler = cirq .Simulator (),
506
+ stopping_criteria = cw .RepetitionsStoppingCriteria (1_000 , repetitions_per_chunk = 500 ),
507
+ checkpoint = True ,
508
+ checkpoint_fn = f'{ tmpdir } /obs.json' ,
509
+ checkpoint_other_fn = f'{ tmpdir } /obs.json' , # Same filename
510
+ )
511
+ _ = cw .measure_grouped_settings (
512
+ circuit = circuit ,
513
+ grouped_settings = grouped_settings ,
514
+ sampler = cirq .Simulator (),
515
+ stopping_criteria = cw .RepetitionsStoppingCriteria (1_000 , repetitions_per_chunk = 500 ),
516
+ checkpoint = True ,
517
+ checkpoint_fn = f'{ tmpdir } /obs.json' ,
518
+ checkpoint_other_fn = f'{ tmpdir } /obs.prev.json' ,
519
+ )
520
+ results = cirq .read_json (f'{ tmpdir } /obs.json' )
521
+ (result ,) = results # one group
522
+ assert result .n_repetitions == 1_000
523
+ assert result .means () == [1.0 ]
0 commit comments