@@ -60,13 +60,17 @@ def map_moments(
60
60
circuit : CIRCUIT_TYPE ,
61
61
map_func : Callable [[circuits .Moment , int ], Union [circuits .Moment , Sequence [circuits .Moment ]]],
62
62
* ,
63
+ tags_to_ignore : Sequence [Hashable ] = (),
63
64
deep : bool = False ,
64
65
) -> CIRCUIT_TYPE :
65
66
"""Applies local transformation on moments, by calling `map_func(moment)` for each moment.
66
67
67
68
Args:
68
69
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
69
70
map_func: Mapping function from (cirq.Moment, moment_index) to a sequence of moments.
71
+ tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be
72
+ ignored when recursively applying the transformer primitive to sub-circuits, given
73
+ deep=True.
70
74
deep: If true, `map_func` will be recursively applied to circuits wrapped inside
71
75
any circuit operations contained within `circuit`.
72
76
@@ -79,6 +83,8 @@ def map_moments(
79
83
for i , op in circuit .findall_operations (
80
84
lambda o : isinstance (o .untagged , circuits .CircuitOperation )
81
85
):
86
+ if set (op .tags ).intersection (tags_to_ignore ):
87
+ continue
82
88
op_untagged = cast (circuits .CircuitOperation , op .untagged )
83
89
mapped_op = op_untagged .replace (
84
90
circuit = map_moments (op_untagged .circuit , map_func , deep = deep )
@@ -190,6 +196,7 @@ def merge_operations(
190
196
merge_func : Callable [[ops .Operation , ops .Operation ], Optional [ops .Operation ]],
191
197
* ,
192
198
tags_to_ignore : Sequence [Hashable ] = (),
199
+ deep : bool = False ,
193
200
) -> CIRCUIT_TYPE :
194
201
"""Merges operations in a circuit by calling `merge_func` iteratively on operations.
195
202
@@ -226,6 +233,8 @@ def merge_operations(
226
233
tags_to_ignore: Sequence of tags which should be ignored while applying `merge_func` on
227
234
tagged operations -- i.e. `merge_func(op1, op2)` will be called only if both `op1` and
228
235
`op2` satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
236
+ deep: If true, the transformer primitive will be recursively applied to all circuits
237
+ wrapped inside circuit operations.
229
238
230
239
231
240
Returns:
@@ -235,9 +244,11 @@ def merge_operations(
235
244
ValueError if the merged operation acts on new qubits outside the set of qubits
236
245
corresponding to the original operations to be merged.
237
246
"""
247
+ _circuit_op_tag = "_internal_tag_to_mark_circuit_ops_in_circuit"
248
+ tags_to_ignore_set = set (tags_to_ignore ) | {_circuit_op_tag }
238
249
239
250
def apply_merge_func (op1 : ops .Operation , op2 : ops .Operation ) -> Optional [ops .Operation ]:
240
- if not all (set (op .tags ). isdisjoint ( tags_to_ignore ) for op in [op1 , op2 ]):
251
+ if not all (tags_to_ignore_set . isdisjoint (op .tags ) for op in [op1 , op2 ]):
241
252
return None
242
253
new_op = merge_func (op1 , op2 )
243
254
qubit_set = frozenset (op1 .qubits + op2 .qubits )
@@ -252,6 +263,23 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope
252
263
for current_moment in circuit :
253
264
new_moment = circuits .Moment ()
254
265
for op in sorted (current_moment .operations , key = lambda op : op .qubits ):
266
+ if (
267
+ deep
268
+ and isinstance (op .untagged , circuits .CircuitOperation )
269
+ and tags_to_ignore_set .isdisjoint (op .tags )
270
+ ):
271
+ op_untagged = op .untagged
272
+ new_moment = new_moment .with_operation (
273
+ op_untagged .replace (
274
+ circuit = merge_operations (
275
+ op_untagged .circuit ,
276
+ merge_func ,
277
+ tags_to_ignore = tags_to_ignore ,
278
+ deep = True ,
279
+ )
280
+ ).with_tags (* op .tags , _circuit_op_tag )
281
+ )
282
+ continue
255
283
op_qs = set (op .qubits )
256
284
idx = ret_circuit .prev_moment_operating_on (tuple (op_qs ))
257
285
if idx is not None and op_qs .issubset (ret_circuit [idx ][op_qs ].operations [0 ].qubits ):
@@ -279,6 +307,12 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope
279
307
idx = ret_circuit .prev_moment_operating_on (tuple (op_qs ))
280
308
new_moment = new_moment .with_operation (op )
281
309
ret_circuit += new_moment
310
+ if deep :
311
+ ret_circuit = map_operations (
312
+ ret_circuit ,
313
+ lambda o , _ : o .untagged .with_tags (* (set (o .tags ) - {_circuit_op_tag })),
314
+ deep = True ,
315
+ )
282
316
return _to_target_circuit_type (ret_circuit , circuit )
283
317
284
318
@@ -288,6 +322,7 @@ def merge_operations_to_circuit_op(
288
322
* ,
289
323
tags_to_ignore : Sequence [Hashable ] = (),
290
324
merged_circuit_op_tag : str = "Merged connected component" ,
325
+ deep : bool = False ,
291
326
) -> CIRCUIT_TYPE :
292
327
"""Merges connected components of operations and wraps each component into a circuit operation.
293
328
@@ -307,6 +342,8 @@ def merge_operations_to_circuit_op(
307
342
potential candidates for any connected component.
308
343
merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
309
344
components.
345
+ deep: If true, the transformer primitive will be recursively applied to all circuits
346
+ wrapped inside circuit operations.
310
347
311
348
Returns:
312
349
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
@@ -329,7 +366,7 @@ def get_ops(op: 'cirq.Operation'):
329
366
merged_circuit_op_tag
330
367
)
331
368
332
- return merge_operations (circuit , merge_func , tags_to_ignore = tags_to_ignore )
369
+ return merge_operations (circuit , merge_func , tags_to_ignore = tags_to_ignore , deep = deep )
333
370
334
371
335
372
def merge_k_qubit_unitaries_to_circuit_op (
@@ -338,6 +375,7 @@ def merge_k_qubit_unitaries_to_circuit_op(
338
375
* ,
339
376
tags_to_ignore : Sequence [Hashable ] = (),
340
377
merged_circuit_op_tag : Optional [str ] = None ,
378
+ deep : bool = False ,
341
379
) -> CIRCUIT_TYPE :
342
380
"""Merges connected components of operations, acting on <= k qubits, into circuit operations.
343
381
@@ -353,6 +391,8 @@ def merge_k_qubit_unitaries_to_circuit_op(
353
391
potential candidates for any connected component.
354
392
merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
355
393
components. A default tag is applied if left None.
394
+ deep: If true, the transformer primitive will be recursively applied to all circuits
395
+ wrapped inside circuit operations.
356
396
357
397
Returns:
358
398
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
@@ -370,12 +410,16 @@ def can_merge(ops1: Sequence['cirq.Operation'], ops2: Sequence['cirq.Operation']
370
410
can_merge ,
371
411
tags_to_ignore = tags_to_ignore ,
372
412
merged_circuit_op_tag = merged_circuit_op_tag or f"Merged { k } q unitary connected component." ,
413
+ deep = deep ,
373
414
)
374
415
375
416
376
417
def merge_moments (
377
418
circuit : CIRCUIT_TYPE ,
378
419
merge_func : Callable [[circuits .Moment , circuits .Moment ], Optional [circuits .Moment ]],
420
+ * ,
421
+ tags_to_ignore : Sequence [Hashable ] = (),
422
+ deep : bool = False ,
379
423
) -> CIRCUIT_TYPE :
380
424
"""Merges adjacent moments, one by one from left to right, by calling `merge_func(m1, m2)`.
381
425
@@ -384,12 +428,27 @@ def merge_moments(
384
428
merge_func: Callable to determine whether two adjacent moments in the circuit should be
385
429
merged. If the moments can be merged, the callable should return the merged moment,
386
430
else None.
431
+ tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be
432
+ ignored when recursively applying the transformer primitive to sub-circuits, given
433
+ deep=True.
434
+ deep: If true, the transformer primitive will be recursively applied to all circuits
435
+ wrapped inside circuit operations.
387
436
388
437
Returns:
389
438
Copy of input circuit with merged moments.
390
439
"""
391
440
if not circuit :
392
441
return circuit
442
+ if deep :
443
+ circuit = map_operations (
444
+ circuit ,
445
+ lambda op , _ : op .untagged .replace (
446
+ circuit = merge_moments (op .untagged .circuit , merge_func , deep = deep )
447
+ ).with_tags (* op .tags )
448
+ if isinstance (op .untagged , circuits .CircuitOperation )
449
+ else op ,
450
+ tags_to_ignore = tags_to_ignore ,
451
+ )
393
452
merged_moments : List [circuits .Moment ] = [circuit [0 ]]
394
453
for current_moment in circuit [1 :]:
395
454
merged_moment = merge_func (merged_moments [- 1 ], current_moment )
0 commit comments