Skip to content

Commit 2693951

Browse files
authored
Improve support for recursively applying transformer primitives on circuit operations using deep=True (#5103)
- Part of fixing #5039 - Fixes multiple bugs and improves support for `deep=True` flag in transformer primitives.
1 parent caadb0c commit 2693951

File tree

2 files changed

+161
-56
lines changed

2 files changed

+161
-56
lines changed

cirq-core/cirq/transformers/transformer_primitives.py

+65-42
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def map_moments(
8181
):
8282
op_untagged = cast(circuits.CircuitOperation, op.untagged)
8383
mapped_op = op_untagged.replace(
84-
circuit=map_moments(op_untagged.mapped_circuit(), map_func, deep=deep).freeze()
85-
)
84+
circuit=map_moments(op_untagged.circuit, map_func, deep=deep)
85+
).with_tags(*op.tags)
8686
batch_replace.append((i, op, mapped_op))
8787
mutable_circuit = circuit.unfreeze(copy=True)
8888
mutable_circuit.batch_replace(batch_replace)
@@ -180,7 +180,8 @@ def map_operations_and_unroll(
180180
deep=deep,
181181
raise_if_add_qubits=raise_if_add_qubits,
182182
tags_to_ignore=tags_to_ignore,
183-
)
183+
),
184+
deep=deep,
184185
)
185186

186187

@@ -399,12 +400,6 @@ def merge_moments(
399400
return _create_target_circuit_type(merged_moments, circuit)
400401

401402

402-
def _check_circuit_op(op, tags_to_check: Optional[Sequence[Hashable]]) -> bool:
403-
return isinstance(op.untagged, circuits.CircuitOperation) and (
404-
tags_to_check is None or any(tag in op.tags for tag in tags_to_check)
405-
)
406-
407-
408403
def unroll_circuit_op(
409404
circuit: CIRCUIT_TYPE,
410405
*,
@@ -418,8 +413,8 @@ def unroll_circuit_op(
418413
419414
Args:
420415
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
421-
deep: If True, `unroll_circuit_op` is recursively called on all circuit operations matching
422-
`tags_to_check`.
416+
deep: If true, the transformer primitive will be recursively applied to all circuits
417+
wrapped inside circuit operations.
423418
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
424419
are unrolled.
425420
@@ -430,12 +425,18 @@ def unroll_circuit_op(
430425
def map_func(m: circuits.Moment, _: int):
431426
to_zip: List['cirq.AbstractCircuit'] = []
432427
for op in m:
433-
if _check_circuit_op(op, tags_to_check):
434-
sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit()
428+
op_untagged = op.untagged
429+
if isinstance(op_untagged, circuits.CircuitOperation):
430+
if deep:
431+
op_untagged = op_untagged.replace(
432+
circuit=unroll_circuit_op(
433+
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
434+
)
435+
)
435436
to_zip.append(
436-
unroll_circuit_op(sub_circuit, deep=deep, tags_to_check=tags_to_check)
437-
if deep
438-
else sub_circuit
437+
op_untagged.mapped_circuit()
438+
if (tags_to_check is None or set(tags_to_check).intersection(op.tags))
439+
else circuits.Circuit(op_untagged.with_tags(*op.tags))
439440
)
440441
else:
441442
to_zip.append(circuits.Circuit(op))
@@ -458,27 +459,36 @@ def unroll_circuit_op_greedy_earliest(
458459
459460
Args:
460461
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
461-
deep: If True, `unroll_circuit_op_greedy_earliest` is recursively called on all circuit
462-
operations matching `tags_to_check`.
462+
deep: If true, the transformer primitive will be recursively applied to all circuits
463+
wrapped inside circuit operations.
463464
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
464465
are unrolled.
465466
466467
Returns:
467468
Copy of input circuit with (Tagged) CircuitOperation's expanded using EARLIEST strategy.
468469
"""
469-
batch_removals = [*circuit.findall_operations(lambda op: _check_circuit_op(op, tags_to_check))]
470-
batch_inserts = []
471-
for i, op in batch_removals:
472-
sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit()
473-
sub_circuit = (
474-
unroll_circuit_op_greedy_earliest(sub_circuit, deep=deep, tags_to_check=tags_to_check)
475-
if deep
476-
else sub_circuit
477-
)
478-
batch_inserts += [(i, sub_circuit.all_operations())]
470+
batch_replace = []
471+
batch_remove = []
472+
batch_insert = []
473+
for i, op in circuit.findall_operations(
474+
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
475+
):
476+
op_untagged = cast(circuits.CircuitOperation, op.untagged)
477+
if deep:
478+
op_untagged = op_untagged.replace(
479+
circuit=unroll_circuit_op_greedy_earliest(
480+
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
481+
)
482+
)
483+
if tags_to_check is None or set(tags_to_check).intersection(op.tags):
484+
batch_remove.append((i, op))
485+
batch_insert.append((i, op_untagged.mapped_circuit().all_operations()))
486+
elif deep:
487+
batch_replace.append((i, op, op_untagged.with_tags(*op.tags)))
479488
unrolled_circuit = circuit.unfreeze(copy=True)
480-
unrolled_circuit.batch_remove(batch_removals)
481-
unrolled_circuit.batch_insert(batch_inserts)
489+
unrolled_circuit.batch_replace(batch_replace)
490+
unrolled_circuit.batch_remove(batch_remove)
491+
unrolled_circuit.batch_insert(batch_insert)
482492
return _to_target_circuit_type(unrolled_circuit, circuit)
483493

484494

@@ -496,8 +506,8 @@ def unroll_circuit_op_greedy_frontier(
496506
497507
Args:
498508
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
499-
deep: If True, `unroll_circuit_op_greedy_frontier` is recursively called on all circuit
500-
operations matching `tags_to_check`.
509+
deep: If true, the transformer primitive will be recursively applied to all circuits
510+
wrapped inside circuit operations.
501511
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
502512
are unrolled.
503513
@@ -506,16 +516,29 @@ def unroll_circuit_op_greedy_frontier(
506516
"""
507517
unrolled_circuit = circuit.unfreeze(copy=True)
508518
frontier: Dict['cirq.Qid', int] = defaultdict(lambda: 0)
509-
for idx, op in circuit.findall_operations(lambda op: _check_circuit_op(op, tags_to_check)):
510-
idx = max(idx, max(frontier[q] for q in op.qubits))
511-
unrolled_circuit.clear_operations_touching(op.qubits, [idx])
512-
sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit()
513-
sub_circuit = (
514-
unroll_circuit_op_greedy_earliest(sub_circuit, deep=deep, tags_to_check=tags_to_check)
515-
if deep
516-
else sub_circuit
517-
)
518-
frontier = unrolled_circuit.insert_at_frontier(sub_circuit.all_operations(), idx, frontier)
519+
idx = 0
520+
while idx < len(unrolled_circuit):
521+
for op in unrolled_circuit[idx].operations:
522+
# Don't touch stuff inserted by unrolling previous circuit ops.
523+
if not isinstance(op.untagged, circuits.CircuitOperation):
524+
continue
525+
if any(frontier[q] > idx for q in op.qubits):
526+
continue
527+
op_untagged = cast(circuits.CircuitOperation, op.untagged)
528+
if deep:
529+
op_untagged = op_untagged.replace(
530+
circuit=unroll_circuit_op_greedy_frontier(
531+
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
532+
)
533+
)
534+
if tags_to_check is None or set(tags_to_check).intersection(op.tags):
535+
unrolled_circuit.clear_operations_touching(op.qubits, [idx])
536+
frontier = unrolled_circuit.insert_at_frontier(
537+
op_untagged.mapped_circuit().all_operations(), idx, frontier
538+
)
539+
elif deep:
540+
unrolled_circuit.batch_replace([(idx, op, op_untagged.with_tags(*op.tags))])
541+
idx += 1
519542
return _to_target_circuit_type(unrolled_circuit, circuit)
520543

521544

cirq-core/cirq/transformers/transformer_primitives_test.py

+96-14
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
117117
)
118118

119119

120+
# pylint: disable=line-too-long
120121
def test_map_operations_deep_subcircuits():
121122
q = cirq.LineQubit.range(5)
122123
c_orig = cirq.Circuit(
@@ -127,9 +128,14 @@ def test_map_operations_deep_subcircuits():
127128
c_orig_with_circuit_ops = cirq.Circuit(
128129
cirq.CircuitOperation(
129130
cirq.FrozenCircuit(
130-
[cirq.CircuitOperation(cirq.FrozenCircuit(op)) for op in c_orig.all_operations()]
131+
[
132+
cirq.CircuitOperation(cirq.FrozenCircuit(op)).repeat(2).with_tags("internal")
133+
for op in c_orig.all_operations()
134+
]
131135
)
132136
)
137+
.repeat(6)
138+
.with_tags("external")
133139
)
134140

135141
def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
@@ -139,23 +145,73 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
139145
cirq.Z.on_each(*op.qubits),
140146
] if op.gate == cirq.CX else op
141147

142-
c_mapped = cirq.map_operations(c_orig_with_circuit_ops, map_func, deep=True)
143-
c_mapped = cirq.unroll_circuit_op(c_mapped, deep=True, tags_to_check=None)
144148
cirq.testing.assert_has_diagram(
145-
c_mapped,
149+
c_orig_with_circuit_ops,
146150
'''
147-
0: ───Z───@───Z───────────────
148-
149-
1: ───Z───X───Z───────────────
150-
151-
2: ───Z───X───Z───────────────
152-
153-
3: ───Z───@───Z───Z───@───Z───
154-
155-
4: ───────────────Z───X───Z───
151+
[ [ 0: ───@─── ] ]
152+
[ 0: ───[ │ ]────────────────────────────────────────────────────────────── ]
153+
[ [ 1: ───X─── ](loops=2)['internal'] ]
154+
[ │ ]
155+
[ 1: ───#2────────────────────────────────────────────────────────────────────────── ]
156+
[ ]
157+
[ [ 2: ───X─── ] ]
158+
0: ───[ 2: ───[ │ ]────────────────────────────────────────────────────────────── ]────────────────────────
159+
[ [ 3: ───@─── ](loops=2)['internal'] ]
160+
[ │ ]
161+
[ │ [ 3: ───@─── ] ]
162+
[ 3: ───#2────────────────────────────────────[ │ ]──────────────────────── ]
163+
[ [ 4: ───X─── ](loops=2)['internal'] ]
164+
[ │ ]
165+
[ 4: ─────────────────────────────────────────#2──────────────────────────────────── ](loops=6)['external']
166+
167+
1: ───#2────────────────────────────────────────────────────────────────────────────────────────────────────────────
168+
169+
2: ───#3────────────────────────────────────────────────────────────────────────────────────────────────────────────
170+
171+
3: ───#4────────────────────────────────────────────────────────────────────────────────────────────────────────────
172+
173+
4: ───#5────────────────────────────────────────────────────────────────────────────────────────────────────────────
156174
''',
157175
)
158176

177+
c_mapped = cirq.map_operations(c_orig_with_circuit_ops, map_func, deep=True)
178+
for unroller in [
179+
cirq.unroll_circuit_op,
180+
cirq.unroll_circuit_op_greedy_earliest,
181+
cirq.unroll_circuit_op_greedy_frontier,
182+
]:
183+
cirq.testing.assert_has_diagram(
184+
unroller(c_mapped, deep=True),
185+
'''
186+
[ [ 0: ───Z───@───Z─── ] ]
187+
[ 0: ───[ │ ]────────────────────────────────────────────────────────────────────── ]
188+
[ [ 1: ───Z───X───Z─── ](loops=2)['internal'] ]
189+
[ │ ]
190+
[ 1: ───#2────────────────────────────────────────────────────────────────────────────────────────── ]
191+
[ ]
192+
[ [ 2: ───Z───X───Z─── ] ]
193+
0: ───[ 2: ───[ │ ]────────────────────────────────────────────────────────────────────── ]────────────────────────
194+
[ [ 3: ───Z───@───Z─── ](loops=2)['internal'] ]
195+
[ │ ]
196+
[ │ [ 3: ───Z───@───Z─── ] ]
197+
[ 3: ───#2────────────────────────────────────────────[ │ ]──────────────────────── ]
198+
[ [ 4: ───Z───X───Z─── ](loops=2)['internal'] ]
199+
[ │ ]
200+
[ 4: ─────────────────────────────────────────────────#2──────────────────────────────────────────── ](loops=6)['external']
201+
202+
1: ───#2────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
203+
204+
2: ───#3────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
205+
206+
3: ───#4────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
207+
208+
4: ───#5────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
209+
''',
210+
)
211+
212+
213+
# pylint: enable=line-too-long
214+
159215

160216
def test_map_operations_respects_tags_to_ignore():
161217
q = cirq.LineQubit.range(2)
@@ -204,13 +260,29 @@ def test_unroll_circuit_op_and_variants():
204260
[cirq.Moment(cirq.CircuitOperation(cirq.FrozenCircuit(m))) for m in mapped_circuit[:-1]],
205261
mapped_circuit[-1],
206262
)
263+
cirq.testing.assert_has_diagram(
264+
mapped_circuit_deep,
265+
'''
266+
0: ───[ 0: ───X─── ]────────────────────────────────────────────────────────────X───
267+
268+
1: ────────────────────[ 1: ───[ 1: ───Z───Z─── ]['<mapped_circuit_op>']─── ]───────
269+
''',
270+
)
207271
for unroller in [
208272
cirq.unroll_circuit_op_greedy_earliest,
209273
cirq.unroll_circuit_op_greedy_frontier,
210274
cirq.unroll_circuit_op,
211275
]:
212276
cirq.testing.assert_same_circuits(
213-
unroller(mapped_circuit), unroller(mapped_circuit_deep, tags_to_check=None, deep=True)
277+
unroller(mapped_circuit), unroller(mapped_circuit_deep, deep=True, tags_to_check=None)
278+
)
279+
cirq.testing.assert_has_diagram(
280+
unroller(mapped_circuit_deep, deep=True),
281+
'''
282+
0: ───[ 0: ───X─── ]────────────────────────X───
283+
284+
1: ────────────────────[ 1: ───Z───Z─── ]───────
285+
''',
214286
)
215287

216288
cirq.testing.assert_has_diagram(
@@ -239,6 +311,16 @@ def test_unroll_circuit_op_and_variants():
239311
)
240312

241313

314+
def test_unroll_circuit_op_greedy_frontier_doesnt_touch_same_op_twice():
315+
q = cirq.NamedQubit("q")
316+
nested_ops = [cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q)))] * 5
317+
nested_circuit_op = cirq.CircuitOperation(cirq.FrozenCircuit(nested_ops))
318+
c = cirq.Circuit(nested_circuit_op, nested_circuit_op, nested_circuit_op)
319+
c_expected = cirq.Circuit(nested_ops, nested_ops, nested_ops)
320+
c_unrolled = cirq.unroll_circuit_op_greedy_frontier(c, tags_to_check=None)
321+
cirq.testing.assert_same_circuits(c_unrolled, c_expected)
322+
323+
242324
def test_unroll_circuit_op_deep():
243325
q0, q1, q2 = cirq.LineQubit.range(3)
244326
c = cirq.Circuit(

0 commit comments

Comments
 (0)