-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add Sampler API for expectation values #3910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,17 +13,59 @@ | |
# limitations under the License. | ||
"""Abstract base class for things sampling quantum circuits.""" | ||
|
||
from typing import List, Optional, TYPE_CHECKING, Union | ||
from typing import Dict, List, Optional, TYPE_CHECKING, Union | ||
import abc | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from cirq import study | ||
from cirq import ops, study | ||
from cirq.work import group_settings_greedy, observables_to_settings | ||
from cirq.work.observable_measurement import ( | ||
_get_params_for_setting, | ||
_with_parameterized_layers, | ||
) | ||
from cirq.work.observable_measurement_data import ( | ||
_stats_from_measurements, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
import cirq | ||
|
||
|
||
def get_samples_mask_for_parameterization( | ||
samples: pd.DataFrame, | ||
param_resolver: 'cirq.ParamResolverOrSimilarType', | ||
mask: Optional[pd.Series] = None, | ||
) -> pd.Series: | ||
"""Generates a 'mask' for a given parameterization of a sample. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is a mask for a parameterization? I think this should be explained. |
||
|
||
Results for the given parameterization can be extracted with: | ||
|
||
samples.loc[series, measure_key] | ||
|
||
where 'series' is the output of this function and 'measure_key' is the | ||
measurement key to extract results for. | ||
|
||
Args: | ||
samples: a pandas DataFrame containing sampling results and their | ||
associated parameterizations. | ||
param_resolver: the specific parameterization of 'samples' to create a | ||
mask for. | ||
mask: a base mask to generate from. Any indices excluded by this mask | ||
will be excluded by the final mask. If this is left unspecified, | ||
it defaults to a blank mask (include-all). | ||
Using a previous result of this method as the 'mask' for | ||
subsequent calls allows construction of a complete mask from | ||
multiple partial parameterizations of the sampling results. | ||
""" | ||
series = mask if mask is not None else np.ones(len(samples), dtype=bool) | ||
pr = study.ParamResolver(param_resolver) | ||
for k in pr: | ||
series = series & (samples[k] == pr.value_of(k)) | ||
return series | ||
|
||
|
||
class Sampler(metaclass=abc.ABCMeta): | ||
"""Something capable of sampling quantum circuits. Simulator or hardware.""" | ||
|
||
|
@@ -133,6 +175,133 @@ def sample( | |
|
||
return pd.concat(results) | ||
|
||
def sample_expectation_values( | ||
self, | ||
program: 'cirq.Circuit', | ||
observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']], | ||
*, | ||
num_samples: int, | ||
params: 'cirq.Sweepable' = None, | ||
permit_terminal_measurements: bool = False, | ||
) -> List[List[float]]: | ||
"""Calculates estimated expectation values from samples of a circuit. | ||
|
||
This is a minimal implementation for measuring observables, and is best | ||
reserved for simple use cases. For more complex use cases, consider | ||
upgrading to `cirq/work/observable_measurement.py`. Additional features | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change to the user visible name instead of the source code name, such as cirq.work.measure_grouped_settings |
||
provided by that toolkit include: | ||
- Chunking of submissions to support more than (max_shots) from | ||
Quantum Engine | ||
- Checkpointing so you don't lose your work halfway through a job | ||
- Measuring to a variance tolerance rather than a pre-specified | ||
number of repetitions | ||
- Readout error symmetrization and mitigation | ||
|
||
This method can be run on any device or simulator that supports circuit | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider reordering or clarifying this. It's unclear which "this method" you are referring to (observable measurement or sample_expectation_values). Maybe consider putting the pointing to observable measurement at the end. I think that makes more sense to explain what this function does before pointing out alternatives. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clarified which method this refers to. I prefer to keep this paragraph at the top of the docstring, since all but the simplest use cases can benefit from using the Observable Measurement toolkit instead. To that end, there's an argument to be made for dropping this API altogether: it doesn't provide functionality outside what will be possible with the Observable Measurement toolkit, and potentially distracts users from the "correct" tools for getting expectation values. As it stands, it primarily serves as a simple reference implementation, as pointed out above. |
||
sampling. Compare with `simulate_expectation_values` in simulator.py, | ||
which is limited to simulators but provides exact results. | ||
|
||
Args: | ||
program: The circuit to simulate. | ||
observables: A list of observables for which to calculate | ||
expectation values. | ||
num_samples: The number of samples to take. Increasing this value | ||
increases the accuracy of the estimate. | ||
params: Parameters to run with the program. | ||
permit_terminal_measurements: If the provided circuit ends in a | ||
measurement, this method will generate an error unless this | ||
is set to True. This is meant to prevent measurements from | ||
ruining expectation value calculations. | ||
|
||
Returns: | ||
A list of expectation-value lists. The outer index determines the | ||
sweep, and the inner index determines the observable. For instance, | ||
results[1][3] would select the fourth observable measured in the | ||
second sweep. | ||
""" | ||
if num_samples <= 0: | ||
raise ValueError( | ||
f'Expectation values require at least one sample. Received: {num_samples}.' | ||
) | ||
if not observables: | ||
raise ValueError('At least one observable must be provided.') | ||
if not permit_terminal_measurements and program.are_any_measurements_terminal(): | ||
raise ValueError( | ||
'Provided circuit has terminal measurements, which may ' | ||
'skew expectation values. If this is intentional, set ' | ||
'permit_terminal_measurements=True.' | ||
) | ||
qubits = ops.QubitOrder.DEFAULT.order_for(program.all_qubits()) | ||
qmap = {q: i for i, q in enumerate(qubits)} | ||
num_qubits = len(qubits) | ||
psums = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optional: consider avoiding abbreviations and use pauli_sums to be more clear. Optional since this does make the pauli_string_aggregate variable name kind of long. |
||
[ops.PauliSum.wrap(o) for o in observables] | ||
if isinstance(observables, List) | ||
else [ops.PauliSum.wrap(observables)] | ||
) | ||
|
||
obs_pstrings: List[ops.PauliString] = [] | ||
for psum in psums: | ||
# This dict will be the PAULI_STRING_LIKE for a new PauliString | ||
# representing this PauliSum. The PauliString is necessary for | ||
# compatibility with `observables_to_settings`. | ||
pstring_aggregate: Dict['cirq.Qid', 'cirq.PAULI_GATE_LIKE'] = {} | ||
for pstring in psum: | ||
pstring_aggregate.update(pstring.items()) | ||
obs_pstrings.append(ops.PauliString(pstring_aggregate)) | ||
|
||
# List of observable settings with the same indexing as 'observables'. | ||
obs_settings = list(observables_to_settings(obs_pstrings, qubits)) | ||
# Map of setting-groups to the settings they cover. | ||
sampling_groups = group_settings_greedy(obs_settings) | ||
sampling_params = { | ||
max_setting: _get_params_for_setting( | ||
setting=max_setting, | ||
flips=[False] * num_qubits, | ||
qubits=qubits, | ||
needs_init_layer=False, | ||
) | ||
for max_setting in sampling_groups | ||
} | ||
|
||
# Parameterized circuit for observable measurement. | ||
mod_program = _with_parameterized_layers(program, qubits, needs_init_layer=False) | ||
|
||
input_params = list(params) if params else {} | ||
num_input_param_values = len(list(study.to_resolvers(input_params))) | ||
# Pairing of input sweeps with each required observable rotation. | ||
sweeps = study.to_sweeps(input_params) | ||
mod_sweep = study.ListSweep(sampling_params.values()) | ||
all_sweeps = [study.Product(sweep, mod_sweep) for sweep in sweeps] | ||
|
||
# Results sampled from the modified circuit. Parameterization ensures | ||
# that all 'z' results map directly to observables. | ||
samples = self.sample(mod_program, repetitions=num_samples, params=all_sweeps) | ||
results: List[List[float]] = [[0] * len(psums) for _ in range(num_input_param_values)] | ||
|
||
for max_setting, grouped_settings in sampling_groups.items(): | ||
# Filter 'samples' down to results matching each 'max_setting'. | ||
series = get_samples_mask_for_parameterization(samples, sampling_params[max_setting]) | ||
|
||
for sweep_idx, pr in enumerate(study.to_resolvers(input_params)): | ||
# Filter 'series' down to results matching each sweep. | ||
subseries = get_samples_mask_for_parameterization(samples, pr, series) | ||
bitstrings = np.asarray( | ||
[ | ||
list(np.binary_repr(intval).zfill(num_qubits)) | ||
for intval in samples.loc[subseries, 'z'] | ||
], | ||
dtype=np.uint8, | ||
) | ||
for setting in grouped_settings: | ||
obs_idx = obs_settings.index(setting) | ||
results[sweep_idx][obs_idx] = sum( | ||
_stats_from_measurements(bitstrings, qmap, pstr, atol=1e-8)[0] | ||
for pstr in psums[obs_idx] | ||
) | ||
|
||
return results | ||
|
||
@abc.abstractmethod | ||
def run_sweep( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is bad practice to import private functions from other modules. It breaks naming assumptions. We should either change the name to indicate that it is no longer private or remove the import (or create a new method that is public in those modules).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, although this behavior is consistent with the observable measurements imports (all files are under
cirq/work
):Cirq/cirq/work/observable_measurement_data.py
Lines 23 to 29 in 845836a
Should I send a "prequel" PR that fixes this before moving forward with this one?