|
13 | 13 | # limitations under the License.
|
14 | 14 | """Abstract base class for things sampling quantum circuits."""
|
15 | 15 |
|
16 |
| -from typing import List, Optional, TYPE_CHECKING, Union |
17 | 16 | import abc
|
| 17 | +from typing import List, Optional, TYPE_CHECKING, Union, Dict, FrozenSet |
18 | 18 |
|
19 | 19 | import pandas as pd
|
20 |
| - |
21 |
| -from cirq import study |
| 20 | +from cirq import study, ops |
| 21 | +from cirq.work.observable_measurement import measure_observables, RepetitionsStoppingCriteria |
| 22 | +from cirq.work.observable_settings import _hashable_param |
22 | 23 |
|
23 | 24 | if TYPE_CHECKING:
|
24 | 25 | import cirq
|
@@ -253,3 +254,107 @@ def run_batch(
|
253 | 254 | self.run_sweep(circuit, params=params, repetitions=repetitions)
|
254 | 255 | for circuit, params, repetitions in zip(programs, params_list, repetitions)
|
255 | 256 | ]
|
| 257 | + |
| 258 | + def sample_expectation_values( |
| 259 | + self, |
| 260 | + program: 'cirq.AbstractCircuit', |
| 261 | + observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']], |
| 262 | + *, |
| 263 | + num_samples: int, |
| 264 | + params: 'cirq.Sweepable' = None, |
| 265 | + permit_terminal_measurements: bool = False, |
| 266 | + ) -> List[List[float]]: |
| 267 | + """Calculates estimated expectation values from samples of a circuit. |
| 268 | +
|
| 269 | + This is a minimal implementation for measuring observables, and is best reserved for |
| 270 | + simple use cases. For more complex use cases, consider upgrading to |
| 271 | + `cirq.work.observable_measurement`. Additional features provided by that toolkit include: |
| 272 | + - Chunking of submissions to support more than (max_shots) from |
| 273 | + Quantum Engine |
| 274 | + - Checkpointing so you don't lose your work halfway through a job |
| 275 | + - Measuring to a variance tolerance rather than a pre-specified |
| 276 | + number of repetitions |
| 277 | + - Readout error symmetrization and mitigation |
| 278 | +
|
| 279 | + This method can be run on any device or simulator that supports circuit sampling. Compare |
| 280 | + with `simulate_expectation_values` in simulator.py, which is limited to simulators |
| 281 | + but provides exact results. |
| 282 | +
|
| 283 | + Args: |
| 284 | + program: The circuit which prepares a state from which we sample expectation values. |
| 285 | + observables: A list of observables for which to calculate expectation values. |
| 286 | + num_samples: The number of samples to take. Increasing this value increases the |
| 287 | + statistical accuracy of the estimate. |
| 288 | + params: Parameters to run with the program. |
| 289 | + permit_terminal_measurements: If the provided circuit ends in a measurement, this |
| 290 | + method will generate an error unless this is set to True. This is meant to |
| 291 | + prevent measurements from ruining expectation value calculations. |
| 292 | +
|
| 293 | + Returns: |
| 294 | + A list of expectation-value lists. The outer index determines the sweep, and the inner |
| 295 | + index determines the observable. For instance, results[1][3] would select the fourth |
| 296 | + observable measured in the second sweep. |
| 297 | + """ |
| 298 | + if num_samples <= 0: |
| 299 | + raise ValueError( |
| 300 | + f'Expectation values require at least one sample. Received: {num_samples}.' |
| 301 | + ) |
| 302 | + if not observables: |
| 303 | + raise ValueError('At least one observable must be provided.') |
| 304 | + if not permit_terminal_measurements and program.are_any_measurements_terminal(): |
| 305 | + raise ValueError( |
| 306 | + 'Provided circuit has terminal measurements, which may ' |
| 307 | + 'skew expectation values. If this is intentional, set ' |
| 308 | + 'permit_terminal_measurements=True.' |
| 309 | + ) |
| 310 | + |
| 311 | + # Wrap input into a list of pauli sum |
| 312 | + pauli_sums: List['cirq.PauliSum'] = ( |
| 313 | + [ops.PauliSum.wrap(o) for o in observables] |
| 314 | + if isinstance(observables, List) |
| 315 | + else [ops.PauliSum.wrap(observables)] |
| 316 | + ) |
| 317 | + del observables |
| 318 | + |
| 319 | + # Flatten Pauli Sum into one big list of Pauli String |
| 320 | + # Keep track of which Pauli Sum each one was from. |
| 321 | + flat_pstrings: List['cirq.PauliString'] = [] |
| 322 | + pstring_to_psum_i: Dict['cirq.PauliString', int] = {} |
| 323 | + for psum_i, pauli_sum in enumerate(pauli_sums): |
| 324 | + for pstring in pauli_sum: |
| 325 | + flat_pstrings.append(pstring) |
| 326 | + pstring_to_psum_i[pstring] = psum_i |
| 327 | + |
| 328 | + # Flatten Circuit Sweep into one big list of Params. |
| 329 | + # Keep track of their indices so we can map back. |
| 330 | + circuit_sweep = study.UnitSweep if params is None else study.to_sweep(params) |
| 331 | + all_circuit_params: List[Dict[str, float]] = [ |
| 332 | + dict(circuit_params) for circuit_params in circuit_sweep.param_tuples() |
| 333 | + ] |
| 334 | + circuit_param_to_sweep_i: Dict[FrozenSet[str, float], int] = { |
| 335 | + _hashable_param(param.items()): i for i, param in enumerate(all_circuit_params) |
| 336 | + } |
| 337 | + del params |
| 338 | + |
| 339 | + accumulators = measure_observables( |
| 340 | + circuit=program, |
| 341 | + observables=flat_pstrings, |
| 342 | + sampler=self, |
| 343 | + stopping_criteria=RepetitionsStoppingCriteria(total_repetitions=num_samples), |
| 344 | + readout_symmetrization=False, |
| 345 | + circuit_sweep=circuit_sweep, |
| 346 | + checkpoint=False, |
| 347 | + ) |
| 348 | + |
| 349 | + # Results are ordered by how they're grouped. Since we want the (circuit_sweep, pauli_sum) |
| 350 | + # nesting structure, we place the measured values according to the back-mappings we set up |
| 351 | + # above. We also do the sum operation to aggregate multiple PauliString measured values |
| 352 | + # for a given PauliSum. |
| 353 | + results: List[List[float]] = [[0] * len(pauli_sums) for _ in range(len(all_circuit_params))] |
| 354 | + for acc in accumulators: |
| 355 | + for res in acc.results: |
| 356 | + param_i = circuit_param_to_sweep_i[_hashable_param(res.circuit_params.items())] |
| 357 | + psum_i = pstring_to_psum_i[res.setting.observable] |
| 358 | + results[param_i][psum_i] += res.mean |
| 359 | + |
| 360 | + return results |
0 commit comments