11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from typing import cast , Dict , Hashable , Iterable , Iterator , List , Optional , Sequence , Set
14
+ from typing import cast , Dict , Hashable , Iterable , List , Optional , Sequence
15
15
from collections import OrderedDict
16
16
import dataclasses
17
17
import numpy as np
@@ -28,18 +28,15 @@ class MeasureInfo:
28
28
Attributes:
29
29
key: String identifying this measurement.
30
30
qubits: List of measured qubits, in order.
31
- slot: The location of this measurement within the program. For circuits,
32
- this is just the moment index; for schedules it is the start time
33
- of the measurement. This is used internally when scheduling on
34
- hardware so that we can combine measurements that occupy the same
35
- slot.
31
+ instances: The number of times a given key occurs in a circuit.
36
32
invert_mask: a list of booleans describing whether the results should
37
33
be flipped for each of the qubits in the qubits field.
34
+ tags: Tags applied to this measurement gate.
38
35
"""
39
36
40
37
key : str
41
38
qubits : List [cirq .GridQubit ]
42
- slot : int
39
+ instances : int
43
40
invert_mask : List [bool ]
44
41
tags : List [Hashable ]
45
42
@@ -54,34 +51,32 @@ def find_measurements(program: cirq.AbstractCircuit) -> List[MeasureInfo]:
54
51
NotImplementedError: If the program is of a type that is not recognized.
55
52
ValueError: If there is a duplicate measurement key.
56
53
"""
57
- measurements : List [MeasureInfo ] = []
58
- keys : Set [str ] = set ()
59
-
60
- if isinstance (program , cirq .AbstractCircuit ):
61
- measure_iter = _circuit_measurements (program )
62
- else :
54
+ if not isinstance (program , cirq .AbstractCircuit ):
63
55
raise NotImplementedError (f'Unrecognized program type: { type (program )} ' )
64
56
65
- for m in measure_iter :
66
- if m .key in keys :
67
- raise ValueError (f'Duplicate measurement key: { m .key } ' )
68
- keys .add (m .key )
69
- measurements .append (m )
70
-
71
- return measurements
72
-
73
-
74
- def _circuit_measurements (circuit : cirq .AbstractCircuit ) -> Iterator [MeasureInfo ]:
75
- for i , moment in enumerate (circuit ):
57
+ measurements : Dict [str , MeasureInfo ] = {}
58
+ for moment in program :
76
59
for op in moment :
77
60
if isinstance (op .gate , cirq .MeasurementGate ):
78
- yield MeasureInfo (
61
+ m = MeasureInfo (
79
62
key = op .gate .key ,
80
63
qubits = _grid_qubits (op ),
81
- slot = i ,
64
+ instances = 1 ,
82
65
invert_mask = list (op .gate .full_invert_mask ()),
83
66
tags = list (op .tags ),
84
67
)
68
+ prev_m = measurements .get (m .key )
69
+ if prev_m is None :
70
+ measurements [m .key ] = m
71
+ else :
72
+ if (
73
+ m .qubits != prev_m .qubits
74
+ or m .invert_mask != prev_m .invert_mask
75
+ or m .tags != prev_m .tags
76
+ ):
77
+ raise ValueError (f"Incompatible repeated keys: { m } , { prev_m } " )
78
+ prev_m .instances += 1
79
+ return list (measurements .values ())
85
80
86
81
87
82
def _grid_qubits (op : cirq .Operation ) -> List [cirq .GridQubit ]:
@@ -137,16 +132,18 @@ def results_to_proto(
137
132
sweep_result .repetitions = trial_result .repetitions
138
133
elif trial_result .repetitions != sweep_result .repetitions :
139
134
raise ValueError ('Different numbers of repetitions in one sweep.' )
135
+ reps = sweep_result .repetitions
140
136
pr = sweep_result .parameterized_results .add ()
141
137
pr .params .assignments .update (trial_result .params .param_dict )
142
138
for m in measurements :
143
139
mr = pr .measurement_results .add ()
144
140
mr .key = m .key
145
- m_data = trial_result .measurements [m .key ]
141
+ mr .instances = m .instances
142
+ m_data = trial_result .records [m .key ]
146
143
for i , qubit in enumerate (m .qubits ):
147
144
qmr = mr .qubit_measurement_results .add ()
148
145
qmr .qubit .id = v2 .qubit_to_proto_id (qubit )
149
- qmr .results = pack_bits (m_data [:, i ] )
146
+ qmr .results = pack_bits (m_data [:, :, i ]. reshape ( reps * m . instances ) )
150
147
return out
151
148
152
149
@@ -193,22 +190,22 @@ def _trial_sweep_from_proto(
193
190
194
191
trial_sweep : List [cirq .Result ] = []
195
192
for pr in msg .parameterized_results :
196
- m_data : Dict [str , np .ndarray ] = {}
193
+ records : Dict [str , np .ndarray ] = {}
197
194
for mr in pr .measurement_results :
195
+ instances = max (mr .instances , 1 )
198
196
qubit_results : OrderedDict [cirq .GridQubit , np .ndarray ] = OrderedDict ()
199
197
for qmr in mr .qubit_measurement_results :
200
198
qubit = v2 .grid_qubit_from_proto_id (qmr .qubit .id )
201
199
if qubit in qubit_results :
202
200
raise ValueError (f'Qubit already exists: { qubit } .' )
203
- qubit_results [qubit ] = unpack_bits (qmr .results , msg .repetitions )
201
+ qubit_results [qubit ] = unpack_bits (qmr .results , msg .repetitions * instances )
204
202
if measure_map :
205
203
ordered_results = [qubit_results [qubit ] for qubit in measure_map [mr .key ].qubits ]
206
204
else :
207
205
ordered_results = list (qubit_results .values ())
208
- m_data [mr .key ] = np .array (ordered_results ).transpose ()
206
+ shape = (msg .repetitions , instances , len (qubit_results ))
207
+ records [mr .key ] = np .array (ordered_results ).transpose ().reshape (shape )
209
208
trial_sweep .append (
210
- cirq .ResultDict (
211
- params = cirq .ParamResolver (dict (pr .params .assignments )), measurements = m_data
212
- )
209
+ cirq .ResultDict (params = cirq .ParamResolver (dict (pr .params .assignments )), records = records )
213
210
)
214
211
return trial_sweep
0 commit comments