Skip to content

Commit 120eb87

Browse files
authored
[Obs] 4.4 - Checkpointing (#4352)
Save "checkpoint" files during observable estimation. With enough observables or enough samples or low enough variables you can construct long running calls to this functionality. These options will (optionally) make sure data is not lost in those scenarios. - It's off by default - If you just toggle it to True, it will save data in a temporary directory. The use case envisaged here is to guard against data loss in an unforseen interruption - You can provide your own filenames. The use case here can be part of the nominal operation where you use that file as the saved results for a given run - We need two filenames so we can do an atomic `mv` so errors during serialization won't result in data loss. The two filenames should be on the same disk or `mv` isn't atomic. We don't enforce that.
1 parent b1dc973 commit 120eb87

File tree

2 files changed

+178
-5
lines changed

2 files changed

+178
-5
lines changed

cirq-core/cirq/work/observable_measurement.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
import abc
1616
import dataclasses
1717
import itertools
18+
import os
19+
import tempfile
1820
import warnings
1921
from typing import Optional, Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence
2022

2123
import numpy as np
2224
import sympy
2325
from cirq import circuits, study, ops, value
2426
from cirq._doc import document
25-
from cirq.protocols import json_serializable_dataclass
27+
from cirq.protocols import json_serializable_dataclass, to_json
2628
from cirq.work.observable_measurement_data import BitstringAccumulator
2729
from cirq.work.observable_settings import (
2830
InitObsSetting,
@@ -353,6 +355,60 @@ def _to_sweep(param_tuples):
353355
return to_sweep
354356

355357

358+
def _parse_checkpoint_options(
359+
checkpoint: bool, checkpoint_fn: Optional[str], checkpoint_other_fn: Optional[str]
360+
) -> Tuple[Optional[str], Optional[str]]:
361+
"""Parse the checkpoint-oriented options in `measure_grouped_settings`.
362+
363+
This function contains the validation and defaults logic. Please see
364+
`measure_grouped_settings` for documentation on these args.
365+
366+
Returns:
367+
checkpoint_fn, checkpoint_other_fn: Parsed or default filenames for primary and previous
368+
checkpoint files.
369+
"""
370+
if not checkpoint:
371+
if checkpoint_fn is not None or checkpoint_other_fn is not None:
372+
raise ValueError(
373+
"Checkpoint filenames were provided but `checkpoint` was set to False."
374+
)
375+
return None, None
376+
377+
if checkpoint_fn is None:
378+
checkpoint_dir = tempfile.mkdtemp()
379+
chk_basename = 'observables'
380+
checkpoint_fn = f'{checkpoint_dir}/{chk_basename}.json'
381+
382+
if checkpoint_other_fn is None:
383+
checkpoint_dir = os.path.dirname(checkpoint_fn)
384+
chk_basename = os.path.basename(checkpoint_fn)
385+
chk_basename, dot, ext = chk_basename.rpartition('.')
386+
if chk_basename == '' or dot != '.' or ext == '':
387+
raise ValueError(
388+
f"You specified `checkpoint_fn={checkpoint_fn!r}` which does not follow the "
389+
f"pattern of 'filename.extension'. Please follow this pattern or fully specify "
390+
f"`checkpoint_other_fn`."
391+
)
392+
393+
if ext != 'json':
394+
raise ValueError(
395+
"Please use a `.json` filename or fully "
396+
"specify checkpoint_fn and checkpoint_other_fn"
397+
)
398+
if checkpoint_dir == '':
399+
checkpoint_other_fn = f'{chk_basename}.prev.json'
400+
else:
401+
checkpoint_other_fn = f'{checkpoint_dir}/{chk_basename}.prev.json'
402+
403+
if checkpoint_fn == checkpoint_other_fn:
404+
raise ValueError(
405+
f"`checkpoint_fn` and `checkpoint_other_fn` were set to the same "
406+
f"filename: {checkpoint_fn}. Please use two different filenames."
407+
)
408+
409+
return checkpoint_fn, checkpoint_other_fn
410+
411+
356412
def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting]]) -> bool:
357413
"""Helper function to go through init_states and determine if any of them need an
358414
initialization layer of single-qubit gates."""
@@ -371,6 +427,9 @@ def measure_grouped_settings(
371427
readout_symmetrization: bool = False,
372428
circuit_sweep: 'cirq.study.sweepable.SweepLike' = None,
373429
readout_calibrations: Optional[BitstringAccumulator] = None,
430+
checkpoint: bool = False,
431+
checkpoint_fn: Optional[str] = None,
432+
checkpoint_other_fn: Optional[str] = None,
374433
) -> List[BitstringAccumulator]:
375434
"""Measure a suite of grouped InitObsSetting settings.
376435
@@ -399,10 +458,26 @@ def measure_grouped_settings(
399458
in `circuit`. The total sweep is the product of the circuit sweep
400459
with parameter settings for the single-qubit basis-change rotations.
401460
readout_calibrations: The result of `calibrate_readout_error`.
461+
checkpoint: If set to True, save cumulative raw results at the end
462+
of each iteration of the sampling loop. Load in these results
463+
with `cirq.read_json`.
464+
checkpoint_fn: The filename for the checkpoint file. If `checkpoint`
465+
is set to True and this is not specified, a file in a temporary
466+
directory will be used.
467+
checkpoint_other_fn: The filename for another checkpoint file, which
468+
contains the previous checkpoint. This lets us avoid losing data if
469+
a failure occurs during checkpoint writing. If `checkpoint`
470+
is set to True and this is not specified, a file in a temporary
471+
directory will be used. If `checkpoint` is set to True and
472+
`checkpoint_fn` is specified but this argument is *not* specified,
473+
"{checkpoint_fn}.prev.json" will be used.
402474
"""
403475
if readout_calibrations is not None and not readout_symmetrization:
404476
raise ValueError("Readout calibration only works if `readout_symmetrization` is enabled.")
405477

478+
checkpoint_fn, checkpoint_other_fn = _parse_checkpoint_options(
479+
checkpoint=checkpoint, checkpoint_fn=checkpoint_fn, checkpoint_other_fn=checkpoint_other_fn
480+
)
406481
qubits = sorted({q for ms in grouped_settings.keys() for q in ms.init_state.qubits})
407482
qubit_to_index = {q: i for i, q in enumerate(qubits)}
408483

@@ -471,4 +546,11 @@ def measure_grouped_settings(
471546
bitstrings = np.logical_xor(flippy_ms.flips, result.measurements['z'])
472547
accumulator.consume_results(bitstrings.astype(np.uint8, casting='safe'))
473548

549+
if checkpoint:
550+
assert checkpoint_fn is not None, 'mypy'
551+
assert checkpoint_other_fn is not None, 'mypy'
552+
if os.path.exists(checkpoint_fn):
553+
os.replace(checkpoint_fn, checkpoint_other_fn)
554+
to_json(list(accumulators.values()), checkpoint_fn)
555+
474556
return list(accumulators.values())

cirq-core/cirq/work/observable_measurement_test.py

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import tempfile
1415

1516
import numpy as np
1617
import pytest
@@ -26,6 +27,7 @@
2627
_aggregate_n_repetitions,
2728
_check_meas_specs_still_todo,
2829
StoppingCriteria,
30+
_parse_checkpoint_options,
2931
)
3032

3133

@@ -155,7 +157,6 @@ def test_params_and_settings():
155157

156158

157159
def test_subdivide_meas_specs():
158-
159160
qubits = cirq.LineQubit.range(2)
160161
q0, q1 = qubits
161162
setting = cw.InitObsSetting(
@@ -364,8 +365,56 @@ def test_meas_spec_still_todo_lots_of_params(monkeypatch):
364365
)
365366

366367

367-
@pytest.mark.parametrize('with_circuit_sweep', (True, False))
368-
def test_measure_grouped_settings(with_circuit_sweep):
368+
def test_checkpoint_options():
369+
# There are three ~binary options (the latter two can be either specified or `None`. We
370+
# test those 2^3 cases.
371+
372+
assert _parse_checkpoint_options(False, None, None) == (None, None)
373+
with pytest.raises(ValueError):
374+
_parse_checkpoint_options(False, 'test', None)
375+
with pytest.raises(ValueError):
376+
_parse_checkpoint_options(False, None, 'test')
377+
with pytest.raises(ValueError):
378+
_parse_checkpoint_options(False, 'test1', 'test2')
379+
380+
chk, chkprev = _parse_checkpoint_options(True, None, None)
381+
assert chk.startswith(tempfile.gettempdir())
382+
assert chk.endswith('observables.json')
383+
assert chkprev.startswith(tempfile.gettempdir())
384+
assert chkprev.endswith('observables.prev.json')
385+
386+
chk, chkprev = _parse_checkpoint_options(True, None, 'prev.json')
387+
assert chk.startswith(tempfile.gettempdir())
388+
assert chk.endswith('observables.json')
389+
assert chkprev == 'prev.json'
390+
391+
chk, chkprev = _parse_checkpoint_options(True, 'my_fancy_observables.json', None)
392+
assert chk == 'my_fancy_observables.json'
393+
assert chkprev == 'my_fancy_observables.prev.json'
394+
395+
chk, chkprev = _parse_checkpoint_options(True, 'my_fancy/observables.json', None)
396+
assert chk == 'my_fancy/observables.json'
397+
assert chkprev == 'my_fancy/observables.prev.json'
398+
399+
with pytest.raises(ValueError, match=r'Please use a `.json` filename.*'):
400+
_parse_checkpoint_options(True, 'my_fancy_observables.obs', None)
401+
402+
with pytest.raises(ValueError, match=r"pattern of 'filename.extension'.*"):
403+
_parse_checkpoint_options(True, 'my_fancy_observables', None)
404+
with pytest.raises(ValueError, match=r"pattern of 'filename.extension'.*"):
405+
_parse_checkpoint_options(True, '.obs', None)
406+
with pytest.raises(ValueError, match=r"pattern of 'filename.extension'.*"):
407+
_parse_checkpoint_options(True, 'obs.', None)
408+
with pytest.raises(ValueError, match=r"pattern of 'filename.extension'.*"):
409+
_parse_checkpoint_options(True, '', None)
410+
411+
chk, chkprev = _parse_checkpoint_options(True, 'test1', 'test2')
412+
assert chk == 'test1'
413+
assert chkprev == 'test2'
414+
415+
416+
@pytest.mark.parametrize(('with_circuit_sweep', 'checkpoint'), [(True, True), (False, False)])
417+
def test_measure_grouped_settings(with_circuit_sweep, checkpoint, tmpdir):
369418
qubits = cirq.LineQubit.range(1)
370419
(q,) = qubits
371420
tests = [
@@ -381,6 +430,11 @@ def test_measure_grouped_settings(with_circuit_sweep):
381430
else:
382431
ss = None
383432

433+
if checkpoint:
434+
checkpoint_fn = f'{tmpdir}/obs.json'
435+
else:
436+
checkpoint_fn = None
437+
384438
for init, obs, coef in tests:
385439
setting = cw.InitObsSetting(
386440
init_state=init(q),
@@ -392,8 +446,10 @@ def test_measure_grouped_settings(with_circuit_sweep):
392446
circuit=circuit,
393447
grouped_settings=grouped_settings,
394448
sampler=cirq.Simulator(),
395-
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000),
449+
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
396450
circuit_sweep=ss,
451+
checkpoint=checkpoint,
452+
checkpoint_fn=checkpoint_fn,
397453
)
398454
if with_circuit_sweep:
399455
for result in results:
@@ -430,3 +486,38 @@ def test_measure_grouped_settings_calibration_validation():
430486
readout_calibrations=dummy_ro_calib,
431487
readout_symmetrization=False, # no-no!
432488
)
489+
490+
491+
def test_measure_grouped_settings_read_checkpoint(tmpdir):
492+
qubits = cirq.LineQubit.range(1)
493+
(q,) = qubits
494+
495+
setting = cw.InitObsSetting(
496+
init_state=cirq.KET_ZERO(q),
497+
observable=cirq.Z(q),
498+
)
499+
grouped_settings = {setting: [setting]}
500+
circuit = cirq.Circuit(cirq.I.on_each(*qubits))
501+
with pytest.raises(ValueError, match=r'same filename.*'):
502+
_ = cw.measure_grouped_settings(
503+
circuit=circuit,
504+
grouped_settings=grouped_settings,
505+
sampler=cirq.Simulator(),
506+
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
507+
checkpoint=True,
508+
checkpoint_fn=f'{tmpdir}/obs.json',
509+
checkpoint_other_fn=f'{tmpdir}/obs.json', # Same filename
510+
)
511+
_ = cw.measure_grouped_settings(
512+
circuit=circuit,
513+
grouped_settings=grouped_settings,
514+
sampler=cirq.Simulator(),
515+
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
516+
checkpoint=True,
517+
checkpoint_fn=f'{tmpdir}/obs.json',
518+
checkpoint_other_fn=f'{tmpdir}/obs.prev.json',
519+
)
520+
results = cirq.read_json(f'{tmpdir}/obs.json')
521+
(result,) = results # one group
522+
assert result.n_repetitions == 1_000
523+
assert result.means() == [1.0]

0 commit comments

Comments
 (0)