Skip to content

Commit ec68a55

Browse files
authored
Add support for repeated keys in result protos (#5907)
Review: @wcourtney
1 parent f35a1e2 commit ec68a55

File tree

5 files changed

+206
-72
lines changed

5 files changed

+206
-72
lines changed

Diff for: cirq-google/cirq_google/api/v2/result.proto

+8-4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ message MeasurementResult {
4646
// The measurement key for the measurement.
4747
string key = 1;
4848

49+
// Number of instances of this key in each circuit repetition.
50+
int32 instances = 3;
51+
4952
// For each qubit that is measured, these are the measurement results.
5053
repeated QubitMeasurementResult qubit_measurement_results = 2;
5154
}
@@ -55,14 +58,15 @@ message QubitMeasurementResult {
5558
// Which qubit was measured.
5659
Qubit qubit = 1;
5760

58-
// These are the results of a measurement on a qubit. Measurement results are
59-
// repetitions number of bits, where the repetitions are define in the
60-
// sweep result message.
61+
// These are the results of a measurement on a qubit. The number of bits
62+
// measured is equal to repetitions * instances, where repetitions is defined
63+
// in the SweepResult message, and instances is defined in the MeasurementResult
64+
// message.
6165
//
6266
// The bytes in this field are constructed as follows:
6367
//
6468
// 1. The results of the measurements produce a list of bits ordered by
65-
// the round of repetition.
69+
// the round of repetition and instance within a round.
6670
//
6771
// 2. This list is broken up into blocks of 8 bits, with the final block
6872
// potentially not being a full 8 bits.

Diff for: cirq-google/cirq_google/api/v2/result_pb2.py

+17-10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: cirq-google/cirq_google/api/v2/result_pb2.pyi

+4-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: cirq-google/cirq_google/api/v2/results.py

+31-34
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import cast, Dict, Hashable, Iterable, Iterator, List, Optional, Sequence, Set
14+
from typing import cast, Dict, Hashable, Iterable, List, Optional, Sequence
1515
from collections import OrderedDict
1616
import dataclasses
1717
import numpy as np
@@ -28,18 +28,15 @@ class MeasureInfo:
2828
Attributes:
2929
key: String identifying this measurement.
3030
qubits: List of measured qubits, in order.
31-
slot: The location of this measurement within the program. For circuits,
32-
this is just the moment index; for schedules it is the start time
33-
of the measurement. This is used internally when scheduling on
34-
hardware so that we can combine measurements that occupy the same
35-
slot.
31+
instances: The number of times a given key occurs in a circuit.
3632
invert_mask: a list of booleans describing whether the results should
3733
be flipped for each of the qubits in the qubits field.
34+
tags: Tags applied to this measurement gate.
3835
"""
3936

4037
key: str
4138
qubits: List[cirq.GridQubit]
42-
slot: int
39+
instances: int
4340
invert_mask: List[bool]
4441
tags: List[Hashable]
4542

@@ -54,34 +51,32 @@ def find_measurements(program: cirq.AbstractCircuit) -> List[MeasureInfo]:
5451
NotImplementedError: If the program is of a type that is not recognized.
5552
ValueError: If there is a duplicate measurement key.
5653
"""
57-
measurements: List[MeasureInfo] = []
58-
keys: Set[str] = set()
59-
60-
if isinstance(program, cirq.AbstractCircuit):
61-
measure_iter = _circuit_measurements(program)
62-
else:
54+
if not isinstance(program, cirq.AbstractCircuit):
6355
raise NotImplementedError(f'Unrecognized program type: {type(program)}')
6456

65-
for m in measure_iter:
66-
if m.key in keys:
67-
raise ValueError(f'Duplicate measurement key: {m.key}')
68-
keys.add(m.key)
69-
measurements.append(m)
70-
71-
return measurements
72-
73-
74-
def _circuit_measurements(circuit: cirq.AbstractCircuit) -> Iterator[MeasureInfo]:
75-
for i, moment in enumerate(circuit):
57+
measurements: Dict[str, MeasureInfo] = {}
58+
for moment in program:
7659
for op in moment:
7760
if isinstance(op.gate, cirq.MeasurementGate):
78-
yield MeasureInfo(
61+
m = MeasureInfo(
7962
key=op.gate.key,
8063
qubits=_grid_qubits(op),
81-
slot=i,
64+
instances=1,
8265
invert_mask=list(op.gate.full_invert_mask()),
8366
tags=list(op.tags),
8467
)
68+
prev_m = measurements.get(m.key)
69+
if prev_m is None:
70+
measurements[m.key] = m
71+
else:
72+
if (
73+
m.qubits != prev_m.qubits
74+
or m.invert_mask != prev_m.invert_mask
75+
or m.tags != prev_m.tags
76+
):
77+
raise ValueError(f"Incompatible repeated keys: {m}, {prev_m}")
78+
prev_m.instances += 1
79+
return list(measurements.values())
8580

8681

8782
def _grid_qubits(op: cirq.Operation) -> List[cirq.GridQubit]:
@@ -137,16 +132,18 @@ def results_to_proto(
137132
sweep_result.repetitions = trial_result.repetitions
138133
elif trial_result.repetitions != sweep_result.repetitions:
139134
raise ValueError('Different numbers of repetitions in one sweep.')
135+
reps = sweep_result.repetitions
140136
pr = sweep_result.parameterized_results.add()
141137
pr.params.assignments.update(trial_result.params.param_dict)
142138
for m in measurements:
143139
mr = pr.measurement_results.add()
144140
mr.key = m.key
145-
m_data = trial_result.measurements[m.key]
141+
mr.instances = m.instances
142+
m_data = trial_result.records[m.key]
146143
for i, qubit in enumerate(m.qubits):
147144
qmr = mr.qubit_measurement_results.add()
148145
qmr.qubit.id = v2.qubit_to_proto_id(qubit)
149-
qmr.results = pack_bits(m_data[:, i])
146+
qmr.results = pack_bits(m_data[:, :, i].reshape(reps * m.instances))
150147
return out
151148

152149

@@ -193,22 +190,22 @@ def _trial_sweep_from_proto(
193190

194191
trial_sweep: List[cirq.Result] = []
195192
for pr in msg.parameterized_results:
196-
m_data: Dict[str, np.ndarray] = {}
193+
records: Dict[str, np.ndarray] = {}
197194
for mr in pr.measurement_results:
195+
instances = max(mr.instances, 1)
198196
qubit_results: OrderedDict[cirq.GridQubit, np.ndarray] = OrderedDict()
199197
for qmr in mr.qubit_measurement_results:
200198
qubit = v2.grid_qubit_from_proto_id(qmr.qubit.id)
201199
if qubit in qubit_results:
202200
raise ValueError(f'Qubit already exists: {qubit}.')
203-
qubit_results[qubit] = unpack_bits(qmr.results, msg.repetitions)
201+
qubit_results[qubit] = unpack_bits(qmr.results, msg.repetitions * instances)
204202
if measure_map:
205203
ordered_results = [qubit_results[qubit] for qubit in measure_map[mr.key].qubits]
206204
else:
207205
ordered_results = list(qubit_results.values())
208-
m_data[mr.key] = np.array(ordered_results).transpose()
206+
shape = (msg.repetitions, instances, len(qubit_results))
207+
records[mr.key] = np.array(ordered_results).transpose().reshape(shape)
209208
trial_sweep.append(
210-
cirq.ResultDict(
211-
params=cirq.ParamResolver(dict(pr.params.assignments)), measurements=m_data
212-
)
209+
cirq.ResultDict(params=cirq.ParamResolver(dict(pr.params.assignments)), records=records)
213210
)
214211
return trial_sweep

0 commit comments

Comments
 (0)