Skip to content

Commit 6af5387

Browse files
authored
Add support for deep=True flag to remaining transformer primitives (#5106)
1 parent 2693951 commit 6af5387

File tree

2 files changed

+158
-16
lines changed

2 files changed

+158
-16
lines changed

cirq-core/cirq/transformers/transformer_primitives.py

+61-2
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,17 @@ def map_moments(
6060
circuit: CIRCUIT_TYPE,
6161
map_func: Callable[[circuits.Moment, int], Union[circuits.Moment, Sequence[circuits.Moment]]],
6262
*,
63+
tags_to_ignore: Sequence[Hashable] = (),
6364
deep: bool = False,
6465
) -> CIRCUIT_TYPE:
6566
"""Applies local transformation on moments, by calling `map_func(moment)` for each moment.
6667
6768
Args:
6869
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
6970
map_func: Mapping function from (cirq.Moment, moment_index) to a sequence of moments.
71+
tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be
72+
ignored when recursively applying the transformer primitive to sub-circuits, given
73+
deep=True.
7074
deep: If true, `map_func` will be recursively applied to circuits wrapped inside
7175
any circuit operations contained within `circuit`.
7276
@@ -79,6 +83,8 @@ def map_moments(
7983
for i, op in circuit.findall_operations(
8084
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
8185
):
86+
if set(op.tags).intersection(tags_to_ignore):
87+
continue
8288
op_untagged = cast(circuits.CircuitOperation, op.untagged)
8389
mapped_op = op_untagged.replace(
8490
circuit=map_moments(op_untagged.circuit, map_func, deep=deep)
@@ -190,6 +196,7 @@ def merge_operations(
190196
merge_func: Callable[[ops.Operation, ops.Operation], Optional[ops.Operation]],
191197
*,
192198
tags_to_ignore: Sequence[Hashable] = (),
199+
deep: bool = False,
193200
) -> CIRCUIT_TYPE:
194201
"""Merges operations in a circuit by calling `merge_func` iteratively on operations.
195202
@@ -226,6 +233,8 @@ def merge_operations(
226233
tags_to_ignore: Sequence of tags which should be ignored while applying `merge_func` on
227234
tagged operations -- i.e. `merge_func(op1, op2)` will be called only if both `op1` and
228235
`op2` satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
236+
deep: If true, the transformer primitive will be recursively applied to all circuits
237+
wrapped inside circuit operations.
229238
230239
231240
Returns:
@@ -235,9 +244,11 @@ def merge_operations(
235244
ValueError if the merged operation acts on new qubits outside the set of qubits
236245
corresponding to the original operations to be merged.
237246
"""
247+
_circuit_op_tag = "_internal_tag_to_mark_circuit_ops_in_circuit"
248+
tags_to_ignore_set = set(tags_to_ignore) | {_circuit_op_tag}
238249

239250
def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Operation]:
240-
if not all(set(op.tags).isdisjoint(tags_to_ignore) for op in [op1, op2]):
251+
if not all(tags_to_ignore_set.isdisjoint(op.tags) for op in [op1, op2]):
241252
return None
242253
new_op = merge_func(op1, op2)
243254
qubit_set = frozenset(op1.qubits + op2.qubits)
@@ -252,6 +263,23 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope
252263
for current_moment in circuit:
253264
new_moment = circuits.Moment()
254265
for op in sorted(current_moment.operations, key=lambda op: op.qubits):
266+
if (
267+
deep
268+
and isinstance(op.untagged, circuits.CircuitOperation)
269+
and tags_to_ignore_set.isdisjoint(op.tags)
270+
):
271+
op_untagged = op.untagged
272+
new_moment = new_moment.with_operation(
273+
op_untagged.replace(
274+
circuit=merge_operations(
275+
op_untagged.circuit,
276+
merge_func,
277+
tags_to_ignore=tags_to_ignore,
278+
deep=True,
279+
)
280+
).with_tags(*op.tags, _circuit_op_tag)
281+
)
282+
continue
255283
op_qs = set(op.qubits)
256284
idx = ret_circuit.prev_moment_operating_on(tuple(op_qs))
257285
if idx is not None and op_qs.issubset(ret_circuit[idx][op_qs].operations[0].qubits):
@@ -279,6 +307,12 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope
279307
idx = ret_circuit.prev_moment_operating_on(tuple(op_qs))
280308
new_moment = new_moment.with_operation(op)
281309
ret_circuit += new_moment
310+
if deep:
311+
ret_circuit = map_operations(
312+
ret_circuit,
313+
lambda o, _: o.untagged.with_tags(*(set(o.tags) - {_circuit_op_tag})),
314+
deep=True,
315+
)
282316
return _to_target_circuit_type(ret_circuit, circuit)
283317

284318

@@ -288,6 +322,7 @@ def merge_operations_to_circuit_op(
288322
*,
289323
tags_to_ignore: Sequence[Hashable] = (),
290324
merged_circuit_op_tag: str = "Merged connected component",
325+
deep: bool = False,
291326
) -> CIRCUIT_TYPE:
292327
"""Merges connected components of operations and wraps each component into a circuit operation.
293328
@@ -307,6 +342,8 @@ def merge_operations_to_circuit_op(
307342
potential candidates for any connected component.
308343
merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
309344
components.
345+
deep: If true, the transformer primitive will be recursively applied to all circuits
346+
wrapped inside circuit operations.
310347
311348
Returns:
312349
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
@@ -329,7 +366,7 @@ def get_ops(op: 'cirq.Operation'):
329366
merged_circuit_op_tag
330367
)
331368

332-
return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore)
369+
return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore, deep=deep)
333370

334371

335372
def merge_k_qubit_unitaries_to_circuit_op(
@@ -338,6 +375,7 @@ def merge_k_qubit_unitaries_to_circuit_op(
338375
*,
339376
tags_to_ignore: Sequence[Hashable] = (),
340377
merged_circuit_op_tag: Optional[str] = None,
378+
deep: bool = False,
341379
) -> CIRCUIT_TYPE:
342380
"""Merges connected components of operations, acting on <= k qubits, into circuit operations.
343381
@@ -353,6 +391,8 @@ def merge_k_qubit_unitaries_to_circuit_op(
353391
potential candidates for any connected component.
354392
merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
355393
components. A default tag is applied if left None.
394+
deep: If true, the transformer primitive will be recursively applied to all circuits
395+
wrapped inside circuit operations.
356396
357397
Returns:
358398
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
@@ -370,12 +410,16 @@ def can_merge(ops1: Sequence['cirq.Operation'], ops2: Sequence['cirq.Operation']
370410
can_merge,
371411
tags_to_ignore=tags_to_ignore,
372412
merged_circuit_op_tag=merged_circuit_op_tag or f"Merged {k}q unitary connected component.",
413+
deep=deep,
373414
)
374415

375416

376417
def merge_moments(
377418
circuit: CIRCUIT_TYPE,
378419
merge_func: Callable[[circuits.Moment, circuits.Moment], Optional[circuits.Moment]],
420+
*,
421+
tags_to_ignore: Sequence[Hashable] = (),
422+
deep: bool = False,
379423
) -> CIRCUIT_TYPE:
380424
"""Merges adjacent moments, one by one from left to right, by calling `merge_func(m1, m2)`.
381425
@@ -384,12 +428,27 @@ def merge_moments(
384428
merge_func: Callable to determine whether two adjacent moments in the circuit should be
385429
merged. If the moments can be merged, the callable should return the merged moment,
386430
else None.
431+
tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be
432+
ignored when recursively applying the transformer primitive to sub-circuits, given
433+
deep=True.
434+
deep: If true, the transformer primitive will be recursively applied to all circuits
435+
wrapped inside circuit operations.
387436
388437
Returns:
389438
Copy of input circuit with merged moments.
390439
"""
391440
if not circuit:
392441
return circuit
442+
if deep:
443+
circuit = map_operations(
444+
circuit,
445+
lambda op, _: op.untagged.replace(
446+
circuit=merge_moments(op.untagged.circuit, merge_func, deep=deep)
447+
).with_tags(*op.tags)
448+
if isinstance(op.untagged, circuits.CircuitOperation)
449+
else op,
450+
tags_to_ignore=tags_to_ignore,
451+
)
393452
merged_moments: List[circuits.Moment] = [circuit[0]]
394453
for current_moment in circuit[1:]:
395454
merged_moment = merge_func(merged_moments[-1], current_moment)

cirq-core/cirq/transformers/transformer_primitives_test.py

+97-14
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,35 @@ def test_map_moments_drop_empty_moments():
399399
cirq.testing.assert_same_circuits(c_mapped, cirq.Circuit(c[0], c[0]))
400400

401401

402+
def test_map_moments_drop_empty_moments_deep():
403+
op = cirq.X(cirq.NamedQubit("q"))
404+
c_nested = cirq.FrozenCircuit(cirq.Moment(op), cirq.Moment(), cirq.Moment(op))
405+
c_orig = cirq.Circuit(
406+
c_nested,
407+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
408+
c_nested,
409+
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
410+
)
411+
c_expected = cirq.Circuit(
412+
[op, op],
413+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
414+
[op, op],
415+
cirq.CircuitOperation(cirq.FrozenCircuit([op, op])).repeat(5).with_tags("preserve_tag"),
416+
)
417+
c_mapped = cirq.map_moments(
418+
c_orig, lambda m, i: [] if len(m) == 0 else [m], deep=True, tags_to_ignore=("ignore",)
419+
)
420+
cirq.testing.assert_same_circuits(c_mapped, c_expected)
421+
422+
423+
def _merge_z_moments_func(m1: cirq.Moment, m2: cirq.Moment) -> Optional[cirq.Moment]:
424+
if any(op.gate != cirq.Z for m in [m1, m2] for op in m):
425+
return None
426+
return cirq.Moment(
427+
cirq.Z(q) for q in (m1.qubits | m2.qubits) if m1.operates_on([q]) ^ m2.operates_on([q])
428+
)
429+
430+
402431
def test_merge_moments():
403432
q = cirq.LineQubit.range(3)
404433
c_orig = cirq.Circuit(
@@ -419,21 +448,8 @@ def test_merge_moments():
419448
''',
420449
)
421450

422-
def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> Optional[cirq.Moment]:
423-
def is_z_moment(m):
424-
return all(op.gate == cirq.Z for op in m)
425-
426-
if not (is_z_moment(m1) and is_z_moment(m2)):
427-
return None
428-
qubits = m1.qubits | m2.qubits
429-
430-
def mul(op1, op2):
431-
return (op1 or op2) if not (op1 and op2) else cirq.decompose_once(op1 * op2)
432-
433-
return cirq.Moment(mul(m1.operation_at(q), m2.operation_at(q)) for q in qubits)
434-
435451
cirq.testing.assert_has_diagram(
436-
cirq.merge_moments(c_orig, merge_func),
452+
cirq.merge_moments(c_orig, _merge_z_moments_func),
437453
'''
438454
0: ───────@───────
439455
@@ -444,6 +460,35 @@ def mul(op1, op2):
444460
)
445461

446462

463+
def test_merge_moments_deep():
464+
q = cirq.LineQubit.range(3)
465+
c_z_moments = cirq.Circuit(
466+
[cirq.Z.on_each(q[0], q[1]), cirq.Z.on_each(q[1], q[2]), cirq.Z.on_each(q[1], q[0])],
467+
strategy=cirq.InsertStrategy.NEW_THEN_INLINE,
468+
)
469+
merged_z_moment = cirq.Moment(cirq.Z.on_each(*q[1:]))
470+
c_nested_circuit = cirq.FrozenCircuit(c_z_moments, cirq.CCX(*q), c_z_moments)
471+
c_merged_circuit = cirq.FrozenCircuit(merged_z_moment, cirq.CCX(*q), merged_z_moment)
472+
c_orig = cirq.Circuit(
473+
cirq.CircuitOperation(c_nested_circuit).repeat(5).with_tags("ignore"),
474+
c_nested_circuit,
475+
cirq.CircuitOperation(c_nested_circuit).repeat(6).with_tags("preserve_tag"),
476+
c_nested_circuit,
477+
cirq.CircuitOperation(c_nested_circuit).repeat(7),
478+
)
479+
c_expected = cirq.Circuit(
480+
cirq.CircuitOperation(c_nested_circuit).repeat(5).with_tags("ignore"),
481+
c_merged_circuit,
482+
cirq.CircuitOperation(c_merged_circuit).repeat(6).with_tags("preserve_tag"),
483+
c_merged_circuit,
484+
cirq.CircuitOperation(c_merged_circuit).repeat(7),
485+
)
486+
cirq.testing.assert_same_circuits(
487+
cirq.merge_moments(c_orig, _merge_z_moments_func, tags_to_ignore=("ignore",), deep=True),
488+
c_expected,
489+
)
490+
491+
447492
def test_merge_moments_empty_moment_as_intermediate_step():
448493
q = cirq.NamedQubit("q")
449494
c_orig = cirq.Circuit([cirq.X(q), cirq.Y(q), cirq.Z(q)] * 2, cirq.X(q) ** 0.5)
@@ -543,7 +588,45 @@ def merge_func(op1, op2):
543588
)
544589

545590

591+
def test_merge_operations_deep():
592+
q = cirq.LineQubit.range(2)
593+
h_cz_y = [cirq.H(q[0]), cirq.CZ(*q), cirq.Y(q[1])]
594+
m_cz_m = [cirq.Moment(), cirq.Moment(cirq.CZ(*q)), cirq.Moment()]
595+
c_orig = cirq.Circuit(
596+
h_cz_y,
597+
cirq.Moment(cirq.X(q[0]).with_tags("ignore"), cirq.Y(q[1])),
598+
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
599+
[cirq.CNOT(*q), cirq.CNOT(*q)],
600+
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(4),
601+
[cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)],
602+
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(5).with_tags("preserve_tag"),
603+
)
604+
c_expected = cirq.Circuit(
605+
m_cz_m,
606+
cirq.Moment(cirq.X(q[0]).with_tags("ignore")),
607+
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
608+
[cirq.CNOT(*q), cirq.CNOT(*q)],
609+
cirq.CircuitOperation(cirq.FrozenCircuit(m_cz_m)).repeat(4),
610+
[cirq.CZ(*q), cirq.Moment(), cirq.Moment()],
611+
cirq.CircuitOperation(cirq.FrozenCircuit(m_cz_m)).repeat(5).with_tags("preserve_tag"),
612+
strategy=cirq.InsertStrategy.NEW,
613+
)
614+
615+
def merge_func(op1, op2):
616+
"""Artificial example where a CZ will absorb any merge-able operation."""
617+
for op in [op1, op2]:
618+
if op.gate == cirq.CZ:
619+
return op
620+
return None
621+
622+
cirq.testing.assert_same_circuits(
623+
cirq.merge_operations(c_orig, merge_func, tags_to_ignore=["ignore"], deep=True), c_expected
624+
)
625+
626+
546627
# pylint: disable=line-too-long
628+
629+
547630
def test_merge_operations_to_circuit_op_merges_connected_component():
548631
c_orig = _create_circuit_to_merge()
549632
cirq.testing.assert_has_diagram(

0 commit comments

Comments
 (0)