Skip to content

[Obs] 4.5 - High-level API #4392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cirq-core/cirq/work/observable_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Dict, List, TYPE_CHECKING, cast
from typing import Iterable, Dict, List, TYPE_CHECKING, cast, Callable

from cirq import ops, value
from cirq.work.observable_settings import InitObsSetting, _max_weight_state, _max_weight_observable

if TYPE_CHECKING:
pass

GROUPER_T = Callable[[Iterable[InitObsSetting]], Dict[InitObsSetting, List[InitObsSetting]]]


def group_settings_greedy(
settings: Iterable[InitObsSetting],
Expand Down
187 changes: 186 additions & 1 deletion cirq-core/cirq/work/observable_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,31 @@
import os
import tempfile
import warnings
from typing import Optional, Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence
from typing import (
Optional,
Union,
Iterable,
Dict,
List,
Tuple,
TYPE_CHECKING,
Set,
Sequence,
Type,
Any,
)

import numpy as np
import pandas as pd
import sympy
from cirq import circuits, study, ops, value
from cirq._doc import document
from cirq.protocols import json_serializable_dataclass, to_json
from cirq.work.observable_grouping import group_settings_greedy, GROUPER_T
from cirq.work.observable_measurement_data import BitstringAccumulator
from cirq.work.observable_settings import (
InitObsSetting,
observables_to_settings,
_MeasurementSpec,
)

Expand Down Expand Up @@ -80,6 +95,17 @@ class StoppingCriteria(abc.ABC):
"""An abstract object that queries a BitstringAccumulator to figure out
whether that `meas_spec` is complete."""

def __init__(self, val: Any):
"""Initialize the stopping criteria with one required argument.

We provide RepetitionsStoppingCriteria and VarianceStoppingCriteria which
each have one required argument. For convenience, the high-level `measure_observables`
function takes either an instantiation of a `StoppingCriteria` or a shorthand name+val.
If future criteria need more or fewer required arguments, this abstract method may
need to be factored out and the `_parse_stopping_criteria` logic updated.
"""
raise NotImplementedError()

@abc.abstractmethod
def more_repetitions(self, accumulator: BitstringAccumulator) -> int:
"""Return the number of additional repetitions to take.
Expand Down Expand Up @@ -554,3 +580,162 @@ def measure_grouped_settings(
to_json(list(accumulators.values()), checkpoint_fn)

return list(accumulators.values())


_GROUPING_FUNCS: Dict[str, GROUPER_T] = {
'greedy': group_settings_greedy,
}

_STOPPING_CRITS: Dict[str, Type[StoppingCriteria]] = {
'repetitions': RepetitionsStoppingCriteria,
'variance': VarianceStoppingCriteria,
}


def _parse_stopping_criteria(
stopping_criteria: Union[str, StoppingCriteria], stopping_criteria_val: Optional[float] = None
) -> StoppingCriteria:
"""Logic for turning a named stopping_criteria and value to one of the built-in stopping
criteria in support of the high-level `measure_observables` API.
"""
if isinstance(stopping_criteria, str):
stopping_criteria_cls = _STOPPING_CRITS[stopping_criteria]
stopping_criteria = stopping_criteria_cls(stopping_criteria_val)
return stopping_criteria


def _parse_grouper(grouper: Union[str, GROUPER_T] = group_settings_greedy) -> GROUPER_T:
"""Logic for turning a named grouper into one of the build-in groupers in support of the
high-level `measure_observables` API."""
if isinstance(grouper, str):
try:
grouper = _GROUPING_FUNCS[grouper.lower()]
except KeyError:
raise ValueError(f"Unknown grouping function {grouper}")
return grouper


def _get_all_qubits(
circuit: circuits.Circuit,
observables: Iterable[ops.PauliString],
) -> List['cirq.Qid']:
"""Helper function for `measure_observables` to get all qubits from a circuit and a
collection of observables."""
qubit_set = set()
for obs in observables:
qubit_set |= set(obs.qubits)
qubit_set |= circuit.all_qubits()
return sorted(qubit_set)


def measure_observables(
circuit: circuits.Circuit,
observables: Iterable[ops.PauliString],
sampler: Union['cirq.Simulator', 'cirq.Sampler'],
stopping_criteria: Union[str, StoppingCriteria],
stopping_criteria_val: Optional[float] = None,
*,
readout_symmetrization: bool = True,
circuit_sweep: Optional['cirq.Sweepable'] = None,
grouper: Union[str, GROUPER_T] = group_settings_greedy,
readout_calibrations: Optional[BitstringAccumulator] = None,
checkpoint: bool = False,
checkpoint_fn: Optional[str] = None,
checkpoint_other_fn: Optional[str] = None,
):
"""Measure a collection of PauliString observables for a state prepared by a Circuit.

If you need more control over the process, please see `measure_grouped_settings` for a
lower-level API. If you would like your results returned as a pandas DataFrame,
please see `measure_observables_df`.

Args:
circuit: The circuit. This can contain parameters, in which case
you should also specify `circuit_sweep`.
observables: A collection of PauliString observables to measure.
These will be grouped into simultaneously-measurable groups,
see `grouper` argument.
sampler: A sampler.
stopping_criteria: Either a StoppingCriteria object or one of
'variance', 'repetitions'. In the latter case, you must
also specify `stopping_criteria_val`.
stopping_criteria_val: The value used for named stopping criteria.
If you specified 'repetitions', this is the number of repetitions.
If you specified 'variance', this is the variance.
readout_symmetrization: If set to True, each run will be
split into two: one normal and one where a bit flip is
incorporated prior to measurement. In the latter case, the
measured bit will be flipped back classically and accumulated
together. This causes readout error to appear symmetric,
p(0|0) = p(1|1).
circuit_sweep: Additional parameter sweeps for parameters contained
in `circuit`. The total sweep is the product of the circuit sweep
with parameter settings for the single-qubit basis-change rotations.
grouper: Either "greedy" or a function that groups lists of
`InitObsSetting`. See the documentation for the `grouped_settings`
argument of `measure_grouped_settings` for full details.
readout_calibrations: The result of `calibrate_readout_error`.
checkpoint: If set to True, save cumulative raw results at the end
of each iteration of the sampling loop.
checkpoint_fn: The filename for the checkpoint file. If `checkpoint`
is set to True and this is not specified, a file in a temporary
directory will be used.
checkpoint_other_fn: The filename for another checkpoint file, which
contains the previous checkpoint. If `checkpoint`
is set to True and this is not specified, a file in a temporary
directory will be used. If `checkpoint` is set to True and
`checkpoint_fn` is specified but this argument is *not* specified,
"{checkpoint_fn}.prev.json" will be used.
"""
qubits = _get_all_qubits(circuit, observables)
settings = list(observables_to_settings(observables, qubits))
actual_grouper = _parse_grouper(grouper)
grouped_settings = actual_grouper(settings)
stopping_criteria = _parse_stopping_criteria(stopping_criteria, stopping_criteria_val)

return measure_grouped_settings(
circuit=circuit,
grouped_settings=grouped_settings,
sampler=sampler,
stopping_criteria=stopping_criteria,
circuit_sweep=circuit_sweep,
readout_symmetrization=readout_symmetrization,
readout_calibrations=readout_calibrations,
checkpoint=checkpoint,
checkpoint_fn=checkpoint_fn,
checkpoint_other_fn=checkpoint_other_fn,
)


def measure_observables_df(
circuit: circuits.Circuit,
observables: Iterable[ops.PauliString],
sampler: Union['cirq.Simulator', 'cirq.Sampler'],
stopping_criteria: Union[str, StoppingCriteria],
stopping_criteria_val: Optional[float] = None,
*,
readout_symmetrization: bool = True,
circuit_sweep: Optional['cirq.Sweepable'] = None,
grouper: Union[str, GROUPER_T] = group_settings_greedy,
readout_calibrations: Optional[BitstringAccumulator] = None,
checkpoint: bool = False,
checkpoint_fn: Optional[str] = None,
checkpoint_other_fn: Optional[str] = None,
):
"""Measure observables and return resulting data as a dataframe."""
accumulators = measure_observables(
circuit=circuit,
observables=observables,
sampler=sampler,
stopping_criteria=stopping_criteria,
stopping_criteria_val=stopping_criteria_val,
readout_symmetrization=readout_symmetrization,
circuit_sweep=circuit_sweep,
grouper=grouper,
readout_calibrations=readout_calibrations,
checkpoint=checkpoint,
checkpoint_fn=checkpoint_fn,
checkpoint_other_fn=checkpoint_other_fn,
)
df = pd.DataFrame(list(itertools.chain.from_iterable(acc.records for acc in accumulators)))
return df
42 changes: 42 additions & 0 deletions cirq-core/cirq/work/observable_measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
_check_meas_specs_still_todo,
StoppingCriteria,
_parse_checkpoint_options,
_parse_stopping_criteria,
measure_observables_df,
)


Expand Down Expand Up @@ -327,6 +329,9 @@ def test_meas_spec_still_todo_bad_spec():
bsa, meas_spec = _set_up_meas_specs_for_testing()

class BadStopping(StoppingCriteria):
def __init__(self):
pass

def more_repetitions(self, accumulator: BitstringAccumulator) -> int:
return -23

Expand Down Expand Up @@ -521,3 +526,40 @@ def test_measure_grouped_settings_read_checkpoint(tmpdir):
(result,) = results # one group
assert result.n_repetitions == 1_000
assert result.means() == [1.0]


def _test_parse_stopping_criteria():
with pytest.raises(ValueError, match='xxx'):
_ = _parse_stopping_criteria('repetitions')

rep = cw.RepetitionsStoppingCriteria(total_repetitions=1_000)
var = cw.VarianceStoppingCriteria(variance_bound=1e-3)
assert _parse_stopping_criteria('repetitions', 1_000) == rep
assert _parse_stopping_criteria('variance', 1e-3) == var
assert _parse_stopping_criteria(rep) == rep
assert _parse_stopping_criteria(var) == var


Q = cirq.NamedQubit('q')


@pytest.mark.parametrize(
['circuit', 'observable'],
[
(cirq.Circuit(cirq.X(Q) ** 0.2), cirq.Z(Q)),
(cirq.Circuit(cirq.X(Q) ** -0.5, cirq.Z(Q) ** 0.2), cirq.Y(Q)),
(cirq.Circuit(cirq.Y(Q) ** 0.5, cirq.Z(Q) ** 0.2), cirq.X(Q)),
],
)
def test_XYZ_point8(circuit, observable):
# each circuit, observable combination should result in the observable value of 0.8
df = measure_observables_df(
circuit,
[observable],
cirq.Simulator(seed=52),
stopping_criteria='variance',
stopping_criteria_val=1e-3 ** 2,
)
assert len(df) == 1, 'one obserbale'
mean = df.loc[0]['mean']
np.testing.assert_allclose(0.8, mean, atol=1e-2)