Skip to content

[Obs] 4.5 - High-level API #4392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cirq-core/cirq/work/observable_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Dict, List, TYPE_CHECKING, cast
from typing import Iterable, Dict, List, TYPE_CHECKING, cast, Callable

from cirq import ops, value
from cirq.work.observable_settings import InitObsSetting, _max_weight_state, _max_weight_observable

if TYPE_CHECKING:
pass

GROUPER_T = Callable[[Iterable[InitObsSetting]], Dict[InitObsSetting, List[InitObsSetting]]]


def group_settings_greedy(
settings: Iterable[InitObsSetting],
Expand Down
261 changes: 213 additions & 48 deletions cirq-core/cirq/work/observable_measurement.py

Large diffs are not rendered by default.

39 changes: 27 additions & 12 deletions cirq-core/cirq/work/observable_measurement_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@

import dataclasses
import datetime
from typing import Dict, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Tuple, TYPE_CHECKING, Iterable, Any

import numpy as np

from cirq import protocols, ops
from cirq import ops, protocols
from cirq._compat import proper_repr
from cirq.work.observable_settings import (
InitObsSetting,
Expand Down Expand Up @@ -81,12 +80,12 @@ def _stats_from_measurements(
return obs_mean.item(), obs_err.item()


@protocols.json_serializable_dataclass(frozen=True)
@dataclasses.dataclass(frozen=True)
class ObservableMeasuredResult:
"""The result of an observable measurement.

Please see `flatten_grouped_results` or `BitstringAccumulator.results` for information on how
to get these from `measure_observables` return values.
A list of these is returned by `measure_observables`, or see `flatten_grouped_results` for
transformation of `measure_grouped_settings` BitstringAccumulators into these objects.

This is a flattened form of the contents of a `BitstringAccumulator` which may group many
simultaneously-observable settings into one object. As such, `BitstringAccumulator` has more
Expand All @@ -110,7 +109,7 @@ class ObservableMeasuredResult:

def __repr__(self):
# I wish we could use the default dataclass __repr__ but
# we need to prefix our class name with `cirq.work.`A
# we need to prefix our class name with `cirq.work.`
return (
f'cirq.work.ObservableMeasuredResult('
f'setting={self.setting!r}, '
Expand All @@ -132,6 +131,25 @@ def observable(self):
def stddev(self):
return np.sqrt(self.variance)

def as_dict(self) -> Dict[str, Any]:
"""Return the contents of this class as a dictionary.

This makes records suitable for construction of a Pandas dataframe. The circuit parameters
are flattened into the top-level of this dictionary.
"""
record = dataclasses.asdict(self)
del record['circuit_params']
del record['setting']
record['init_state'] = self.init_state
record['observable'] = self.observable

circuit_param_dict = {f'param.{k}': v for k, v in self.circuit_params.items()}
record.update(**circuit_param_dict)
return record

def _json_dict_(self):
return protocols.dataclass_json_dict(self)


def _setting_to_z_observable(setting: InitObsSetting):
qubits = setting.observable.qubits
Expand Down Expand Up @@ -271,7 +289,7 @@ def n_repetitions(self):
return len(self.bitstrings)

@property
def results(self):
def results(self) -> Iterable[ObservableMeasuredResult]:
"""Yield individual setting results as `ObservableMeasuredResult`
objects."""
for setting in self._simul_settings:
Expand All @@ -291,10 +309,7 @@ def records(self):
after chaining these results with those from other BitstringAccumulators.
"""
for result in self.results:
record = dataclasses.asdict(result)
del record['circuit_params']
record.update(**self._meas_spec.circuit_params)
yield record
yield result.as_dict()

def _json_dict_(self):
from cirq.study.result import _pack_digits
Expand Down
30 changes: 29 additions & 1 deletion cirq-core/cirq/work/observable_measurement_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import datetime
import time

Expand Down Expand Up @@ -90,14 +91,41 @@ def test_observable_measured_result():
mean=0,
variance=5 ** 2,
repetitions=4,
circuit_params={},
circuit_params={'phi': 52},
)
assert omr.stddev == 5
assert omr.observable == cirq.Y(a) * cirq.Y(b)
assert omr.init_state == cirq.Z(a) * cirq.Z(b)

cirq.testing.assert_equivalent_repr(omr)

assert omr.as_dict() == {
'init_state': cirq.Z(a) * cirq.Z(b),
'observable': cirq.Y(a) * cirq.Y(b),
'mean': 0,
'variance': 25,
'repetitions': 4,
'param.phi': 52,
}
omr2 = dataclasses.replace(
omr,
circuit_params={
'phi': 52,
'observable': 3.14, # this would be a bad but legal parameter name
'param.phi': -1,
},
)
assert omr2.as_dict() == {
'init_state': cirq.Z(a) * cirq.Z(b),
'observable': cirq.Y(a) * cirq.Y(b),
'mean': 0,
'variance': 25,
'repetitions': 4,
'param.phi': 52,
'param.observable': 3.14,
'param.param.phi': -1,
}


@pytest.fixture()
def example_bsa() -> 'cw.BitstringAccumulator':
Expand Down
94 changes: 85 additions & 9 deletions cirq-core/cirq/work/observable_measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from typing import Iterable, Dict, List

import numpy as np
import pytest

import cirq
import cirq.work as cw
from cirq.work import _MeasurementSpec, BitstringAccumulator
from cirq.work import _MeasurementSpec, BitstringAccumulator, group_settings_greedy, InitObsSetting
from cirq.work.observable_measurement import (
_with_parameterized_layers,
_get_params_for_setting,
Expand All @@ -28,6 +29,11 @@
_check_meas_specs_still_todo,
StoppingCriteria,
_parse_checkpoint_options,
measure_observables_df,
CheckpointFileOptions,
VarianceStoppingCriteria,
measure_observables,
RepetitionsStoppingCriteria,
)


Expand Down Expand Up @@ -448,8 +454,7 @@ def test_measure_grouped_settings(with_circuit_sweep, checkpoint, tmpdir):
sampler=cirq.Simulator(),
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
circuit_sweep=ss,
checkpoint=checkpoint,
checkpoint_fn=checkpoint_fn,
checkpoint=CheckpointFileOptions(checkpoint=checkpoint, checkpoint_fn=checkpoint_fn),
)
if with_circuit_sweep:
for result in results:
Expand Down Expand Up @@ -504,20 +509,91 @@ def test_measure_grouped_settings_read_checkpoint(tmpdir):
grouped_settings=grouped_settings,
sampler=cirq.Simulator(),
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
checkpoint=True,
checkpoint_fn=f'{tmpdir}/obs.json',
checkpoint_other_fn=f'{tmpdir}/obs.json', # Same filename
checkpoint=CheckpointFileOptions(
checkpoint=True,
checkpoint_fn=f'{tmpdir}/obs.json',
checkpoint_other_fn=f'{tmpdir}/obs.json', # Same filename
),
)
_ = cw.measure_grouped_settings(
circuit=circuit,
grouped_settings=grouped_settings,
sampler=cirq.Simulator(),
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
checkpoint=True,
checkpoint_fn=f'{tmpdir}/obs.json',
checkpoint_other_fn=f'{tmpdir}/obs.prev.json',
checkpoint=CheckpointFileOptions(
checkpoint=True,
checkpoint_fn=f'{tmpdir}/obs.json',
checkpoint_other_fn=f'{tmpdir}/obs.prev.json',
),
)
results = cirq.read_json(f'{tmpdir}/obs.json')
(result,) = results # one group
assert result.n_repetitions == 1_000
assert result.means() == [1.0]


Q = cirq.NamedQubit('q')


@pytest.mark.parametrize(
['circuit', 'observable'],
[
(cirq.Circuit(cirq.X(Q) ** 0.2), cirq.Z(Q)),
(cirq.Circuit(cirq.X(Q) ** -0.5, cirq.Z(Q) ** 0.2), cirq.Y(Q)),
(cirq.Circuit(cirq.Y(Q) ** 0.5, cirq.Z(Q) ** 0.2), cirq.X(Q)),
],
)
def test_XYZ_point8(circuit, observable):
# each circuit, observable combination should result in the observable value of 0.8
df = measure_observables_df(
circuit,
[observable],
cirq.Simulator(seed=52),
stopping_criteria=VarianceStoppingCriteria(1e-3 ** 2),
)
assert len(df) == 1, 'one observable'
mean = df.loc[0]['mean']
np.testing.assert_allclose(0.8, mean, atol=1e-2)


def _each_in_its_own_group_grouper(
settings: Iterable[InitObsSetting],
) -> Dict[InitObsSetting, List[InitObsSetting]]:
return {setting: [setting] for setting in settings}


@pytest.mark.parametrize(
'grouper', ['greedy', group_settings_greedy, _each_in_its_own_group_grouper]
)
def test_measure_observable_grouper(grouper):
circuit = cirq.Circuit(cirq.X(Q) ** 0.2)
observables = [
cirq.Z(Q),
cirq.Z(cirq.NamedQubit('q2')),
]
results = measure_observables(
circuit,
observables,
cirq.Simulator(seed=52),
stopping_criteria=RepetitionsStoppingCriteria(50_000),
grouper=grouper,
)
assert len(results) == 2, 'two observables'
np.testing.assert_allclose(0.8, results[0].mean, atol=0.05)
np.testing.assert_allclose(1, results[1].mean, atol=1e-9)


def test_measure_observable_bad_grouper():
circuit = cirq.Circuit(cirq.X(Q) ** 0.2)
observables = [
cirq.Z(Q),
cirq.Z(cirq.NamedQubit('q2')),
]
with pytest.raises(ValueError, match=r'Unknown grouping function'):
_ = measure_observables(
circuit,
observables,
cirq.Simulator(seed=52),
stopping_criteria=RepetitionsStoppingCriteria(50_000),
grouper='super fancy grouper',
)
24 changes: 14 additions & 10 deletions cirq-core/cirq/work/observable_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union, Iterable, Dict, TYPE_CHECKING, Tuple
import dataclasses
from typing import Union, Iterable, Dict, TYPE_CHECKING, ItemsView, Tuple, FrozenSet

from cirq import ops, value
from cirq import ops, value, protocols

if TYPE_CHECKING:
import cirq
from cirq.value.product_state import _NamedOneQubitState

# Workaround for mypy custom dataclasses
from dataclasses import dataclass as json_serializable_dataclass
else:
from cirq.protocols import json_serializable_dataclass


@json_serializable_dataclass(frozen=True)
@dataclasses.dataclass(frozen=True)
class InitObsSetting:
"""A pair of initial state and observable.

Expand Down Expand Up @@ -59,6 +55,9 @@ def __repr__(self):
f'observable={self.observable!r})'
)

def _json_dict_(self):
return protocols.dataclass_json_dict(self)


def _max_weight_observable(observables: Iterable[ops.PauliString]) -> Union[None, ops.PauliString]:
"""Create a new observable that is compatible with all input observables
Expand Down Expand Up @@ -135,7 +134,9 @@ def _fix_precision(val: float, precision) -> int:
return int(val * precision)


def _hashable_param(param_tuples: Iterable[Tuple[str, float]], precision=1e7):
def _hashable_param(
param_tuples: ItemsView[str, float], precision=1e7
) -> FrozenSet[Tuple[str, float]]:
"""Hash circuit parameters using fixed precision.

Circuit parameters can be floats but we also need to use them as
Expand All @@ -144,7 +145,7 @@ def _hashable_param(param_tuples: Iterable[Tuple[str, float]], precision=1e7):
return frozenset((k, _fix_precision(v, precision)) for k, v in param_tuples)


@json_serializable_dataclass(frozen=True)
@dataclasses.dataclass(frozen=True)
class _MeasurementSpec:
"""An encapsulation of all the specifications for one run of a
quantum processor.
Expand All @@ -165,3 +166,6 @@ def __repr__(self):
f'cirq.work._MeasurementSpec(max_setting={self.max_setting!r}, '
f'circuit_params={self.circuit_params!r})'
)

def _json_dict_(self):
return protocols.dataclass_json_dict(self)
Loading