Skip to content

Commit c20c99b

Browse files
authored
Add cirq.eject_phased_paulis transformer to replace cirq.EjectPhasedPaulis (#4958)
* Add cirq.eject_phased_paulis transformer to replace cirq.EjectPhasedPaulis * Add CCO tests, support PhasedXZGates * Support PhasedXZGates equivalent to z rotations and update docstrings
1 parent 2d2226a commit c20c99b

File tree

12 files changed

+976
-321
lines changed

12 files changed

+976
-321
lines changed

cirq-core/cirq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@
364364
decompose_two_qubit_interaction_into_four_fsim_gates,
365365
drop_empty_moments,
366366
drop_negligible_operations,
367+
eject_phased_paulis,
367368
eject_z,
368369
expand_composite,
369370
is_negligible_turn,

cirq-core/cirq/ion/ion_decomposition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def two_qubit_matrix_to_ion_operations(
5353
def _cleanup_operations(operations: List[ops.Operation]):
5454
circuit = circuits.Circuit(operations)
5555
optimizers.merge_single_qubit_gates.merge_single_qubit_gates_into_phased_x_z(circuit)
56-
optimizers.eject_phased_paulis.EjectPhasedPaulis().optimize_circuit(circuit)
56+
circuit = transformers.eject_phased_paulis(circuit)
5757
circuit = transformers.eject_z(circuit)
5858
circuit = circuits.Circuit(circuit.all_operations(), strategy=circuits.InsertStrategy.EARLIEST)
5959
return list(circuit.all_operations())

cirq-core/cirq/optimizers/eject_phased_paulis.py

+7-299
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,10 @@
1515
"""Pushes 180 degree rotations around axes in the XY plane later in the circuit.
1616
"""
1717

18-
from typing import Optional, cast, TYPE_CHECKING, Iterable, Tuple, Dict, List
19-
import sympy
20-
21-
from cirq import circuits, ops, value, protocols
22-
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
23-
24-
if TYPE_CHECKING:
25-
import cirq
26-
27-
28-
class _OptimizerState:
29-
def __init__(self):
30-
# The phases of the W gates currently being pushed along each qubit.
31-
self.held_w_phases: Dict[ops.Qid, value.TParamVal] = {}
32-
33-
# Accumulated commands to batch-apply to the circuit later.
34-
self.deletions: List[Tuple[int, ops.Operation]] = []
35-
self.inline_intos: List[Tuple[int, ops.Operation]] = []
36-
self.insertions: List[Tuple[int, ops.Operation]] = []
18+
from cirq import _compat, circuits, transformers
3719

3820

21+
@_compat.deprecated_class(deadline='v1.0', fix='Use cirq.eject_phased_paulis instead.')
3922
class EjectPhasedPaulis:
4023
"""Pushes X, Y, and PhasedX gates towards the end of the circuit.
4124
@@ -60,283 +43,8 @@ def __init__(self, tolerance: float = 1e-8, eject_parameterized: bool = False) -
6043
self.eject_parameterized = eject_parameterized
6144

6245
def optimize_circuit(self, circuit: circuits.Circuit):
63-
state = _OptimizerState()
64-
65-
for moment_index, moment in enumerate(circuit):
66-
for op in moment.operations:
67-
affected = [q for q in op.qubits if q in state.held_w_phases]
68-
69-
# Collect, phase, and merge Ws.
70-
w = _try_get_known_phased_pauli(op, no_symbolic=not self.eject_parameterized)
71-
if w is not None:
72-
if single_qubit_decompositions.is_negligible_turn(
73-
(w[0] - 1) / 2, self.tolerance
74-
):
75-
_potential_cross_whole_w(moment_index, op, self.tolerance, state)
76-
else:
77-
_potential_cross_partial_w(moment_index, op, state)
78-
continue
79-
80-
if not affected:
81-
continue
82-
83-
# Absorb Z rotations.
84-
t = _try_get_known_z_half_turns(op, no_symbolic=not self.eject_parameterized)
85-
if t is not None:
86-
_absorb_z_into_w(moment_index, op, state)
87-
continue
88-
89-
# Dump coherent flips into measurement bit flips.
90-
if isinstance(op.gate, ops.MeasurementGate):
91-
_dump_into_measurement(moment_index, op, state)
92-
93-
# Cross CZs using kickback.
94-
if (
95-
_try_get_known_cz_half_turns(op, no_symbolic=not self.eject_parameterized)
96-
is not None
97-
):
98-
if len(affected) == 1:
99-
_single_cross_over_cz(moment_index, op, affected[0], state)
100-
else:
101-
_double_cross_over_cz(op, state)
102-
continue
103-
104-
# Don't know how to handle this situation. Dump the gates.
105-
_dump_held(op.qubits, moment_index, state)
106-
107-
# Put anything that's still held at the end of the circuit.
108-
_dump_held(state.held_w_phases.keys(), len(circuit), state)
109-
110-
circuit.batch_remove(state.deletions)
111-
circuit.batch_insert_into(state.inline_intos)
112-
circuit.batch_insert(state.insertions)
113-
114-
115-
def _absorb_z_into_w(moment_index: int, op: ops.Operation, state: _OptimizerState) -> None:
116-
"""Absorbs a Z^t gate into a W(a) flip.
117-
118-
[Where W(a) is shorthand for PhasedX(phase_exponent=a).]
119-
120-
Uses the following identity:
121-
───W(a)───Z^t───
122-
≡ ───W(a)───────────Z^t/2──────────Z^t/2─── (split Z)
123-
≡ ───W(a)───W(a)───Z^-t/2───W(a)───Z^t/2─── (flip Z)
124-
≡ ───W(a)───W(a)──────────W(a+t/2)───────── (phase W)
125-
≡ ────────────────────────W(a+t/2)───────── (cancel Ws)
126-
≡ ───W(a+t/2)───
127-
"""
128-
t = cast(value.TParamVal, _try_get_known_z_half_turns(op))
129-
q = op.qubits[0]
130-
state.held_w_phases[q] += t / 2
131-
state.deletions.append((moment_index, op))
132-
133-
134-
def _dump_held(qubits: Iterable[ops.Qid], moment_index: int, state: _OptimizerState):
135-
# Note: sorting is to avoid non-determinism in the insertion order.
136-
for q in sorted(qubits):
137-
p = state.held_w_phases.get(q)
138-
if p is not None:
139-
dump_op = ops.PhasedXPowGate(phase_exponent=p).on(q)
140-
state.insertions.append((moment_index, dump_op))
141-
state.held_w_phases.pop(q, None)
142-
143-
144-
def _dump_into_measurement(moment_index: int, op: ops.Operation, state: _OptimizerState) -> None:
145-
measurement = cast(ops.MeasurementGate, cast(ops.GateOperation, op).gate)
146-
new_measurement = measurement.with_bits_flipped(
147-
*[i for i, q in enumerate(op.qubits) if q in state.held_w_phases]
148-
).on(*op.qubits)
149-
for q in op.qubits:
150-
state.held_w_phases.pop(q, None)
151-
state.deletions.append((moment_index, op))
152-
state.inline_intos.append((moment_index, new_measurement))
153-
154-
155-
def _potential_cross_whole_w(
156-
moment_index: int, op: ops.Operation, tolerance: float, state: _OptimizerState
157-
) -> None:
158-
"""Grabs or cancels a held W gate against an existing W gate.
159-
160-
[Where W(a) is shorthand for PhasedX(phase_exponent=a).]
161-
162-
Uses the following identity:
163-
───W(a)───W(b)───
164-
≡ ───Z^-a───X───Z^a───Z^-b───X───Z^b───
165-
≡ ───Z^-a───Z^-a───Z^b───X───X───Z^b───
166-
≡ ───Z^-a───Z^-a───Z^b───Z^b───
167-
≡ ───Z^2(b-a)───
168-
"""
169-
state.deletions.append((moment_index, op))
170-
171-
_, phase_exponent = cast(
172-
Tuple[value.TParamVal, value.TParamVal], _try_get_known_phased_pauli(op)
173-
)
174-
q = op.qubits[0]
175-
a = state.held_w_phases.get(q, None)
176-
b = phase_exponent
177-
178-
if a is None:
179-
# Collect the gate.
180-
state.held_w_phases[q] = b
181-
else:
182-
# Cancel the gate.
183-
del state.held_w_phases[q]
184-
t = 2 * (b - a)
185-
if not single_qubit_decompositions.is_negligible_turn(t / 2, tolerance):
186-
leftover_phase = ops.Z(q) ** t
187-
state.inline_intos.append((moment_index, leftover_phase))
188-
189-
190-
def _potential_cross_partial_w(
191-
moment_index: int, op: ops.Operation, state: _OptimizerState
192-
) -> None:
193-
"""Cross the held W over a partial W gate.
194-
195-
[Where W(a) is shorthand for PhasedX(phase_exponent=a).]
196-
197-
Uses the following identity:
198-
───W(a)───W(b)^t───
199-
≡ ───Z^-a───X───Z^a───W(b)^t────── (expand W(a))
200-
≡ ───Z^-a───X───W(b-a)^t───Z^a──── (move Z^a across, phasing axis)
201-
≡ ───Z^-a───W(a-b)^t───X───Z^a──── (move X across, negating axis angle)
202-
≡ ───W(2a-b)^t───Z^-a───X───Z^a─── (move Z^-a across, phasing axis)
203-
≡ ───W(2a-b)^t───W(a)───
204-
"""
205-
a = state.held_w_phases.get(op.qubits[0], None)
206-
if a is None:
207-
return
208-
exponent, phase_exponent = cast(
209-
Tuple[value.TParamVal, value.TParamVal], _try_get_known_phased_pauli(op)
210-
)
211-
new_op = ops.PhasedXPowGate(exponent=exponent, phase_exponent=2 * a - phase_exponent).on(
212-
op.qubits[0]
213-
)
214-
state.deletions.append((moment_index, op))
215-
state.inline_intos.append((moment_index, new_op))
216-
217-
218-
def _single_cross_over_cz(
219-
moment_index: int, op: ops.Operation, qubit_with_w: 'cirq.Qid', state: _OptimizerState
220-
) -> None:
221-
"""Crosses exactly one W flip over a partial CZ.
222-
223-
[Where W(a) is shorthand for PhasedX(phase_exponent=a).]
224-
225-
Uses the following identity:
226-
227-
──────────@─────
228-
229-
───W(a)───@^t───
230-
231-
232-
≡ ───@──────O──────@────────────────────
233-
| | │ (split into on/off cases)
234-
───W(a)───W(a)───@^t──────────────────
235-
236-
≡ ───@─────────────@─────────────O──────
237-
| │ | (off doesn't interact with on)
238-
───W(a)──────────@^t───────────W(a)───
239-
240-
≡ ───────────Z^t───@──────@──────O──────
241-
│ | | (crossing causes kickback)
242-
─────────────────@^-t───W(a)───W(a)─── (X Z^t X Z^-t = exp(pi t) I)
243-
244-
≡ ───────────Z^t───@────────────────────
245-
│ (merge on/off cases)
246-
─────────────────@^-t───W(a)──────────
247-
248-
≡ ───Z^t───@──────────────
249-
250-
─────────@^-t───W(a)────
251-
"""
252-
t = cast(value.TParamVal, _try_get_known_cz_half_turns(op))
253-
other_qubit = op.qubits[0] if qubit_with_w == op.qubits[1] else op.qubits[1]
254-
negated_cz = ops.CZ(*op.qubits) ** -t
255-
kickback = ops.Z(other_qubit) ** t
256-
257-
state.deletions.append((moment_index, op))
258-
state.inline_intos.append((moment_index, negated_cz))
259-
state.insertions.append((moment_index, kickback))
260-
261-
262-
def _double_cross_over_cz(op: ops.Operation, state: _OptimizerState) -> None:
263-
"""Crosses two W flips over a partial CZ.
264-
265-
[Where W(a) is shorthand for PhasedX(phase_exponent=a).]
266-
267-
Uses the following identity:
268-
269-
───W(a)───@─────
270-
271-
───W(b)───@^t───
272-
273-
274-
≡ ──────────@────────────W(a)───
275-
│ (single-cross top W over CZ)
276-
───W(b)───@^-t─────────Z^t────
277-
278-
279-
≡ ──────────@─────Z^-t───W(a)───
280-
│ (single-cross bottom W over CZ)
281-
──────────@^t───W(b)───Z^t────
282-
283-
284-
≡ ──────────@─────W(a)───Z^t────
285-
│ (flip over Z^-t)
286-
──────────@^t───W(b)───Z^t────
287-
288-
289-
≡ ──────────@─────W(a+t/2)──────
290-
│ (absorb Zs into Ws)
291-
──────────@^t───W(b+t/2)──────
292-
293-
≡ ───@─────W(a+t/2)───
294-
295-
───@^t───W(b+t/2)───
296-
"""
297-
t = cast(value.TParamVal, _try_get_known_cz_half_turns(op))
298-
for q in op.qubits:
299-
state.held_w_phases[q] = cast(value.TParamVal, state.held_w_phases[q]) + t / 2
300-
301-
302-
def _try_get_known_cz_half_turns(
303-
op: ops.Operation, no_symbolic: bool = False
304-
) -> Optional[value.TParamVal]:
305-
if not isinstance(op, ops.GateOperation) or not isinstance(op.gate, ops.CZPowGate):
306-
return None
307-
h = op.gate.exponent
308-
if no_symbolic and isinstance(h, sympy.Basic):
309-
return None
310-
return h
311-
312-
313-
def _try_get_known_phased_pauli(
314-
op: ops.Operation, no_symbolic: bool = False
315-
) -> Optional[Tuple[value.TParamVal, value.TParamVal]]:
316-
if (no_symbolic and protocols.is_parameterized(op)) or not isinstance(op, ops.GateOperation):
317-
return None
318-
gate = op.gate
319-
320-
if isinstance(gate, ops.PhasedXPowGate):
321-
e = gate.exponent
322-
p = gate.phase_exponent
323-
elif isinstance(gate, ops.YPowGate):
324-
e = gate.exponent
325-
p = 0.5
326-
elif isinstance(gate, ops.XPowGate):
327-
e = gate.exponent
328-
p = 0.0
329-
else:
330-
return None
331-
return value.canonicalize_half_turns(e), value.canonicalize_half_turns(p)
332-
333-
334-
def _try_get_known_z_half_turns(
335-
op: ops.Operation, no_symbolic: bool = False
336-
) -> Optional[value.TParamVal]:
337-
if not isinstance(op, ops.GateOperation) or not isinstance(op.gate, ops.ZPowGate):
338-
return None
339-
h = op.gate.exponent
340-
if no_symbolic and isinstance(h, sympy.Basic):
341-
return None
342-
return h
46+
circuit._moments = [
47+
*transformers.eject_phased_paulis(
48+
circuit, atol=self.tolerance, eject_parameterized=self.eject_parameterized
49+
)
50+
]

cirq-core/cirq/optimizers/eject_phased_paulis_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ def assert_optimizes(
2626
compare_unitaries: bool = True,
2727
eject_parameterized: bool = False,
2828
):
29-
opt = cirq.EjectPhasedPaulis(eject_parameterized=eject_parameterized)
29+
with cirq.testing.assert_deprecated("Use cirq.eject_phased_paulis", deadline='v1.0'):
30+
opt = cirq.EjectPhasedPaulis(eject_parameterized=eject_parameterized)
3031

3132
circuit = before.copy()
33+
expected = cirq.drop_empty_moments(expected)
3234
opt.optimize_circuit(circuit)
3335

3436
# They should have equivalent effects.

cirq-core/cirq/optimizers/merge_interactions_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit):
2828
# Ignore differences that would be caught by follow-up optimizations.
2929
followup_optimizations: List[Callable[[cirq.Circuit], None]] = [
3030
cirq.merge_single_qubit_gates_into_phased_x_z,
31-
cirq.EjectPhasedPaulis().optimize_circuit,
3231
]
3332
for post in followup_optimizations:
3433
post(actual)
3534
post(expected)
3635

3736
followup_transformers: List[cirq.TRANSFORMER] = [
37+
cirq.eject_phased_paulis,
3838
cirq.eject_z,
3939
cirq.drop_negligible_operations,
4040
cirq.drop_empty_moments,

cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs):
3939
# Ignore differences that would be caught by follow-up optimizations.
4040
followup_optimizations: List[Callable[[cirq.Circuit], None]] = [
4141
cirq.merge_single_qubit_gates_into_phased_x_z,
42-
cirq.EjectPhasedPaulis().optimize_circuit,
4342
]
4443
for post in followup_optimizations:
4544
post(actual)
4645
post(expected)
4746

4847
followup_transformers: List[cirq.TRANSFORMER] = [
48+
cirq.eject_phased_paulis,
4949
cirq.eject_z,
5050
cirq.drop_negligible_operations,
5151
cirq.drop_empty_moments,

cirq-core/cirq/transformers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545

4646
from cirq.transformers.expand_composite import expand_composite
4747

48+
from cirq.transformers.eject_phased_paulis import eject_phased_paulis
49+
4850
from cirq.transformers.drop_empty_moments import drop_empty_moments
4951

5052
from cirq.transformers.drop_negligible_operations import drop_negligible_operations

0 commit comments

Comments
 (0)