18
18
component operations in order, including any nested CircuitOperations.
19
19
"""
20
20
21
- from typing import TYPE_CHECKING , AbstractSet , Callable , Dict , Optional , Tuple , Union
21
+ from typing import TYPE_CHECKING , AbstractSet , Callable , Dict , List , Optional , Tuple , Union
22
22
23
23
import dataclasses
24
24
import numpy as np
33
33
INT_TYPE = Union [int , np .integer ]
34
34
35
35
36
+ def default_repetition_ids (repetitions : int ) -> Optional [List [str ]]:
37
+ if abs (repetitions ) > 1 :
38
+ return [str (i ) for i in range (abs (repetitions ))]
39
+ return None
40
+
41
+
42
+ def cartesian_product_of_string_lists (list1 : Optional [List [str ]], list2 : Optional [List [str ]]):
43
+ if list1 is None and list2 is None :
44
+ return None # coverage: ignore
45
+ if list1 is None :
46
+ return list2 # coverage: ignore
47
+ if list2 is None :
48
+ return list1
49
+ return [f'{ first } -{ second } ' for first in list1 for second in list2 ]
50
+
51
+
36
52
@dataclasses .dataclass (frozen = True )
37
53
class CircuitOperation (ops .Operation ):
38
54
"""An operation that encapsulates a circuit.
@@ -47,6 +63,10 @@ class CircuitOperation(ops.Operation):
47
63
qubit_map: Remappings for qubits in the circuit.
48
64
measurement_key_map: Remappings for measurement keys in the circuit.
49
65
param_resolver: Resolved values for parameters in the circuit.
66
+ repetition_ids: List of identifiers for each repetition of the
67
+ CircuitOperation. If populated, the length should be equal to the
68
+ repetitions. If not populated and abs(`repetitions`) > 1, it is
69
+ initialized to strings for numbers in `range(repetitions)`.
50
70
"""
51
71
52
72
_hash : Optional [int ] = dataclasses .field (default = None , init = False )
@@ -56,10 +76,28 @@ class CircuitOperation(ops.Operation):
56
76
qubit_map : Dict ['cirq.Qid' , 'cirq.Qid' ] = dataclasses .field (default_factory = dict )
57
77
measurement_key_map : Dict [str , str ] = dataclasses .field (default_factory = dict )
58
78
param_resolver : study .ParamResolver = study .ParamResolver ()
79
+ repetition_ids : Optional [List [str ]] = dataclasses .field (default = None )
59
80
60
81
def __post_init__ (self ):
61
82
if not isinstance (self .circuit , circuits .FrozenCircuit ):
62
83
raise TypeError (f'Expected circuit of type FrozenCircuit, got: { type (self .circuit )!r} ' )
84
+
85
+ # Ensure that the circuit is invertible if the repetitions are negative.
86
+ if self .repetitions < 0 :
87
+ try :
88
+ protocols .inverse (self .circuit .unfreeze ())
89
+ except TypeError :
90
+ raise ValueError (f'repetitions are negative but the circuit is not invertible' )
91
+
92
+ # Initialize repetition_ids to default, if unspecified. Else, validate their length.
93
+ loop_size = abs (self .repetitions )
94
+ if not self .repetition_ids :
95
+ object .__setattr__ (self , 'repetition_ids' , self ._default_repetition_ids ())
96
+ elif len (self .repetition_ids ) != loop_size :
97
+ raise ValueError (
98
+ f'Expected repetition_ids to be a list of length { loop_size } , '
99
+ f'got: { self .repetition_ids } '
100
+ )
63
101
# Ensure that param_resolver is converted to an actual ParamResolver.
64
102
object .__setattr__ (self , 'param_resolver' , study .ParamResolver (self .param_resolver ))
65
103
@@ -83,6 +121,7 @@ def __eq__(self, other) -> bool:
83
121
and self .measurement_key_map == other .measurement_key_map
84
122
and self .param_resolver == other .param_resolver
85
123
and self .repetitions == other .repetitions
124
+ and self .repetition_ids == other .repetition_ids
86
125
)
87
126
88
127
# Methods for getting post-mapping properties of the contained circuit.
@@ -93,6 +132,9 @@ def qubits(self) -> Tuple['cirq.Qid', ...]:
93
132
ordered_qubits = ops .QubitOrder .DEFAULT .order_for (self .circuit .all_qubits ())
94
133
return tuple (self .qubit_map .get (q , q ) for q in ordered_qubits )
95
134
135
+ def _default_repetition_ids (self ) -> Optional [List [str ]]:
136
+ return default_repetition_ids (self .repetitions )
137
+
96
138
def _qid_shape_ (self ) -> Tuple [int , ...]:
97
139
return tuple (q .dimension for q in self .qubits )
98
140
@@ -117,8 +159,39 @@ def _decompose_(self) -> 'cirq.OP_TREE':
117
159
result = result ** - 1
118
160
result = protocols .with_measurement_key_mapping (result , self .measurement_key_map )
119
161
result = protocols .resolve_parameters (result , self .param_resolver , recursive = False )
120
-
121
- return list (result .all_operations ()) * abs (self .repetitions )
162
+ # repetition_ids don't need to be taken into account if the circuit has no measurements
163
+ # or if repetition_ids are unset.
164
+ if self .repetition_ids is None or not protocols .is_measurement (result ):
165
+ return list (result .all_operations ()) * abs (self .repetitions )
166
+ # If it's a measurement circuit with repetitions/repetition_ids, prefix the repetition_ids
167
+ # to measurements. Details at https://tinyurl.com/measurement-repeated-circuitop.
168
+ ops = [] # type: List[cirq.Operation]
169
+ for parent_id in self .repetition_ids :
170
+ for op in result .all_operations ():
171
+ if isinstance (op , CircuitOperation ):
172
+ # For a CircuitOperation, prefix the current repetition_id to the children
173
+ # repetition_ids.
174
+ ops .append (
175
+ op .with_repetition_ids (
176
+ # If `op.repetition_ids` is None, this will return `[parent_id]`.
177
+ cartesian_product_of_string_lists ([parent_id ], op .repetition_ids )
178
+ )
179
+ )
180
+ elif protocols .is_measurement (op ):
181
+ # For a non-CircuitOperation measurement, prefix the current repetition_id
182
+ # to the children measurement keys. Implemented by creating a mapping and
183
+ # using the with_measurement_key_mapping protocol.
184
+ ops .append (
185
+ protocols .with_measurement_key_mapping (
186
+ op ,
187
+ key_map = {
188
+ key : f'{ parent_id } -{ key } ' for key in protocols .measurement_keys (op )
189
+ },
190
+ )
191
+ )
192
+ else :
193
+ ops .append (op )
194
+ return ops
122
195
123
196
# Methods for string representation of the operation.
124
197
@@ -132,6 +205,9 @@ def __repr__(self):
132
205
args += f'measurement_key_map={ proper_repr (self .measurement_key_map )} ,\n '
133
206
if self .param_resolver :
134
207
args += f'param_resolver={ proper_repr (self .param_resolver )} ,\n '
208
+ if self .repetition_ids != self ._default_repetition_ids ():
209
+ # Default repetition_ids need not be specified.
210
+ args += f'repetition_ids={ proper_repr (self .repetition_ids )} ,\n '
135
211
indented_args = args .replace ('\n ' , '\n ' )
136
212
return f'cirq.CircuitOperation({ indented_args [:- 4 ]} )'
137
213
@@ -155,7 +231,11 @@ def dict_str(d: Dict) -> str:
155
231
args .append (f'key_map={ dict_str (self .measurement_key_map )} ' )
156
232
if self .param_resolver :
157
233
args .append (f'params={ self .param_resolver .param_dict } ' )
158
- if self .repetitions != 1 :
234
+ if self .repetition_ids != self ._default_repetition_ids ():
235
+ # Default repetition_ids need not be specified.
236
+ args .append (f'repetition_ids={ self .repetition_ids } ' )
237
+ elif self .repetitions != 1 :
238
+ # Only add loops if we haven't added repetition_ids.
159
239
args .append (f'loops={ self .repetitions } ' )
160
240
if not args :
161
241
return f'{ header } \n { circuit_msg } '
@@ -173,6 +253,7 @@ def __hash__(self):
173
253
frozenset (self .qubit_map .items ()),
174
254
frozenset (self .measurement_key_map .items ()),
175
255
self .param_resolver ,
256
+ tuple ([] if self .repetition_ids is None else self .repetition_ids ),
176
257
)
177
258
),
178
259
)
@@ -188,53 +269,95 @@ def _json_dict_(self):
188
269
'qubit_map' : sorted (self .qubit_map .items ()),
189
270
'measurement_key_map' : self .measurement_key_map ,
190
271
'param_resolver' : self .param_resolver ,
272
+ 'repetition_ids' : self .repetition_ids ,
191
273
}
192
274
193
275
@classmethod
194
276
def _from_json_dict_ (
195
- cls , circuit , repetitions , qubit_map , measurement_key_map , param_resolver , ** kwargs
277
+ cls ,
278
+ circuit ,
279
+ repetitions ,
280
+ qubit_map ,
281
+ measurement_key_map ,
282
+ param_resolver ,
283
+ repetition_ids ,
284
+ ** kwargs ,
196
285
):
197
286
return (
198
287
cls (circuit )
199
288
.with_qubit_mapping (dict (qubit_map ))
200
289
.with_measurement_key_mapping (measurement_key_map )
201
290
.with_params (param_resolver )
202
- .repeat (repetitions )
291
+ .repeat (repetitions , repetition_ids )
203
292
)
204
293
205
294
# Methods for constructing a similar object with one field modified.
206
295
207
296
def repeat (
208
297
self ,
209
- repetitions : INT_TYPE ,
298
+ repetitions : Optional [INT_TYPE ] = None ,
299
+ repetition_ids : Optional [List [str ]] = None ,
210
300
) -> 'CircuitOperation' :
211
301
"""Returns a copy of this operation repeated 'repetitions' times.
302
+ Each repetition instance will be identified by a single repetition_id.
212
303
213
304
Args:
214
305
repetitions: Number of times this operation should repeat. This
215
- is multiplied with any pre-existing repetitions.
306
+ is multiplied with any pre-existing repetitions. If unset, it
307
+ defaults to the length of `repetition_ids`.
308
+ repetition_ids: List of IDs, one for each repetition. If unset,
309
+ defaults to `default_repetition_ids(repetitions)`.
216
310
217
311
Returns:
218
- A copy of this operation repeated 'repetitions' times.
312
+ A copy of this operation repeated `repetitions` times with the
313
+ appropriate `repetition_ids`. The output `repetition_ids` are the
314
+ cartesian product of input `repetition_ids` with the base
315
+ operation's `repetition_ids`. If the base operation has unset
316
+ `repetition_ids` (indicates {-1, 0, 1} `repetitions` with no custom
317
+ IDs), the input `repetition_ids` are directly used.
219
318
220
319
Raises:
221
320
TypeError: `repetitions` is not an integer value.
222
- NotImplementedError: The operation contains measurements and
223
- cannot have repetitions .
321
+ ValueError: Unexpected length of `repetition_ids`.
322
+ ValueError: Both `repetitions` and `repetition_ids` are None .
224
323
"""
324
+ if repetitions is None :
325
+ if repetition_ids is None :
326
+ raise ValueError ('At least one of repetitions and repetition_ids must be set' )
327
+ repetitions = len (repetition_ids )
328
+
225
329
if not isinstance (repetitions , (int , np .integer )):
226
330
raise TypeError ('Only integer repetitions are allowed.' )
227
- if repetitions == 1 :
331
+
332
+ repetitions = int (repetitions )
333
+
334
+ if repetitions == 1 and repetition_ids is None :
228
335
# As CircuitOperation is immutable, this can safely return the original.
229
336
return self
230
- repetitions = int (repetitions )
231
- if protocols .is_measurement (self .circuit ):
232
- raise NotImplementedError ('Loops over measurements are not supported.' )
233
- return self .replace (repetitions = self .repetitions * repetitions )
337
+
338
+ expected_repetition_id_length = abs (repetitions )
339
+ # The eventual number of repetitions of the returned CircuitOperation.
340
+ final_repetitions = self .repetitions * repetitions
341
+
342
+ if repetition_ids is None :
343
+ repetition_ids = default_repetition_ids (expected_repetition_id_length )
344
+ elif len (repetition_ids ) != expected_repetition_id_length :
345
+ raise ValueError (
346
+ f'Expected repetition_ids={ repetition_ids } length to be '
347
+ f'{ expected_repetition_id_length } '
348
+ )
349
+
350
+ # If `self.repetition_ids` is None, this will just return `repetition_ids`.
351
+ repetition_ids = cartesian_product_of_string_lists (repetition_ids , self .repetition_ids )
352
+
353
+ return self .replace (repetitions = final_repetitions , repetition_ids = repetition_ids )
234
354
235
355
def __pow__ (self , power : int ) -> 'CircuitOperation' :
236
356
return self .repeat (power )
237
357
358
+ def with_repetition_ids (self , repetition_ids : List [str ]) -> 'CircuitOperation' :
359
+ return self .replace (repetition_ids = repetition_ids )
360
+
238
361
def with_qubit_mapping (
239
362
self ,
240
363
qubit_map : Union [Dict ['cirq.Qid' , 'cirq.Qid' ], Callable [['cirq.Qid' ], 'cirq.Qid' ]],
0 commit comments