Skip to content

Commit d2f284d

Browse files
authored
Add support for deep=True to cirq.expand_composite transformer (#5119)
- Adds support to recursively run `cirq.expand_composite` transformer on circuits wrapped inside a circuit operation by setting deep=True in transformer context. - Note that this does not rely on `preserve_structure` argument of `protocols.decompose` because the latter does not support handling nested circuit operations tagged with a no-compile tag (the added tests would fail if we rely on protocols.decompose(preserve_structure=True) instead transformer primitives). Hence, I would argue that we should deprecate the preserve_structure=True flag in protocols.decompose in-favour of this transformer. cc @95-martin-orion - Part of #5039
1 parent b45e63d commit d2f284d

File tree

2 files changed

+84
-3
lines changed

2 files changed

+84
-3
lines changed

cirq-core/cirq/transformers/expand_composite.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from typing import Callable, Optional, TYPE_CHECKING
1818

19-
from cirq import ops, protocols
19+
from cirq import circuits, ops, protocols
2020
from cirq.transformers import transformer_api, transformer_primitives
2121

2222
if TYPE_CHECKING:
@@ -49,8 +49,17 @@ def expand_composite(
4949
"""
5050

5151
def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
52-
return protocols.decompose(op, keep=no_decomp, on_stuck_raise=None)
52+
if context and context.deep and isinstance(op.untagged, circuits.CircuitOperation):
53+
return op
54+
return protocols.decompose(
55+
op,
56+
keep=no_decomp,
57+
on_stuck_raise=None,
58+
)
5359

5460
return transformer_primitives.map_operations_and_unroll(
55-
circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else ()
61+
circuit,
62+
map_func,
63+
tags_to_ignore=context.tags_to_ignore if context else (),
64+
deep=context.deep if context else False,
5665
).unfreeze(copy=False)

cirq-core/cirq/transformers/expand_composite_test.py

+72
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,75 @@ def test_do_not_decompose_no_compile():
187187
c = cirq.Circuit(cirq.CNOT(q0, q1).with_tags("no_compile"))
188188
context = cirq.TransformerContext(tags_to_ignore=("no_compile",))
189189
assert_equal_mod_empty(c, cirq.expand_composite(c, context=context))
190+
191+
192+
def test_expands_composite_recursively_preserving_structur():
193+
q = cirq.LineQubit.range(2)
194+
c_nested = cirq.FrozenCircuit(
195+
cirq.SWAP(*q[:2]), cirq.SWAP(*q[:2]).with_tags("ignore"), cirq.SWAP(*q[:2])
196+
)
197+
c_nested_expanded = cirq.FrozenCircuit(
198+
[cirq.CNOT(*q), cirq.CNOT(*q[::-1]), cirq.CNOT(*q)],
199+
cirq.SWAP(*q[:2]).with_tags("ignore"),
200+
[cirq.CNOT(*q), cirq.CNOT(*q[::-1]), cirq.CNOT(*q)],
201+
)
202+
c_orig = cirq.Circuit(
203+
c_nested,
204+
cirq.CircuitOperation(
205+
cirq.FrozenCircuit(
206+
c_nested,
207+
cirq.CircuitOperation(c_nested).repeat(5).with_tags("ignore"),
208+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("preserve_tag"),
209+
cirq.CircuitOperation(c_nested).repeat(7),
210+
c_nested,
211+
)
212+
)
213+
.repeat(4)
214+
.with_tags("ignore"),
215+
c_nested,
216+
cirq.CircuitOperation(
217+
cirq.FrozenCircuit(
218+
c_nested,
219+
cirq.CircuitOperation(c_nested).repeat(5).with_tags("ignore"),
220+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("preserve_tag"),
221+
cirq.CircuitOperation(c_nested).repeat(7),
222+
c_nested,
223+
)
224+
)
225+
.repeat(5)
226+
.with_tags("preserve_tag"),
227+
c_nested,
228+
)
229+
c_expected = cirq.Circuit(
230+
c_nested_expanded,
231+
cirq.CircuitOperation(
232+
cirq.FrozenCircuit(
233+
c_nested,
234+
cirq.CircuitOperation(c_nested).repeat(5).with_tags("ignore"),
235+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("preserve_tag"),
236+
cirq.CircuitOperation(c_nested).repeat(7),
237+
c_nested,
238+
)
239+
)
240+
.repeat(4)
241+
.with_tags("ignore"),
242+
c_nested_expanded,
243+
cirq.CircuitOperation(
244+
cirq.FrozenCircuit(
245+
c_nested_expanded,
246+
cirq.CircuitOperation(c_nested).repeat(5).with_tags("ignore"),
247+
cirq.CircuitOperation(c_nested_expanded).repeat(6).with_tags("preserve_tag"),
248+
cirq.CircuitOperation(c_nested_expanded).repeat(7),
249+
c_nested_expanded,
250+
)
251+
)
252+
.repeat(5)
253+
.with_tags("preserve_tag"),
254+
c_nested_expanded,
255+
)
256+
257+
context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
258+
c_expanded = cirq.expand_composite(
259+
c_orig, no_decomp=lambda op: op.gate == cirq.CNOT, context=context
260+
)
261+
cirq.testing.assert_same_circuits(c_expanded, c_expected)

0 commit comments

Comments
 (0)