Skip to content

Add Sampler API for expectation values #3910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 171 additions & 2 deletions cirq/work/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,59 @@
# limitations under the License.
"""Abstract base class for things sampling quantum circuits."""

from typing import List, Optional, TYPE_CHECKING, Union
from typing import Dict, List, Optional, TYPE_CHECKING, Union
import abc

import numpy as np
import pandas as pd

from cirq import study
from cirq import ops, study
from cirq.work import group_settings_greedy, observables_to_settings
from cirq.work.observable_measurement import (
_get_params_for_setting,
_with_parameterized_layers,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is bad practice to import private functions from other modules. It breaks naming assumptions. We should either change the name to indicate that it is no longer private or remove the import (or create a new method that is public in those modules).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, although this behavior is consistent with the observable measurements imports (all files are under cirq/work):

from cirq.work.observable_settings import (
InitObsSetting,
_max_weight_observable,
_max_weight_state,
_MeasurementSpec,
zeros_state,
)

Should I send a "prequel" PR that fixes this before moving forward with this one?

)
from cirq.work.observable_measurement_data import (
_stats_from_measurements,
)

if TYPE_CHECKING:
import cirq


def get_samples_mask_for_parameterization(
samples: pd.DataFrame,
param_resolver: 'cirq.ParamResolverOrSimilarType',
mask: Optional[pd.Series] = None,
) -> pd.Series:
"""Generates a 'mask' for a given parameterization of a sample.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a mask for a parameterization? I think this should be explained.


Results for the given parameterization can be extracted with:

samples.loc[series, measure_key]

where 'series' is the output of this function and 'measure_key' is the
measurement key to extract results for.

Args:
samples: a pandas DataFrame containing sampling results and their
associated parameterizations.
param_resolver: the specific parameterization of 'samples' to create a
mask for.
mask: a base mask to generate from. Any indices excluded by this mask
will be excluded by the final mask. If this is left unspecified,
it defaults to a blank mask (include-all).
Using a previous result of this method as the 'mask' for
subsequent calls allows construction of a complete mask from
multiple partial parameterizations of the sampling results.
"""
series = mask if mask is not None else np.ones(len(samples), dtype=bool)
pr = study.ParamResolver(param_resolver)
for k in pr:
series = series & (samples[k] == pr.value_of(k))
return series


class Sampler(metaclass=abc.ABCMeta):
"""Something capable of sampling quantum circuits. Simulator or hardware."""

Expand Down Expand Up @@ -133,6 +175,133 @@ def sample(

return pd.concat(results)

def sample_expectation_values(
self,
program: 'cirq.Circuit',
observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']],
*,
num_samples: int,
params: 'cirq.Sweepable' = None,
permit_terminal_measurements: bool = False,
) -> List[List[float]]:
"""Calculates estimated expectation values from samples of a circuit.

This is a minimal implementation for measuring observables, and is best
reserved for simple use cases. For more complex use cases, consider
upgrading to `cirq/work/observable_measurement.py`. Additional features
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to the user visible name instead of the source code name, such as cirq.work.measure_grouped_settings

provided by that toolkit include:
- Chunking of submissions to support more than (max_shots) from
Quantum Engine
- Checkpointing so you don't lose your work halfway through a job
- Measuring to a variance tolerance rather than a pre-specified
number of repetitions
- Readout error symmetrization and mitigation

This method can be run on any device or simulator that supports circuit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider reordering or clarifying this. It's unclear which "this method" you are referring to (observable measurement or sample_expectation_values). Maybe consider putting the pointing to observable measurement at the end. I think that makes more sense to explain what this function does before pointing out alternatives.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarified which method this refers to. I prefer to keep this paragraph at the top of the docstring, since all but the simplest use cases can benefit from using the Observable Measurement toolkit instead.

To that end, there's an argument to be made for dropping this API altogether: it doesn't provide functionality outside what will be possible with the Observable Measurement toolkit, and potentially distracts users from the "correct" tools for getting expectation values. As it stands, it primarily serves as a simple reference implementation, as pointed out above.

sampling. Compare with `simulate_expectation_values` in simulator.py,
which is limited to simulators but provides exact results.

Args:
program: The circuit to simulate.
observables: A list of observables for which to calculate
expectation values.
num_samples: The number of samples to take. Increasing this value
increases the accuracy of the estimate.
params: Parameters to run with the program.
permit_terminal_measurements: If the provided circuit ends in a
measurement, this method will generate an error unless this
is set to True. This is meant to prevent measurements from
ruining expectation value calculations.

Returns:
A list of expectation-value lists. The outer index determines the
sweep, and the inner index determines the observable. For instance,
results[1][3] would select the fourth observable measured in the
second sweep.
"""
if num_samples <= 0:
raise ValueError(
f'Expectation values require at least one sample. Received: {num_samples}.'
)
if not observables:
raise ValueError('At least one observable must be provided.')
if not permit_terminal_measurements and program.are_any_measurements_terminal():
raise ValueError(
'Provided circuit has terminal measurements, which may '
'skew expectation values. If this is intentional, set '
'permit_terminal_measurements=True.'
)
qubits = ops.QubitOrder.DEFAULT.order_for(program.all_qubits())
qmap = {q: i for i, q in enumerate(qubits)}
num_qubits = len(qubits)
psums = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: consider avoiding abbreviations and use pauli_sums to be more clear. Optional since this does make the pauli_string_aggregate variable name kind of long.

[ops.PauliSum.wrap(o) for o in observables]
if isinstance(observables, List)
else [ops.PauliSum.wrap(observables)]
)

obs_pstrings: List[ops.PauliString] = []
for psum in psums:
# This dict will be the PAULI_STRING_LIKE for a new PauliString
# representing this PauliSum. The PauliString is necessary for
# compatibility with `observables_to_settings`.
pstring_aggregate: Dict['cirq.Qid', 'cirq.PAULI_GATE_LIKE'] = {}
for pstring in psum:
pstring_aggregate.update(pstring.items())
obs_pstrings.append(ops.PauliString(pstring_aggregate))

# List of observable settings with the same indexing as 'observables'.
obs_settings = list(observables_to_settings(obs_pstrings, qubits))
# Map of setting-groups to the settings they cover.
sampling_groups = group_settings_greedy(obs_settings)
sampling_params = {
max_setting: _get_params_for_setting(
setting=max_setting,
flips=[False] * num_qubits,
qubits=qubits,
needs_init_layer=False,
)
for max_setting in sampling_groups
}

# Parameterized circuit for observable measurement.
mod_program = _with_parameterized_layers(program, qubits, needs_init_layer=False)

input_params = list(params) if params else {}
num_input_param_values = len(list(study.to_resolvers(input_params)))
# Pairing of input sweeps with each required observable rotation.
sweeps = study.to_sweeps(input_params)
mod_sweep = study.ListSweep(sampling_params.values())
all_sweeps = [study.Product(sweep, mod_sweep) for sweep in sweeps]

# Results sampled from the modified circuit. Parameterization ensures
# that all 'z' results map directly to observables.
samples = self.sample(mod_program, repetitions=num_samples, params=all_sweeps)
results: List[List[float]] = [[0] * len(psums) for _ in range(num_input_param_values)]

for max_setting, grouped_settings in sampling_groups.items():
# Filter 'samples' down to results matching each 'max_setting'.
series = get_samples_mask_for_parameterization(samples, sampling_params[max_setting])

for sweep_idx, pr in enumerate(study.to_resolvers(input_params)):
# Filter 'series' down to results matching each sweep.
subseries = get_samples_mask_for_parameterization(samples, pr, series)
bitstrings = np.asarray(
[
list(np.binary_repr(intval).zfill(num_qubits))
for intval in samples.loc[subseries, 'z']
],
dtype=np.uint8,
)
for setting in grouped_settings:
obs_idx = obs_settings.index(setting)
results[sweep_idx][obs_idx] = sum(
_stats_from_measurements(bitstrings, qmap, pstr, atol=1e-8)[0]
for pstr in psums[obs_idx]
)

return results

@abc.abstractmethod
def run_sweep(
self,
Expand Down
123 changes: 123 additions & 0 deletions cirq/work/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for cirq.Sampler."""
from typing import List
import pytest

import numpy as np
Expand Down Expand Up @@ -132,6 +133,128 @@ def test_sampler_sample_inconsistent_keys():
)


def test_sampler_simple_sample_expectation_values():
a = cirq.LineQubit(0)
sampler = cirq.Simulator()
circuit = cirq.Circuit(cirq.H(a))
obs = cirq.X(a)
results = sampler.sample_expectation_values(circuit, [obs], num_samples=1000)

assert np.allclose(results, [[1]])


def test_sampler_sample_expectation_values_calculation():
class DeterministicImbalancedStateSampler(cirq.Sampler):
"""A simple, deterministic mock sampler.

Pretends to sample from a state vector with a 3:1 balance between the
probabilities of the |0) and |1) state.
"""

def run_sweep(
self,
program: 'cirq.Circuit',
params: 'cirq.Sweepable',
repetitions: int = 1,
) -> List['cirq.Result']:
results = np.zeros((repetitions, 1), dtype=bool)
for idx in range(repetitions // 4):
results[idx][0] = 1
return [
cirq.Result(params=pr, measurements={'z': results})
for pr in cirq.study.to_resolvers(params)
]

a = cirq.LineQubit(0)
sampler = DeterministicImbalancedStateSampler()
# This circuit is not actually sampled, but the mock sampler above gives
# a reasonable approximation of it.
circuit = cirq.Circuit(cirq.X(a) ** (1 / 3))
obs = cirq.Z(a)
results = sampler.sample_expectation_values(circuit, [obs], num_samples=1000)

# (0.75 * 1) + (0.25 * -1) = 0.5
assert np.allclose(results, [[0.5]])


def test_sampler_sample_expectation_values_multi_param():
a = cirq.LineQubit(0)
t = sympy.Symbol('t')
sampler = cirq.Simulator(seed=1)
circuit = cirq.Circuit(cirq.X(a) ** t)
obs = cirq.Z(a)
results = sampler.sample_expectation_values(
circuit, [obs], num_samples=5, params=cirq.Linspace('t', 0, 2, 3)
)

assert np.allclose(results, [[1], [-1], [1]])


def test_sampler_sample_expectation_values_multi_qubit():
q = cirq.LineQubit.range(3)
sampler = cirq.Simulator(seed=1)
circuit = cirq.Circuit(cirq.X(q[0]), cirq.X(q[1]), cirq.X(q[2]))
obs = cirq.Z(q[0]) + cirq.Z(q[1]) + cirq.Z(q[2])
results = sampler.sample_expectation_values(circuit, [obs], num_samples=5)

assert np.allclose(results, [[-3]])


def test_sampler_sample_expectation_values_composite():
# Tests multi-{param,qubit} sampling together in one circuit.
q = cirq.LineQubit.range(3)
t = [sympy.Symbol(f't{x}') for x in range(3)]

sampler = cirq.Simulator(seed=1)
circuit = cirq.Circuit(
cirq.X(q[0]) ** t[0],
cirq.X(q[1]) ** t[1],
cirq.X(q[2]) ** t[2],
)

obs = [cirq.Z(q[x]) for x in range(3)]
# t0 is in the inner loop to make bit-ordering easier below.
params = ([{'t0': t0, 't1': t1, 't2': t2} for t2 in [0, 1] for t1 in [0, 1] for t0 in [0, 1]],)
results = sampler.sample_expectation_values(
circuit,
obs,
num_samples=5,
params=params,
)
print('\n'.join(str(r) for r in results))

assert len(results) == 8
assert np.allclose(
results,
[
[+1, +1, +1],
[-1, +1, +1],
[+1, -1, +1],
[-1, -1, +1],
[+1, +1, -1],
[-1, +1, -1],
[+1, -1, -1],
[-1, -1, -1],
],
)


def test_sampler_simple_sample_expectation_requirements():
a = cirq.LineQubit(0)
sampler = cirq.Simulator(seed=1)
circuit = cirq.Circuit(cirq.H(a))
obs = cirq.X(a)
with pytest.raises(ValueError, match='at least one sample'):
_ = sampler.sample_expectation_values(circuit, [obs], num_samples=0)

with pytest.raises(ValueError, match='At least one observable'):
_ = sampler.sample_expectation_values(circuit, [], num_samples=1)

circuit.append(cirq.measure(a, key='out'))
with pytest.raises(ValueError, match='permit_terminal_measurements'):
_ = sampler.sample_expectation_values(circuit, [obs], num_samples=1)


@pytest.mark.asyncio
async def test_sampler_async_not_run_inline():
ran = False
Expand Down