Skip to content

Commit 155f607

Browse files
authored
Add support for deep=True to cirq.stratified_circuit transformer (#5117)
- Adds support to recursively run `cirq.stratified_circuit` transformer on circuits wrapped inside a circuit operation by setting deep=True in transformer context. - Part of #5039
1 parent 869d83b commit 155f607

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

cirq-core/cirq/transformers/stratify.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
]
4646

4747

48-
@transformer_api.transformer
48+
@transformer_api.transformer(add_deep_support=True)
4949
def stratified_circuit(
5050
circuit: 'cirq.AbstractCircuit',
5151
*,

cirq-core/cirq/transformers/stratify_test.py

+37
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,43 @@ def test_complex_circuit():
295295
)
296296

297297

298+
def test_complex_circuit_deep():
299+
q = cirq.LineQubit.range(5)
300+
c_nested = cirq.FrozenCircuit(
301+
cirq.Moment(
302+
cirq.X(q[0]).with_tags("ignore"),
303+
cirq.ISWAP(q[1], q[2]).with_tags("ignore"),
304+
cirq.Z(q[4]),
305+
),
306+
cirq.Moment(cirq.Z(q[1]), cirq.ISWAP(q[3], q[4])),
307+
cirq.Moment(cirq.ISWAP(q[0], q[1]), cirq.X(q[3])),
308+
cirq.Moment(cirq.X.on_each(q[0])),
309+
)
310+
c_nested_stratified = cirq.FrozenCircuit(
311+
cirq.Moment(cirq.X(q[0]).with_tags("ignore"), cirq.ISWAP(q[1], q[2]).with_tags("ignore")),
312+
cirq.Moment(cirq.Z.on_each(q[1], q[4])),
313+
cirq.Moment(cirq.ISWAP(*q[:2]), cirq.ISWAP(*q[3:])),
314+
cirq.Moment(cirq.X.on_each(q[0], q[3])),
315+
)
316+
c_orig = cirq.Circuit(
317+
c_nested,
318+
cirq.CircuitOperation(c_nested).repeat(5).with_tags("ignore"),
319+
c_nested,
320+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("preserve_tag"),
321+
c_nested,
322+
)
323+
c_expected = cirq.Circuit(
324+
c_nested_stratified,
325+
cirq.CircuitOperation(c_nested).repeat(5).with_tags("ignore"),
326+
c_nested_stratified,
327+
cirq.CircuitOperation(c_nested_stratified).repeat(6).with_tags("preserve_tag"),
328+
c_nested_stratified,
329+
)
330+
context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
331+
c_stratified = cirq.stratified_circuit(c_orig, context=context, categories=[cirq.X, cirq.Z])
332+
cirq.testing.assert_same_circuits(c_stratified, c_expected)
333+
334+
298335
def test_no_categories_earliest_insert():
299336
q1, q2, q3, q4, q5 = cirq.LineQubit.range(5)
300337
input_circuit = cirq.Circuit(

0 commit comments

Comments
 (0)