Skip to content

Commit 0a4b4aa

Browse files
tanujkhattarrht
authored andcommitted
Add support for deep=True to cirq.align_left and cirq.align_right transformers (quantumlib#5112)
- Adds support to recursively run `cirq.align_left` and `cirq.align_right` transformers on circuits wrapped inside a circuit operation by setting `deep=True` in transformer context. - Part of quantumlib#5039
1 parent ac47c1c commit 0a4b4aa

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

Diff for: cirq-core/cirq/transformers/align.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Transformer passes which align operations to the left or right of the circuit."""
1616

17+
import dataclasses
1718
from typing import Optional, TYPE_CHECKING
1819
from cirq import circuits, ops
1920
from cirq.transformers import transformer_api
@@ -22,7 +23,7 @@
2223
import cirq
2324

2425

25-
@transformer_api.transformer
26+
@transformer_api.transformer(add_deep_support=True)
2627
def align_left(
2728
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
2829
) -> 'cirq.Circuit':
@@ -54,7 +55,7 @@ def align_left(
5455
return ret
5556

5657

57-
@transformer_api.transformer
58+
@transformer_api.transformer(add_deep_support=True)
5859
def align_right(
5960
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
6061
) -> 'cirq.Circuit':
@@ -70,4 +71,6 @@ def align_right(
7071
Returns:
7172
Copy of the transformed input circuit.
7273
"""
74+
if context is not None and context.deep is True:
75+
context = dataclasses.replace(context, deep=False)
7376
return align_left(circuit[::-1], context=context)[::-1]

Diff for: cirq-core/cirq/transformers/align_test.py

+68
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,41 @@ def test_align_left_no_compile_context():
7171
)
7272

7373

74+
def test_align_left_deep():
75+
q1, q2 = cirq.LineQubit.range(2)
76+
c_nested = cirq.FrozenCircuit(
77+
[
78+
cirq.Moment([cirq.X(q1)]),
79+
cirq.Moment([cirq.Y(q2)]),
80+
cirq.Moment([cirq.Z(q1), cirq.Y(q2).with_tags("nocompile")]),
81+
cirq.Moment([cirq.Y(q1)]),
82+
cirq.measure(q2, key='a'),
83+
cirq.Z(q1).with_classical_controls('a'),
84+
]
85+
)
86+
c_nested_aligned = cirq.FrozenCircuit(
87+
cirq.Moment(cirq.X(q1), cirq.Y(q2)),
88+
cirq.Moment(cirq.Z(q1)),
89+
cirq.Moment([cirq.Y(q1), cirq.Y(q2).with_tags("nocompile")]),
90+
cirq.measure(q2, key='a'),
91+
cirq.Z(q1).with_classical_controls('a'),
92+
)
93+
c_orig = cirq.Circuit(
94+
c_nested,
95+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"),
96+
c_nested,
97+
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
98+
)
99+
c_expected = cirq.Circuit(
100+
c_nested_aligned,
101+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"),
102+
c_nested_aligned,
103+
cirq.CircuitOperation(c_nested_aligned).repeat(5).with_tags("preserve_tag"),
104+
)
105+
context = cirq.TransformerContext(tags_to_ignore=["nocompile"], deep=True)
106+
cirq.testing.assert_same_circuits(cirq.align_left(c_orig, context=context), c_expected)
107+
108+
74109
def test_align_left_subset_of_operations():
75110
q1 = cirq.NamedQubit('q1')
76111
q2 = cirq.NamedQubit('q2')
@@ -133,6 +168,39 @@ def test_align_right_no_compile_context():
133168
)
134169

135170

171+
def test_align_right_deep():
172+
q1, q2 = cirq.LineQubit.range(2)
173+
c_nested = cirq.FrozenCircuit(
174+
cirq.Moment([cirq.X(q1)]),
175+
cirq.Moment([cirq.Y(q1), cirq.X(q2).with_tags("nocompile")]),
176+
cirq.Moment([cirq.X(q2)]),
177+
cirq.Moment([cirq.Y(q1)]),
178+
cirq.measure(q1, key='a'),
179+
cirq.Z(q2).with_classical_controls('a'),
180+
)
181+
c_nested_aligned = cirq.FrozenCircuit(
182+
cirq.Moment([cirq.X(q1), cirq.X(q2).with_tags("nocompile")]),
183+
[cirq.Y(q1), cirq.Y(q1)],
184+
cirq.Moment(cirq.measure(q1, key='a'), cirq.X(q2)),
185+
cirq.Z(q2).with_classical_controls('a'),
186+
)
187+
c_orig = cirq.Circuit(
188+
c_nested,
189+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"),
190+
c_nested,
191+
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
192+
)
193+
c_expected = cirq.Circuit(
194+
c_nested_aligned,
195+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"),
196+
cirq.Moment(),
197+
c_nested_aligned,
198+
cirq.CircuitOperation(c_nested_aligned).repeat(5).with_tags("preserve_tag"),
199+
)
200+
context = cirq.TransformerContext(tags_to_ignore=["nocompile"], deep=True)
201+
cirq.testing.assert_same_circuits(cirq.align_right(c_orig, context=context), c_expected)
202+
203+
136204
def test_classical_control():
137205
q0, q1 = cirq.LineQubit.range(2)
138206
circuit = cirq.Circuit(

0 commit comments

Comments
 (0)