@@ -81,8 +81,8 @@ def map_moments(
81
81
):
82
82
op_untagged = cast (circuits .CircuitOperation , op .untagged )
83
83
mapped_op = op_untagged .replace (
84
- circuit = map_moments (op_untagged .mapped_circuit () , map_func , deep = deep ). freeze ( )
85
- )
84
+ circuit = map_moments (op_untagged .circuit , map_func , deep = deep )
85
+ ). with_tags ( * op . tags )
86
86
batch_replace .append ((i , op , mapped_op ))
87
87
mutable_circuit = circuit .unfreeze (copy = True )
88
88
mutable_circuit .batch_replace (batch_replace )
@@ -180,7 +180,8 @@ def map_operations_and_unroll(
180
180
deep = deep ,
181
181
raise_if_add_qubits = raise_if_add_qubits ,
182
182
tags_to_ignore = tags_to_ignore ,
183
- )
183
+ ),
184
+ deep = deep ,
184
185
)
185
186
186
187
@@ -399,12 +400,6 @@ def merge_moments(
399
400
return _create_target_circuit_type (merged_moments , circuit )
400
401
401
402
402
- def _check_circuit_op (op , tags_to_check : Optional [Sequence [Hashable ]]) -> bool :
403
- return isinstance (op .untagged , circuits .CircuitOperation ) and (
404
- tags_to_check is None or any (tag in op .tags for tag in tags_to_check )
405
- )
406
-
407
-
408
403
def unroll_circuit_op (
409
404
circuit : CIRCUIT_TYPE ,
410
405
* ,
@@ -418,8 +413,8 @@ def unroll_circuit_op(
418
413
419
414
Args:
420
415
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
421
- deep: If True, `unroll_circuit_op` is recursively called on all circuit operations matching
422
- `tags_to_check` .
416
+ deep: If true, the transformer primitive will be recursively applied to all circuits
417
+ wrapped inside circuit operations .
423
418
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
424
419
are unrolled.
425
420
@@ -430,12 +425,18 @@ def unroll_circuit_op(
430
425
def map_func (m : circuits .Moment , _ : int ):
431
426
to_zip : List ['cirq.AbstractCircuit' ] = []
432
427
for op in m :
433
- if _check_circuit_op (op , tags_to_check ):
434
- sub_circuit = cast (circuits .CircuitOperation , op .untagged ).mapped_circuit ()
428
+ op_untagged = op .untagged
429
+ if isinstance (op_untagged , circuits .CircuitOperation ):
430
+ if deep :
431
+ op_untagged = op_untagged .replace (
432
+ circuit = unroll_circuit_op (
433
+ op_untagged .circuit , deep = deep , tags_to_check = tags_to_check
434
+ )
435
+ )
435
436
to_zip .append (
436
- unroll_circuit_op ( sub_circuit , deep = deep , tags_to_check = tags_to_check )
437
- if deep
438
- else sub_circuit
437
+ op_untagged . mapped_circuit ( )
438
+ if ( tags_to_check is None or set ( tags_to_check ). intersection ( op . tags ))
439
+ else circuits . Circuit ( op_untagged . with_tags ( * op . tags ))
439
440
)
440
441
else :
441
442
to_zip .append (circuits .Circuit (op ))
@@ -458,27 +459,36 @@ def unroll_circuit_op_greedy_earliest(
458
459
459
460
Args:
460
461
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
461
- deep: If True, `unroll_circuit_op_greedy_earliest` is recursively called on all circuit
462
- operations matching `tags_to_check` .
462
+ deep: If true, the transformer primitive will be recursively applied to all circuits
463
+ wrapped inside circuit operations .
463
464
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
464
465
are unrolled.
465
466
466
467
Returns:
467
468
Copy of input circuit with (Tagged) CircuitOperation's expanded using EARLIEST strategy.
468
469
"""
469
- batch_removals = [* circuit .findall_operations (lambda op : _check_circuit_op (op , tags_to_check ))]
470
- batch_inserts = []
471
- for i , op in batch_removals :
472
- sub_circuit = cast (circuits .CircuitOperation , op .untagged ).mapped_circuit ()
473
- sub_circuit = (
474
- unroll_circuit_op_greedy_earliest (sub_circuit , deep = deep , tags_to_check = tags_to_check )
475
- if deep
476
- else sub_circuit
477
- )
478
- batch_inserts += [(i , sub_circuit .all_operations ())]
470
+ batch_replace = []
471
+ batch_remove = []
472
+ batch_insert = []
473
+ for i , op in circuit .findall_operations (
474
+ lambda o : isinstance (o .untagged , circuits .CircuitOperation )
475
+ ):
476
+ op_untagged = cast (circuits .CircuitOperation , op .untagged )
477
+ if deep :
478
+ op_untagged = op_untagged .replace (
479
+ circuit = unroll_circuit_op_greedy_earliest (
480
+ op_untagged .circuit , deep = deep , tags_to_check = tags_to_check
481
+ )
482
+ )
483
+ if tags_to_check is None or set (tags_to_check ).intersection (op .tags ):
484
+ batch_remove .append ((i , op ))
485
+ batch_insert .append ((i , op_untagged .mapped_circuit ().all_operations ()))
486
+ elif deep :
487
+ batch_replace .append ((i , op , op_untagged .with_tags (* op .tags )))
479
488
unrolled_circuit = circuit .unfreeze (copy = True )
480
- unrolled_circuit .batch_remove (batch_removals )
481
- unrolled_circuit .batch_insert (batch_inserts )
489
+ unrolled_circuit .batch_replace (batch_replace )
490
+ unrolled_circuit .batch_remove (batch_remove )
491
+ unrolled_circuit .batch_insert (batch_insert )
482
492
return _to_target_circuit_type (unrolled_circuit , circuit )
483
493
484
494
@@ -496,8 +506,8 @@ def unroll_circuit_op_greedy_frontier(
496
506
497
507
Args:
498
508
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
499
- deep: If True, `unroll_circuit_op_greedy_frontier` is recursively called on all circuit
500
- operations matching `tags_to_check` .
509
+ deep: If true, the transformer primitive will be recursively applied to all circuits
510
+ wrapped inside circuit operations .
501
511
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
502
512
are unrolled.
503
513
@@ -506,16 +516,29 @@ def unroll_circuit_op_greedy_frontier(
506
516
"""
507
517
unrolled_circuit = circuit .unfreeze (copy = True )
508
518
frontier : Dict ['cirq.Qid' , int ] = defaultdict (lambda : 0 )
509
- for idx , op in circuit .findall_operations (lambda op : _check_circuit_op (op , tags_to_check )):
510
- idx = max (idx , max (frontier [q ] for q in op .qubits ))
511
- unrolled_circuit .clear_operations_touching (op .qubits , [idx ])
512
- sub_circuit = cast (circuits .CircuitOperation , op .untagged ).mapped_circuit ()
513
- sub_circuit = (
514
- unroll_circuit_op_greedy_earliest (sub_circuit , deep = deep , tags_to_check = tags_to_check )
515
- if deep
516
- else sub_circuit
517
- )
518
- frontier = unrolled_circuit .insert_at_frontier (sub_circuit .all_operations (), idx , frontier )
519
+ idx = 0
520
+ while idx < len (unrolled_circuit ):
521
+ for op in unrolled_circuit [idx ].operations :
522
+ # Don't touch stuff inserted by unrolling previous circuit ops.
523
+ if not isinstance (op .untagged , circuits .CircuitOperation ):
524
+ continue
525
+ if any (frontier [q ] > idx for q in op .qubits ):
526
+ continue
527
+ op_untagged = cast (circuits .CircuitOperation , op .untagged )
528
+ if deep :
529
+ op_untagged = op_untagged .replace (
530
+ circuit = unroll_circuit_op_greedy_frontier (
531
+ op_untagged .circuit , deep = deep , tags_to_check = tags_to_check
532
+ )
533
+ )
534
+ if tags_to_check is None or set (tags_to_check ).intersection (op .tags ):
535
+ unrolled_circuit .clear_operations_touching (op .qubits , [idx ])
536
+ frontier = unrolled_circuit .insert_at_frontier (
537
+ op_untagged .mapped_circuit ().all_operations (), idx , frontier
538
+ )
539
+ elif deep :
540
+ unrolled_circuit .batch_replace ([(idx , op , op_untagged .with_tags (* op .tags ))])
541
+ idx += 1
519
542
return _to_target_circuit_type (unrolled_circuit , circuit )
520
543
521
544
0 commit comments