Skip to content

Commit f92fc65

Browse files
committed
Sampler API
1 parent dcb59f3 commit f92fc65

File tree

4 files changed

+244
-19
lines changed

4 files changed

+244
-19
lines changed

cirq-core/cirq/work/observable_measurement.py

+11-14
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262

6363

6464
def _with_parameterized_layers(
65-
circuit: 'cirq.Circuit',
65+
circuit: 'cirq.AbstractCircuit',
6666
qubits: Sequence['cirq.Qid'],
6767
needs_init_layer: bool,
6868
) -> 'cirq.Circuit':
@@ -84,9 +84,9 @@ def _with_parameterized_layers(
8484
meas_mom = ops.Moment([ops.measure(*qubits, key='z')])
8585
if needs_init_layer:
8686
total_circuit = circuits.Circuit([x_beg_mom, y_beg_mom])
87-
total_circuit += circuit.copy()
87+
total_circuit += circuit.unfreeze()
8888
else:
89-
total_circuit = circuit.copy()
89+
total_circuit = circuit.unfreeze()
9090
total_circuit.append([x_end_mom, y_end_mom, meas_mom])
9191
return total_circuit
9292

@@ -445,7 +445,7 @@ def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting
445445

446446

447447
def measure_grouped_settings(
448-
circuit: 'cirq.Circuit',
448+
circuit: 'cirq.AbstractCircuit',
449449
grouped_settings: Dict[InitObsSetting, List[InitObsSetting]],
450450
sampler: 'cirq.Sampler',
451451
stopping_criteria: StoppingCriteria,
@@ -523,10 +523,7 @@ def measure_grouped_settings(
523523
for max_setting, circuit_params in itertools.product(
524524
grouped_settings.keys(), circuit_sweep.param_tuples()
525525
):
526-
# The type annotation for Param is just `Iterable`.
527-
# We make sure that it's truly a tuple.
528526
circuit_params = dict(circuit_params)
529-
530527
meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params)
531528
accumulator = BitstringAccumulator(
532529
meas_spec=meas_spec,
@@ -616,8 +613,8 @@ def _parse_grouper(grouper: Union[str, GROUPER_T] = group_settings_greedy) -> GR
616613

617614

618615
def _get_all_qubits(
619-
circuit: circuits.Circuit,
620-
observables: Iterable[ops.PauliString],
616+
circuit: 'cirq.AbstractCircuit',
617+
observables: Iterable['cirq.PauliString'],
621618
) -> List['cirq.Qid']:
622619
"""Helper function for `measure_observables` to get all qubits from a circuit and a
623620
collection of observables."""
@@ -629,8 +626,8 @@ def _get_all_qubits(
629626

630627

631628
def measure_observables(
632-
circuit: circuits.Circuit,
633-
observables: Iterable[ops.PauliString],
629+
circuit: 'cirq.AbstractCircuit',
630+
observables: Iterable['cirq.PauliString'],
634631
sampler: Union['cirq.Simulator', 'cirq.Sampler'],
635632
stopping_criteria: Union[str, StoppingCriteria],
636633
stopping_criteria_val: Optional[float] = None,
@@ -642,7 +639,7 @@ def measure_observables(
642639
checkpoint: bool = False,
643640
checkpoint_fn: Optional[str] = None,
644641
checkpoint_other_fn: Optional[str] = None,
645-
):
642+
) -> List[BitstringAccumulator]:
646643
"""Measure a collection of PauliString observables for a state prepared by a Circuit.
647644
648645
If you need more control over the process, please see `measure_grouped_settings` for a
@@ -708,8 +705,8 @@ def measure_observables(
708705

709706

710707
def measure_observables_df(
711-
circuit: circuits.Circuit,
712-
observables: Iterable[ops.PauliString],
708+
circuit: 'cirq.AbstractCircuit',
709+
observables: Iterable['cirq.PauliString'],
713710
sampler: Union['cirq.Simulator', 'cirq.Sampler'],
714711
stopping_criteria: Union[str, StoppingCriteria],
715712
stopping_criteria_val: Optional[float] = None,

cirq-core/cirq/work/observable_settings.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Union, Iterable, Dict, TYPE_CHECKING, Tuple
15+
from typing import Union, Iterable, Dict, TYPE_CHECKING, ItemsView
1616

1717
from cirq import ops, value
1818

@@ -135,7 +135,7 @@ def _fix_precision(val: float, precision) -> int:
135135
return int(val * precision)
136136

137137

138-
def _hashable_param(param_tuples: Iterable[Tuple[str, float]], precision=1e7):
138+
def _hashable_param(param_tuples: ItemsView[str, float], precision=1e7):
139139
"""Hash circuit parameters using fixed precision.
140140
141141
Circuit parameters can be floats but we also need to use them as

cirq-core/cirq/work/sampler.py

+108-3
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414
"""Abstract base class for things sampling quantum circuits."""
1515

16-
from typing import List, Optional, TYPE_CHECKING, Union
1716
import abc
17+
from typing import List, Optional, TYPE_CHECKING, Union, Dict, FrozenSet
1818

1919
import pandas as pd
20-
21-
from cirq import study
20+
from cirq import study, ops
21+
from cirq.work.observable_measurement import measure_observables, RepetitionsStoppingCriteria
22+
from cirq.work.observable_settings import _hashable_param
2223

2324
if TYPE_CHECKING:
2425
import cirq
@@ -253,3 +254,107 @@ def run_batch(
253254
self.run_sweep(circuit, params=params, repetitions=repetitions)
254255
for circuit, params, repetitions in zip(programs, params_list, repetitions)
255256
]
257+
258+
def sample_expectation_values(
259+
self,
260+
program: 'cirq.AbstractCircuit',
261+
observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']],
262+
*,
263+
num_samples: int,
264+
params: 'cirq.Sweepable' = None,
265+
permit_terminal_measurements: bool = False,
266+
) -> List[List[float]]:
267+
"""Calculates estimated expectation values from samples of a circuit.
268+
269+
This is a minimal implementation for measuring observables, and is best reserved for
270+
simple use cases. For more complex use cases, consider upgrading to
271+
`cirq.work.observable_measurement`. Additional features provided by that toolkit include:
272+
- Chunking of submissions to support more than (max_shots) from
273+
Quantum Engine
274+
- Checkpointing so you don't lose your work halfway through a job
275+
- Measuring to a variance tolerance rather than a pre-specified
276+
number of repetitions
277+
- Readout error symmetrization and mitigation
278+
279+
This method can be run on any device or simulator that supports circuit sampling. Compare
280+
with `simulate_expectation_values` in simulator.py, which is limited to simulators
281+
but provides exact results.
282+
283+
Args:
284+
program: The circuit which prepares a state from which we sample expectation values.
285+
observables: A list of observables for which to calculate expectation values.
286+
num_samples: The number of samples to take. Increasing this value increases the
287+
statistical accuracy of the estimate.
288+
params: Parameters to run with the program.
289+
permit_terminal_measurements: If the provided circuit ends in a measurement, this
290+
method will generate an error unless this is set to True. This is meant to
291+
prevent measurements from ruining expectation value calculations.
292+
293+
Returns:
294+
A list of expectation-value lists. The outer index determines the sweep, and the inner
295+
index determines the observable. For instance, results[1][3] would select the fourth
296+
observable measured in the second sweep.
297+
"""
298+
if num_samples <= 0:
299+
raise ValueError(
300+
f'Expectation values require at least one sample. Received: {num_samples}.'
301+
)
302+
if not observables:
303+
raise ValueError('At least one observable must be provided.')
304+
if not permit_terminal_measurements and program.are_any_measurements_terminal():
305+
raise ValueError(
306+
'Provided circuit has terminal measurements, which may '
307+
'skew expectation values. If this is intentional, set '
308+
'permit_terminal_measurements=True.'
309+
)
310+
311+
# Wrap input into a list of pauli sum
312+
pauli_sums: List['cirq.PauliSum'] = (
313+
[ops.PauliSum.wrap(o) for o in observables]
314+
if isinstance(observables, List)
315+
else [ops.PauliSum.wrap(observables)]
316+
)
317+
del observables
318+
319+
# Flatten Pauli Sum into one big list of Pauli String
320+
# Keep track of which Pauli Sum each one was from.
321+
flat_pstrings: List['cirq.PauliString'] = []
322+
pstring_to_psum_i: Dict['cirq.PauliString', int] = {}
323+
for psum_i, pauli_sum in enumerate(pauli_sums):
324+
for pstring in pauli_sum:
325+
flat_pstrings.append(pstring)
326+
pstring_to_psum_i[pstring] = psum_i
327+
328+
# Flatten Circuit Sweep into one big list of Params.
329+
# Keep track of their indices so we can map back.
330+
circuit_sweep = study.UnitSweep if params is None else study.to_sweep(params)
331+
all_circuit_params: List[Dict[str, float]] = [
332+
dict(circuit_params) for circuit_params in circuit_sweep.param_tuples()
333+
]
334+
circuit_param_to_sweep_i: Dict[FrozenSet[str, float], int] = {
335+
_hashable_param(param.items()): i for i, param in enumerate(all_circuit_params)
336+
}
337+
del params
338+
339+
accumulators = measure_observables(
340+
circuit=program,
341+
observables=flat_pstrings,
342+
sampler=self,
343+
stopping_criteria=RepetitionsStoppingCriteria(total_repetitions=num_samples),
344+
readout_symmetrization=False,
345+
circuit_sweep=circuit_sweep,
346+
checkpoint=False,
347+
)
348+
349+
# Results are ordered by how they're grouped. Since we want the (circuit_sweep, pauli_sum)
350+
# nesting structure, we place the measured values according to the back-mappings we set up
351+
# above. We also do the sum operation to aggregate multiple PauliString measured values
352+
# for a given PauliSum.
353+
results: List[List[float]] = [[0] * len(pauli_sums) for _ in range(len(all_circuit_params))]
354+
for acc in accumulators:
355+
for res in acc.results:
356+
param_i = circuit_param_to_sweep_i[_hashable_param(res.circuit_params.items())]
357+
psum_i = pstring_to_psum_i[res.setting.observable]
358+
results[param_i][psum_i] += res.mean
359+
360+
return results

cirq-core/cirq/work/sampler_test.py

+123
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Tests for cirq.Sampler."""
15+
from typing import List
16+
1517
import pytest
1618

1719
import numpy as np
@@ -198,3 +200,124 @@ def test_sampler_run_batch_bad_input_lengths():
198200
_ = sampler.run_batch(
199201
[circuit1, circuit2], params_list=[params1, params2], repetitions=[1, 2, 3]
200202
)
203+
204+
205+
def test_sampler_simple_sample_expectation_values():
206+
a = cirq.LineQubit(0)
207+
sampler = cirq.Simulator()
208+
circuit = cirq.Circuit(cirq.H(a))
209+
obs = cirq.X(a)
210+
results = sampler.sample_expectation_values(circuit, [obs], num_samples=1000)
211+
212+
assert np.allclose(results, [[1]])
213+
214+
215+
def test_sampler_sample_expectation_values_calculation():
216+
class DeterministicImbalancedStateSampler(cirq.Sampler):
217+
"""A simple, deterministic mock sampler.
218+
Pretends to sample from a state vector with a 3:1 balance between the
219+
probabilities of the |0) and |1) state.
220+
"""
221+
222+
def run_sweep(
223+
self,
224+
program: 'cirq.Circuit',
225+
params: 'cirq.Sweepable',
226+
repetitions: int = 1,
227+
) -> List['cirq.Result']:
228+
results = np.zeros((repetitions, 1), dtype=bool)
229+
for idx in range(repetitions // 4):
230+
results[idx][0] = 1
231+
return [
232+
cirq.Result(params=pr, measurements={'z': results})
233+
for pr in cirq.study.to_resolvers(params)
234+
]
235+
236+
a = cirq.LineQubit(0)
237+
sampler = DeterministicImbalancedStateSampler()
238+
# This circuit is not actually sampled, but the mock sampler above gives
239+
# a reasonable approximation of it.
240+
circuit = cirq.Circuit(cirq.X(a) ** (1 / 3))
241+
obs = cirq.Z(a)
242+
results = sampler.sample_expectation_values(circuit, [obs], num_samples=1000)
243+
244+
# (0.75 * 1) + (0.25 * -1) = 0.5
245+
assert np.allclose(results, [[0.5]])
246+
247+
248+
def test_sampler_sample_expectation_values_multi_param():
249+
a = cirq.LineQubit(0)
250+
t = sympy.Symbol('t')
251+
sampler = cirq.Simulator(seed=1)
252+
circuit = cirq.Circuit(cirq.X(a) ** t)
253+
obs = cirq.Z(a)
254+
results = sampler.sample_expectation_values(
255+
circuit, [obs], num_samples=5, params=cirq.Linspace('t', 0, 2, 3)
256+
)
257+
258+
assert np.allclose(results, [[1], [-1], [1]])
259+
260+
261+
def test_sampler_sample_expectation_values_multi_qubit():
262+
q = cirq.LineQubit.range(3)
263+
sampler = cirq.Simulator(seed=1)
264+
circuit = cirq.Circuit(cirq.X(q[0]), cirq.X(q[1]), cirq.X(q[2]))
265+
obs = cirq.Z(q[0]) + cirq.Z(q[1]) + cirq.Z(q[2])
266+
results = sampler.sample_expectation_values(circuit, [obs], num_samples=5)
267+
268+
assert np.allclose(results, [[-3]])
269+
270+
271+
def test_sampler_sample_expectation_values_composite():
272+
# Tests multi-{param,qubit} sampling together in one circuit.
273+
q = cirq.LineQubit.range(3)
274+
t = [sympy.Symbol(f't{x}') for x in range(3)]
275+
276+
sampler = cirq.Simulator(seed=1)
277+
circuit = cirq.Circuit(
278+
cirq.X(q[0]) ** t[0],
279+
cirq.X(q[1]) ** t[1],
280+
cirq.X(q[2]) ** t[2],
281+
)
282+
283+
obs = [cirq.Z(q[x]) for x in range(3)]
284+
# t0 is in the inner loop to make bit-ordering easier below.
285+
params = ([{'t0': t0, 't1': t1, 't2': t2} for t2 in [0, 1] for t1 in [0, 1] for t0 in [0, 1]],)
286+
results = sampler.sample_expectation_values(
287+
circuit,
288+
obs,
289+
num_samples=5,
290+
params=params,
291+
)
292+
print('\n'.join(str(r) for r in results))
293+
294+
assert len(results) == 8
295+
assert np.allclose(
296+
results,
297+
[
298+
[+1, +1, +1],
299+
[-1, +1, +1],
300+
[+1, -1, +1],
301+
[-1, -1, +1],
302+
[+1, +1, -1],
303+
[-1, +1, -1],
304+
[+1, -1, -1],
305+
[-1, -1, -1],
306+
],
307+
)
308+
309+
310+
def test_sampler_simple_sample_expectation_requirements():
311+
a = cirq.LineQubit(0)
312+
sampler = cirq.Simulator(seed=1)
313+
circuit = cirq.Circuit(cirq.H(a))
314+
obs = cirq.X(a)
315+
with pytest.raises(ValueError, match='at least one sample'):
316+
_ = sampler.sample_expectation_values(circuit, [obs], num_samples=0)
317+
318+
with pytest.raises(ValueError, match='At least one observable'):
319+
_ = sampler.sample_expectation_values(circuit, [], num_samples=1)
320+
321+
circuit.append(cirq.measure(a, key='out'))
322+
with pytest.raises(ValueError, match='permit_terminal_measurements'):
323+
_ = sampler.sample_expectation_values(circuit, [obs], num_samples=1)

0 commit comments

Comments
 (0)