13
13
# limitations under the License.
14
14
15
15
import itertools
16
- from typing import (
17
- Any ,
18
- Dict ,
19
- Iterable ,
20
- List ,
21
- Mapping ,
22
- Optional ,
23
- Sequence ,
24
- Tuple ,
25
- TYPE_CHECKING ,
26
- Union ,
27
- )
16
+ from collections import defaultdict
17
+ from typing import Any , Dict , Iterable , List , Optional , Sequence , Tuple , TYPE_CHECKING , Union
28
18
29
19
import numpy as np
30
20
@@ -43,30 +33,32 @@ class _MeasurementQid(ops.Qid):
43
33
Exactly one qubit will be created per qubit in the measurement gate.
44
34
"""
45
35
46
- def __init__ (self , key : Union [str , 'cirq.MeasurementKey' ], qid : 'cirq.Qid' ):
36
+ def __init__ (self , key : Union [str , 'cirq.MeasurementKey' ], qid : 'cirq.Qid' , index : int = 0 ):
47
37
"""Initializes the qubit.
48
38
49
39
Args:
50
40
key: The key of the measurement gate being deferred.
51
41
qid: One qubit that is being measured. Each deferred measurement
52
42
should create one new _MeasurementQid per qubit being measured
53
43
by that gate.
44
+ index: For repeated measurement keys, this represents the index of that measurement.
54
45
"""
55
46
self ._key = value .MeasurementKey .parse_serialized (key ) if isinstance (key , str ) else key
56
47
self ._qid = qid
48
+ self ._index = index
57
49
58
50
@property
59
51
def dimension (self ) -> int :
60
52
return self ._qid .dimension
61
53
62
54
def _comparison_key (self ) -> Any :
63
- return str (self ._key ), self ._qid ._comparison_key ()
55
+ return str (self ._key ), self ._index , self . _qid ._comparison_key ()
64
56
65
57
def __str__ (self ) -> str :
66
- return f"M('{ self ._key } ', q={ self ._qid } )"
58
+ return f"M('{ self ._key } [ { self . _index } ] ', q={ self ._qid } )"
67
59
68
60
def __repr__ (self ) -> str :
69
- return f'_MeasurementQid({ self ._key !r} , { self ._qid !r} )'
61
+ return f'_MeasurementQid({ self ._key !r} , { self ._qid !r} , { self . _index } )'
70
62
71
63
72
64
@transformer_api .transformer
@@ -102,16 +94,18 @@ def defer_measurements(
102
94
103
95
circuit = transformer_primitives .unroll_circuit_op (circuit , deep = True , tags_to_check = None )
104
96
terminal_measurements = {op for _ , op in find_terminal_measurements (circuit )}
105
- measurement_qubits : Dict ['cirq.MeasurementKey' , List ['_MeasurementQid' ]] = {}
97
+ measurement_qubits : Dict ['cirq.MeasurementKey' , List [Tuple ['cirq.Qid' , ...]]] = defaultdict (
98
+ list
99
+ )
106
100
107
101
def defer (op : 'cirq.Operation' , _ ) -> 'cirq.OP_TREE' :
108
102
if op in terminal_measurements :
109
103
return op
110
104
gate = op .gate
111
105
if isinstance (gate , ops .MeasurementGate ):
112
106
key = value .MeasurementKey .parse_serialized (gate .key )
113
- targets = [_MeasurementQid (key , q ) for q in op .qubits ]
114
- measurement_qubits [key ] = targets
107
+ targets = [_MeasurementQid (key , q , len ( measurement_qubits [ key ]) ) for q in op .qubits ]
108
+ measurement_qubits [key ]. append ( tuple ( targets ))
115
109
cxs = [_mod_add (q , target ) for q , target in zip (op .qubits , targets )]
116
110
confusions = [
117
111
_ConfusionChannel (m , [op .qubits [i ].dimension for i in indexes ]).on (
@@ -125,10 +119,24 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
125
119
return [defer (op , None ) for op in protocols .decompose_once (op )]
126
120
elif op .classical_controls :
127
121
# Convert to a quantum control
128
- keys = sorted (set (key for c in op .classical_controls for key in c .keys ))
129
- for key in keys :
122
+
123
+ # First create a sorted set of the indexed keys for this control.
124
+ keys = sorted (
125
+ set (
126
+ indexed_key
127
+ for condition in op .classical_controls
128
+ for indexed_key in (
129
+ [(condition .key , condition .index )]
130
+ if isinstance (condition , value .KeyCondition )
131
+ else [(k , - 1 ) for k in condition .keys ]
132
+ )
133
+ )
134
+ )
135
+ for key , index in keys :
130
136
if key not in measurement_qubits :
131
137
raise ValueError (f'Deferred measurement for key={ key } not found.' )
138
+ if index >= len (measurement_qubits [key ]) or index < - len (measurement_qubits [key ]):
139
+ raise ValueError (f'Invalid index for { key } ' )
132
140
133
141
# Try every possible datastore state (exponential in the number of keys) against the
134
142
# condition, and the ones that work are the control values for the new op.
@@ -140,12 +148,11 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
140
148
141
149
# Rearrange these into the format expected by SumOfProducts
142
150
products = [
143
- [i for key in keys for i in store .records [key ][ 0 ]]
151
+ [val for k , i in keys for val in store .records [k ][ i ]]
144
152
for store in compatible_datastores
145
153
]
146
-
147
154
control_values = ops .SumOfProducts (products )
148
- qs = [q for key in keys for q in measurement_qubits [key ]]
155
+ qs = [q for k , i in keys for q in measurement_qubits [k ][ i ]]
149
156
return op .without_classical_controls ().controlled_by (* qs , control_values = control_values )
150
157
return op
151
158
@@ -155,14 +162,15 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
155
162
tags_to_ignore = context .tags_to_ignore if context else (),
156
163
raise_if_add_qubits = False ,
157
164
).unfreeze ()
158
- for k , qubits in measurement_qubits .items ():
159
- circuit .append (ops .measure (* qubits , key = k ))
165
+ for k , qubits_list in measurement_qubits .items ():
166
+ for qubits in qubits_list :
167
+ circuit .append (ops .measure (* qubits , key = k ))
160
168
return circuit
161
169
162
170
163
171
def _all_possible_datastore_states (
164
- keys : Iterable ['cirq.MeasurementKey' ],
165
- measurement_qubits : Mapping ['cirq.MeasurementKey' , Iterable [ 'cirq.Qid' ]],
172
+ keys : Iterable [Tuple [ 'cirq.MeasurementKey' , int ] ],
173
+ measurement_qubits : Dict ['cirq.MeasurementKey' , List [ Tuple [ 'cirq.Qid' , ...] ]],
166
174
) -> Iterable ['cirq.ClassicalDataStoreReader' ]:
167
175
"""The cartesian product of all possible DataStore states for the given keys."""
168
176
# First we get the list of all possible values. So if we have a key mapped to qubits of shape
@@ -179,17 +187,28 @@ def _all_possible_datastore_states(
179
187
# ((1, 1), (0,)),
180
188
# ((1, 1), (1,)),
181
189
# ((1, 1), (2,))]
182
- all_values = itertools .product (
190
+ all_possible_measurements = itertools .product (
183
191
* [
184
- tuple (itertools .product (* [range (q .dimension ) for q in measurement_qubits [k ]]))
185
- for k in keys
192
+ tuple (itertools .product (* [range (q .dimension ) for q in measurement_qubits [k ][ i ] ]))
193
+ for k , i in keys
186
194
]
187
195
)
188
- # Then we create the ClassicalDataDictionaryStore for each of the above.
189
- for sequences in all_values :
190
- lookup = {k : [sequence ] for k , sequence in zip (keys , sequences )}
196
+ # Then we create the ClassicalDataDictionaryStore for each of the above. A `measurement_list`
197
+ # is a single row of the above example, and can be zipped with `keys`.
198
+ for measurement_list in all_possible_measurements :
199
+ # Initialize a set of measurement records for this iteration. This will have the same shape
200
+ # as `measurement_qubits` but zeros for all measurements.
201
+ records = {
202
+ key : [(0 ,) * len (qubits ) for qubits in qubits_list ]
203
+ for key , qubits_list in measurement_qubits .items ()
204
+ }
205
+ # Set the measurement values from the current row of the above, for each key/index we care
206
+ # about.
207
+ for (k , i ), measurement in zip (keys , measurement_list ):
208
+ records [k ][i ] = measurement
209
+ # Finally yield this sample to the consumer.
191
210
yield value .ClassicalDataDictionaryStore (
192
- _records = lookup , _measured_qubits = { k : [ tuple ( measurement_qubits [ k ])] for k in keys }
211
+ _records = records , _measured_qubits = measurement_qubits
193
212
)
194
213
195
214
0 commit comments