Skip to content

Commit ca4bb72

Browse files
authored
Add support for deep=True to cirq.eject_phased_paulis transformer (#5116)
- Adds support to recursively run `cirq.eject_phased_paulis` transformer on circuits wrapped inside a circuit operation by setting deep=True in transformer context. - Part of #5039
1 parent d2f284d commit ca4bb72

File tree

2 files changed

+33
-19
lines changed

2 files changed

+33
-19
lines changed

cirq-core/cirq/transformers/eject_phased_paulis.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import cirq
2727

2828

29-
@transformer_api.transformer
29+
@transformer_api.transformer(add_deep_support=True)
3030
def eject_phased_paulis(
3131
circuit: 'cirq.AbstractCircuit',
3232
*,

cirq-core/cirq/transformers/eject_phased_paulis_test.py

+32-18
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from typing import Iterable, cast
1515

16+
import dataclasses
1617
import numpy as np
1718
import pytest
1819
import sympy
@@ -53,29 +54,42 @@ def assert_optimizes(
5354
)
5455

5556
# And match the expected circuit.
56-
assert circuit == expected, (
57-
"Circuit wasn't optimized as expected.\n"
58-
"INPUT:\n"
59-
"{}\n"
60-
"\n"
61-
"EXPECTED OUTPUT:\n"
62-
"{}\n"
63-
"\n"
64-
"ACTUAL OUTPUT:\n"
65-
"{}\n"
66-
"\n"
67-
"EXPECTED OUTPUT (detailed):\n"
68-
"{!r}\n"
69-
"\n"
70-
"ACTUAL OUTPUT (detailed):\n"
71-
"{!r}"
72-
).format(before, expected, circuit, expected, circuit)
57+
cirq.testing.assert_same_circuits(circuit, expected)
7358

7459
# And it should be idempotent.
7560
circuit = cirq.eject_phased_paulis(
7661
circuit, eject_parameterized=eject_parameterized, context=context
7762
)
78-
assert circuit == expected
63+
cirq.testing.assert_same_circuits(circuit, expected)
64+
65+
# Nested sub-circuits should also get optimized.
66+
q = before.all_qubits()
67+
c_nested = cirq.Circuit(
68+
[cirq.PhasedXPowGate(phase_exponent=0.5).on_each(*q), (cirq.Z ** 0.5).on_each(*q)],
69+
cirq.CircuitOperation(before.freeze()).repeat(2).with_tags("ignore"),
70+
[cirq.Y.on_each(*q), cirq.X.on_each(*q)],
71+
cirq.CircuitOperation(before.freeze()).repeat(3).with_tags("preserve_tag"),
72+
)
73+
c_expected = cirq.Circuit(
74+
cirq.PhasedXPowGate(phase_exponent=0.75).on_each(*q),
75+
cirq.Moment(cirq.CircuitOperation(before.freeze()).repeat(2).with_tags("ignore")),
76+
cirq.Z.on_each(*q),
77+
cirq.Moment(cirq.CircuitOperation(expected.freeze()).repeat(3).with_tags("preserve_tag")),
78+
)
79+
if context is None:
80+
context = cirq.TransformerContext(tags_to_ignore=("ignore",), deep=True)
81+
else:
82+
context = dataclasses.replace(
83+
context, tags_to_ignore=context.tags_to_ignore + ("ignore",), deep=True
84+
)
85+
c_nested = cirq.eject_phased_paulis(
86+
c_nested, context=context, eject_parameterized=eject_parameterized
87+
)
88+
cirq.testing.assert_same_circuits(c_nested, c_expected)
89+
c_nested = cirq.eject_phased_paulis(
90+
c_nested, context=context, eject_parameterized=eject_parameterized
91+
)
92+
cirq.testing.assert_same_circuits(c_nested, c_expected)
7993

8094

8195
def quick_circuit(*moments: Iterable[cirq.OP_TREE]) -> cirq.Circuit:

0 commit comments

Comments
 (0)