Skip to content

Commit 89e7210

Browse files
authored
Add support for deep=True to cirq.optimize_for_target_gateset transformer (#5124)
- Adds support for `deep=True` flag to `cirq.optimize_for_target_gateset` which enables optimizing circuits preserving the sub-circuit structure (i.e. without unrolling circuit operations). - Part of #5039
1 parent 518d828 commit 89e7210

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

cirq-core/cirq/transformers/optimize_for_target_gateset.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
"""Transformers to rewrite a circuit using gates from a given target gateset."""
1616

17-
from typing import Optional, Callable, TYPE_CHECKING
17+
from typing import Optional, Callable, Hashable, Sequence, TYPE_CHECKING
1818

19+
from cirq import circuits
1920
from cirq.protocols import decompose_protocol as dp
2021
from cirq.transformers import transformer_api, transformer_primitives
2122

@@ -38,6 +39,7 @@ def _decompose_operations_to_target_gateset(
3839
gateset: Optional['cirq.Gateset'] = None,
3940
decomposer: Callable[['cirq.Operation', int], dp.DecomposeResult] = lambda *_: NotImplemented,
4041
ignore_failures: bool = True,
42+
tags_to_decompose: Sequence[Hashable] = (),
4143
) -> 'cirq.Circuit':
4244
"""Decomposes every operation to `gateset` using `cirq.decompose` and `decomposer`.
4345
@@ -56,6 +58,8 @@ def _decompose_operations_to_target_gateset(
5658
- `None` or `NotImplemented` if does not know how to decompose a given `op`.
5759
ignore_failures: If set, operations that fail to convert are left unchanged. If not set,
5860
conversion failures raise a ValueError.
61+
tags_to_decompose: `cirq.CircuitOperation`s tagged with any of `tags_to_decompose` will
62+
be decomposed even if context.deep is True.
5963
6064
Returns:
6165
An equivalent circuit containing gates accepted by `gateset`.
@@ -65,6 +69,13 @@ def _decompose_operations_to_target_gateset(
6569
"""
6670

6771
def map_func(op: 'cirq.Operation', moment_index: int):
72+
if (
73+
context
74+
and context.deep
75+
and isinstance(op.untagged, circuits.CircuitOperation)
76+
and set(op.tags).isdisjoint(tags_to_decompose)
77+
):
78+
return op
6879
return dp.decompose(
6980
op,
7081
intercepting_decomposer=lambda o: decomposer(o, moment_index),
@@ -77,7 +88,10 @@ def map_func(op: 'cirq.Operation', moment_index: int):
7788
)
7889

7990
return transformer_primitives.map_operations_and_unroll(
80-
circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else ()
91+
circuit,
92+
map_func,
93+
tags_to_ignore=context.tags_to_ignore if context else (),
94+
deep=context.deep if context else False,
8195
).unfreeze(copy=False)
8296

8397

@@ -122,6 +136,7 @@ def optimize_for_target_gateset(
122136
gateset=gateset,
123137
decomposer=gateset.decompose_to_target_gateset,
124138
ignore_failures=ignore_failures,
139+
tags_to_decompose=(gateset._intermediate_result_tag,),
125140
)
126141

127142
for transformer in gateset.postprocess_transformers:

cirq-core/cirq/transformers/optimize_for_target_gateset_test.py

+50
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,53 @@ def test_optimize_for_target_gateset():
196196
_ = cirq.optimize_for_target_gateset(
197197
c_orig, gateset=gateset, context=context, ignore_failures=False
198198
)
199+
200+
201+
def test_optimize_for_target_gateset_deep():
202+
q0, q1 = cirq.LineQubit.range(2)
203+
c_nested = cirq.FrozenCircuit(cirq.CX(q0, q1))
204+
c_orig = cirq.Circuit(
205+
cirq.CircuitOperation(
206+
cirq.FrozenCircuit(cirq.H(q0), cirq.CircuitOperation(c_nested).repeat(3))
207+
).repeat(5)
208+
)
209+
c_expected = cirq.Circuit(
210+
cirq.CircuitOperation(
211+
cirq.FrozenCircuit(
212+
cirq.single_qubit_matrix_to_phxz(cirq.unitary(cirq.H(q0))).on(q0),
213+
cirq.CircuitOperation(
214+
cirq.FrozenCircuit(
215+
cirq.MatrixGate(c_nested.unitary(qubit_order=[q0, q1]), name="M").on(q0, q1)
216+
)
217+
).repeat(3),
218+
)
219+
).repeat(5)
220+
)
221+
gateset = MatrixGateTargetGateset()
222+
context = cirq.TransformerContext(deep=True)
223+
c_new = cirq.optimize_for_target_gateset(c_orig, gateset=gateset, context=context)
224+
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_new, c_expected)
225+
cirq.testing.assert_has_diagram(
226+
c_orig,
227+
'''
228+
[ [ 0: ───@─── ] ]
229+
[ 0: ───H───[ │ ]──────────── ]
230+
0: ───[ [ 1: ───X─── ](loops=3) ]────────────
231+
[ │ ]
232+
[ 1: ───────#2──────────────────────── ](loops=5)
233+
234+
1: ───#2──────────────────────────────────────────────────
235+
''',
236+
)
237+
cirq.testing.assert_has_diagram(
238+
c_new,
239+
'''
240+
[ [ 0: ───M[1]─── ] ]
241+
[ 0: ───PhXZ(a=-0.5,x=0.5,z=-1)───[ │ ]──────────── ]
242+
0: ───[ [ 1: ───M[2]─── ](loops=3) ]────────────
243+
[ │ ]
244+
[ 1: ─────────────────────────────#2─────────────────────────── ](loops=5)
245+
246+
1: ───#2───────────────────────────────────────────────────────────────────────────
247+
''',
248+
)

0 commit comments

Comments
 (0)