18
18
component operations in order, including any nested CircuitOperations.
19
19
"""
20
20
from typing import (
21
- TYPE_CHECKING ,
22
21
AbstractSet ,
23
22
Callable ,
23
+ cast ,
24
24
Dict ,
25
25
FrozenSet ,
26
26
Iterator ,
27
27
List ,
28
28
Optional ,
29
29
Tuple ,
30
+ TYPE_CHECKING ,
30
31
Union ,
31
32
)
32
33
@@ -94,6 +95,12 @@ class CircuitOperation(ops.Operation):
94
95
will have its path prepended with the repetition id for each
95
96
repetition. When False, this will not happen and the measurement
96
97
key will be repeated.
98
+ repeat_until: A condition that will be tested after each iteration of
99
+ the subcircuit. The subcircuit will repeat until condition returns
100
+ True, but will always run at least once, and the measurement key
101
+ need not be defined prior to the subcircuit (but must be defined in
102
+ a measurement within the subcircuit). This field is incompatible
103
+ with repetitions or repetition_ids.
97
104
"""
98
105
99
106
_hash : Optional [int ] = dataclasses .field (default = None , init = False )
@@ -103,6 +110,9 @@ class CircuitOperation(ops.Operation):
103
110
_cached_control_keys : Optional [AbstractSet ['cirq.MeasurementKey' ]] = dataclasses .field (
104
111
default = None , init = False
105
112
)
113
+ _cached_mapped_single_loop : Optional ['cirq.Circuit' ] = dataclasses .field (
114
+ default = None , init = False
115
+ )
106
116
107
117
circuit : 'cirq.FrozenCircuit'
108
118
repetitions : int = 1
@@ -113,6 +123,7 @@ class CircuitOperation(ops.Operation):
113
123
parent_path : Tuple [str , ...] = dataclasses .field (default_factory = tuple )
114
124
extern_keys : FrozenSet ['cirq.MeasurementKey' ] = dataclasses .field (default_factory = frozenset )
115
125
use_repetition_ids : bool = True
126
+ repeat_until : Optional ['cirq.Condition' ] = dataclasses .field (default = None )
116
127
117
128
def __post_init__ (self ):
118
129
if not isinstance (self .circuit , circuits .FrozenCircuit ):
@@ -148,6 +159,14 @@ def __post_init__(self):
148
159
if q_new .dimension != q .dimension :
149
160
raise ValueError (f'Qid dimension conflict.\n From qid: { q } \n To qid: { q_new } ' )
150
161
162
+ if self .repeat_until :
163
+ if self .use_repetition_ids or self .repetitions != 1 :
164
+ raise ValueError ('Cannot use repetitions with repeat_until' )
165
+ if protocols .measurement_key_objs (self ._mapped_single_loop ()).isdisjoint (
166
+ self .repeat_until .keys
167
+ ):
168
+ raise ValueError ('Infinite loop: condition is not modified in subcircuit.' )
169
+
151
170
# Ensure that param_resolver is converted to an actual ParamResolver.
152
171
object .__setattr__ (self , 'param_resolver' , study .ParamResolver (self .param_resolver ))
153
172
@@ -174,6 +193,7 @@ def __eq__(self, other) -> bool:
174
193
and self .repetition_ids == other .repetition_ids
175
194
and self .parent_path == other .parent_path
176
195
and self .use_repetition_ids == other .use_repetition_ids
196
+ and self .repeat_until == other .repeat_until
177
197
)
178
198
179
199
# Methods for getting post-mapping properties of the contained circuit.
@@ -223,6 +243,8 @@ def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
223
243
if not protocols .control_keys (self .circuit )
224
244
else protocols .control_keys (self .mapped_circuit ())
225
245
)
246
+ if self .repeat_until is not None :
247
+ keys |= frozenset (self .repeat_until .keys ) - self ._measurement_key_objs_ ()
226
248
object .__setattr__ (self , '_cached_control_keys' , keys )
227
249
return self ._cached_control_keys # type: ignore
228
250
@@ -235,6 +257,27 @@ def _parameter_names_(self) -> AbstractSet[str]:
235
257
)
236
258
}
237
259
260
+ def _mapped_single_loop (self , repetition_id : Optional [str ] = None ) -> 'cirq.Circuit' :
261
+ if self ._cached_mapped_single_loop is None :
262
+ circuit = self .circuit .unfreeze ()
263
+ if self .qubit_map :
264
+ circuit = circuit .transform_qubits (lambda q : self .qubit_map .get (q , q ))
265
+ if self .repetitions < 0 :
266
+ circuit = circuit ** - 1
267
+ if self .measurement_key_map :
268
+ circuit = protocols .with_measurement_key_mapping (circuit , self .measurement_key_map )
269
+ if self .param_resolver :
270
+ circuit = protocols .resolve_parameters (
271
+ circuit , self .param_resolver , recursive = False
272
+ )
273
+ object .__setattr__ (self , '_cached_mapped_single_loop' , circuit )
274
+ circuit = cast (circuits .Circuit , self ._cached_mapped_single_loop )
275
+ if repetition_id :
276
+ circuit = protocols .with_rescoped_keys (circuit , (repetition_id ,))
277
+ return protocols .with_rescoped_keys (
278
+ circuit , self .parent_path , bindable_keys = self .extern_keys
279
+ )
280
+
238
281
def mapped_circuit (self , deep : bool = False ) -> 'cirq.Circuit' :
239
282
"""Applies all maps to the contained circuit and returns the result.
240
283
@@ -249,24 +292,12 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
249
292
"""
250
293
if self .repetitions == 0 :
251
294
return circuits .Circuit ()
252
- circuit = self .circuit .unfreeze ()
253
- if self .qubit_map :
254
- circuit = circuit .transform_qubits (lambda q : self .qubit_map .get (q , q ))
255
- if self .repetitions < 0 :
256
- circuit = circuit ** - 1
257
- if self .measurement_key_map :
258
- circuit = protocols .with_measurement_key_mapping (circuit , self .measurement_key_map )
259
- if self .param_resolver :
260
- circuit = protocols .resolve_parameters (circuit , self .param_resolver , recursive = False )
261
- if self .repetition_ids is not None :
262
- if not self .use_repetition_ids or not protocols .is_measurement (circuit ):
263
- circuit = circuit * abs (self .repetitions )
264
- else :
265
- circuit = circuits .Circuit (
266
- protocols .with_rescoped_keys (circuit , (rep ,)) for rep in self .repetition_ids
267
- )
268
- circuit = protocols .with_rescoped_keys (
269
- circuit , self .parent_path , bindable_keys = self .extern_keys
295
+ circuit = (
296
+ circuits .Circuit (self ._mapped_single_loop (rep ) for rep in self .repetition_ids )
297
+ if self .repetition_ids is not None
298
+ and self .use_repetition_ids
299
+ and protocols .is_measurement (self .circuit )
300
+ else self ._mapped_single_loop () * abs (self .repetitions )
270
301
)
271
302
if deep :
272
303
circuit = circuit .map_operations (
@@ -282,8 +313,16 @@ def _decompose_(self) -> Iterator['cirq.Operation']:
282
313
return self .mapped_circuit (deep = False ).all_operations ()
283
314
284
315
def _act_on_ (self , args : 'cirq.OperationTarget' ) -> bool :
285
- for op in self ._decompose_ ():
286
- protocols .act_on (op , args )
316
+ if self .repeat_until :
317
+ circuit = self ._mapped_single_loop ()
318
+ while True :
319
+ for op in circuit .all_operations ():
320
+ protocols .act_on (op , args )
321
+ if self .repeat_until .resolve (args .classical_data ):
322
+ break
323
+ else :
324
+ for op in self ._decompose_ ():
325
+ protocols .act_on (op , args )
287
326
return True
288
327
289
328
# Methods for string representation of the operation.
@@ -305,6 +344,8 @@ def __repr__(self):
305
344
args += f'repetition_ids={ proper_repr (self .repetition_ids )} ,\n '
306
345
if not self .use_repetition_ids :
307
346
args += 'use_repetition_ids=False,\n '
347
+ if self .repeat_until :
348
+ args += f'repeat_until={ self .repeat_until !r} ,\n '
308
349
indented_args = args .replace ('\n ' , '\n ' )
309
350
return f'cirq.CircuitOperation({ indented_args [:- 4 ]} )'
310
351
@@ -337,6 +378,8 @@ def dict_str(d: Dict) -> str:
337
378
args .append (f'loops={ self .repetitions } ' )
338
379
if not self .use_repetition_ids :
339
380
args .append ('no_rep_ids' )
381
+ if self .repeat_until :
382
+ args .append (f'until={ self .repeat_until } ' )
340
383
if not args :
341
384
return circuit_msg
342
385
return f'{ circuit_msg } ({ ", " .join (args )} )'
@@ -375,6 +418,8 @@ def _json_dict_(self):
375
418
}
376
419
if not self .use_repetition_ids :
377
420
resp ['use_repetition_ids' ] = False
421
+ if self .repeat_until :
422
+ resp ['repeat_until' ] = self .repeat_until
378
423
return resp
379
424
380
425
@classmethod
@@ -388,10 +433,11 @@ def _from_json_dict_(
388
433
repetition_ids ,
389
434
parent_path = (),
390
435
use_repetition_ids = True ,
436
+ repeat_until = None ,
391
437
** kwargs ,
392
438
):
393
439
return (
394
- cls (circuit , use_repetition_ids = use_repetition_ids )
440
+ cls (circuit , use_repetition_ids = use_repetition_ids , repeat_until = repeat_until )
395
441
.with_qubit_mapping (dict (qubit_map ))
396
442
.with_measurement_key_mapping (measurement_key_map )
397
443
.with_params (param_resolver )
0 commit comments