Skip to content

Commit 9d46602

Browse files
tanujkhattarrht
authored andcommitted
Add cirq.eject_z transformer to replace cirq.EjectZ (quantumlib#4955)
* Add eject_z transformer to replace EjectZ * Replaces usages of EjectZ with eject_z * Reorder cleanup transformers
1 parent cc91331 commit 9d46602

File tree

13 files changed

+636
-157
lines changed

13 files changed

+636
-157
lines changed

cirq-core/cirq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@
362362
decompose_two_qubit_interaction_into_four_fsim_gates,
363363
drop_empty_moments,
364364
drop_negligible_operations,
365+
eject_z,
365366
expand_composite,
366367
is_negligible_turn,
367368
map_moments,

cirq-core/cirq/ion/ion_decomposition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _cleanup_operations(operations: List[ops.Operation]):
5454
circuit = circuits.Circuit(operations)
5555
optimizers.merge_single_qubit_gates.merge_single_qubit_gates_into_phased_x_z(circuit)
5656
optimizers.eject_phased_paulis.EjectPhasedPaulis().optimize_circuit(circuit)
57-
optimizers.eject_z.EjectZ().optimize_circuit(circuit)
57+
circuit = transformers.eject_z(circuit)
5858
circuit = circuits.Circuit(circuit.all_operations(), strategy=circuits.InsertStrategy.EARLIEST)
5959
return list(circuit.all_operations())
6060

cirq-core/cirq/optimizers/eject_z.py

+9-116
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,13 @@
1414

1515
"""An optimization pass that pushes Z gates later and later in the circuit."""
1616

17-
from typing import cast, Dict, Iterable, List, Optional, Tuple
18-
from collections import defaultdict
19-
import numpy as np
20-
import sympy
17+
from cirq import circuits, transformers
2118

22-
from cirq import circuits, ops, protocols
23-
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
24-
25-
26-
def _is_integer(n):
27-
return np.isclose(n, np.round(n))
28-
29-
30-
def _is_swaplike(op: ops.Operation):
31-
if isinstance(op.gate, ops.SwapPowGate):
32-
return op.gate.exponent == 1
33-
34-
if isinstance(op.gate, ops.ISwapPowGate):
35-
return _is_integer((op.gate.exponent - 1) / 2)
36-
37-
if isinstance(op.gate, ops.FSimGate):
38-
return _is_integer(op.gate.theta / np.pi - 1 / 2)
39-
40-
return False
19+
# from cirq.transformers import eject_z
20+
from cirq._compat import deprecated_class
4121

4222

23+
@deprecated_class(deadline='v1.0', fix='Use cirq.eject_z instead.')
4324
class EjectZ:
4425
"""Pushes Z gates towards the end of the circuit.
4526
@@ -62,96 +43,8 @@ def __init__(self, tolerance: float = 0.0, eject_parameterized: bool = False) ->
6243
self.eject_parameterized = eject_parameterized
6344

6445
def optimize_circuit(self, circuit: circuits.Circuit):
65-
# Tracks qubit phases (in half turns; multiply by pi to get radians).
66-
qubit_phase: Dict[ops.Qid, float] = defaultdict(lambda: 0)
67-
deletions: List[Tuple[int, ops.Operation]] = []
68-
replacements: List[Tuple[int, ops.Operation, ops.Operation]] = []
69-
insertions: List[Tuple[int, ops.Operation]] = []
70-
phased_xz_replacements: Dict[Tuple[int, ops.Qid], int] = {}
71-
72-
def dump_tracked_phase(qubits: Iterable[ops.Qid], index: int) -> None:
73-
"""Zeroes qubit_phase entries by emitting Z gates."""
74-
for q in qubits:
75-
p = qubit_phase[q]
76-
qubit_phase[q] = 0
77-
if single_qubit_decompositions.is_negligible_turn(p, self.tolerance):
78-
continue
79-
dumped = False
80-
moment_index = circuit.prev_moment_operating_on([q], index)
81-
if moment_index is not None:
82-
op = circuit.moments[moment_index][q]
83-
if op and isinstance(op.gate, ops.PhasedXZGate):
84-
# Attach z-rotation to replacing PhasedXZ gate.
85-
idx = phased_xz_replacements[moment_index, q]
86-
_, _, repl_op = replacements[idx]
87-
gate = cast(ops.PhasedXZGate, repl_op.gate)
88-
repl_op = gate.with_z_exponent(p * 2).on(q)
89-
replacements[idx] = (moment_index, op, repl_op)
90-
dumped = True
91-
if not dumped:
92-
# Add a new Z gate
93-
dump_op = ops.Z(q) ** (p * 2)
94-
insertions.append((index, dump_op))
95-
96-
for moment_index, moment in enumerate(circuit):
97-
for op in moment.operations:
98-
# Move Z gates into tracked qubit phases.
99-
h = _try_get_known_z_half_turns(op, self.eject_parameterized)
100-
if h is not None:
101-
q = op.qubits[0]
102-
qubit_phase[q] += h / 2
103-
deletions.append((moment_index, op))
104-
continue
105-
106-
# Z gate before measurement is a no-op. Drop tracked phase.
107-
if isinstance(op.gate, ops.MeasurementGate):
108-
for q in op.qubits:
109-
qubit_phase[q] = 0
110-
111-
# If there's no tracked phase, we can move on.
112-
phases = [qubit_phase[q] for q in op.qubits]
113-
if not isinstance(op.gate, ops.PhasedXZGate) and all(
114-
single_qubit_decompositions.is_negligible_turn(p, self.tolerance)
115-
for p in phases
116-
):
117-
continue
118-
119-
if _is_swaplike(op):
120-
a, b = op.qubits
121-
qubit_phase[a], qubit_phase[b] = qubit_phase[b], qubit_phase[a]
122-
continue
123-
124-
# Try to move the tracked phasing over the operation.
125-
phased_op = op
126-
for i, p in enumerate(phases):
127-
if not single_qubit_decompositions.is_negligible_turn(p, self.tolerance):
128-
phased_op = protocols.phase_by(phased_op, -p, i, default=None)
129-
if phased_op is not None:
130-
gate = phased_op.gate
131-
if isinstance(gate, ops.PhasedXZGate) and (
132-
self.eject_parameterized or not protocols.is_parameterized(gate.z_exponent)
133-
):
134-
qubit = phased_op.qubits[0]
135-
qubit_phase[qubit] += gate.z_exponent / 2
136-
phased_op = gate.with_z_exponent(0).on(qubit)
137-
repl_idx = len(replacements)
138-
phased_xz_replacements[moment_index, qubit] = repl_idx
139-
replacements.append((moment_index, op, phased_op))
140-
else:
141-
dump_tracked_phase(op.qubits, moment_index)
142-
143-
dump_tracked_phase(qubit_phase.keys(), len(circuit))
144-
circuit.batch_remove(deletions)
145-
circuit.batch_replace(replacements)
146-
circuit.batch_insert(insertions)
147-
148-
149-
def _try_get_known_z_half_turns(op: ops.Operation, eject_parameterized: bool) -> Optional[float]:
150-
if not isinstance(op, ops.GateOperation):
151-
return None
152-
if not isinstance(op.gate, ops.ZPowGate):
153-
return None
154-
h = op.gate.exponent
155-
if not eject_parameterized and isinstance(h, sympy.Basic):
156-
return None
157-
return h
46+
circuit._moments = [
47+
*transformers.eject_z(
48+
circuit, atol=self.tolerance, eject_parameterized=self.eject_parameterized
49+
)
50+
]

cirq-core/cirq/optimizers/eject_z_test.py

+21-31
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,36 @@
1616
import sympy
1717

1818
import cirq
19-
from cirq.optimizers.eject_z import _try_get_known_z_half_turns
2019

2120

2221
def assert_optimizes(
2322
before: cirq.Circuit, expected: cirq.Circuit, eject_parameterized: bool = False
2423
):
25-
opt = cirq.EjectZ(eject_parameterized=eject_parameterized)
24+
with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'):
25+
opt = cirq.EjectZ(eject_parameterized=eject_parameterized)
2626

27-
if cirq.has_unitary(before):
28-
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
29-
before, expected, atol=1e-8
30-
)
27+
if cirq.has_unitary(before):
28+
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
29+
before, expected, atol=1e-8
30+
)
3131

32-
circuit = before.copy()
33-
opt.optimize_circuit(circuit)
34-
opt.optimize_circuit(expected)
32+
circuit = before.copy()
33+
opt.optimize_circuit(circuit)
34+
opt.optimize_circuit(expected)
3535

36-
cirq.testing.assert_same_circuits(circuit, expected)
36+
cirq.testing.assert_same_circuits(circuit, expected)
3737

38-
# And it should be idempotent.
39-
opt.optimize_circuit(circuit)
40-
cirq.testing.assert_same_circuits(circuit, expected)
38+
# And it should be idempotent.
39+
opt.optimize_circuit(circuit)
40+
cirq.testing.assert_same_circuits(circuit, expected)
4141

4242

4343
def assert_removes_all_z_gates(circuit: cirq.Circuit, eject_parameterized: bool = True):
44-
opt = cirq.EjectZ(eject_parameterized=eject_parameterized)
44+
with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'):
45+
opt = cirq.EjectZ(eject_parameterized=eject_parameterized)
4546
optimized = circuit.copy()
4647
opt.optimize_circuit(optimized)
4748
for op in optimized.all_operations():
48-
assert _try_get_known_z_half_turns(op, eject_parameterized) is None
4949
if isinstance(op.gate, cirq.PhasedXZGate) and (
5050
eject_parameterized or not cirq.is_parameterized(op.gate.z_exponent)
5151
):
@@ -373,7 +373,8 @@ def test_swap():
373373
original = cirq.Circuit([cirq.rz(0.123).on(a), cirq.SWAP(a, b)])
374374
optimized = original.copy()
375375

376-
cirq.EjectZ().optimize_circuit(optimized)
376+
with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'):
377+
cirq.EjectZ().optimize_circuit(optimized)
377378
optimized = cirq.drop_empty_moments(optimized)
378379

379380
assert optimized[0].operations == (cirq.SWAP(a, b),)
@@ -384,19 +385,14 @@ def test_swap():
384385
)
385386

386387

387-
@pytest.mark.parametrize('exponent', (0, 2, 1.1, -2, -1.6))
388-
def test_not_a_swap(exponent):
389-
a, b = cirq.LineQubit.range(2)
390-
assert not cirq.optimizers.eject_z._is_swaplike(cirq.SWAP(a, b) ** exponent)
391-
392-
393388
@pytest.mark.parametrize('theta', (np.pi / 2, -np.pi / 2, np.pi / 2 + 5 * np.pi))
394389
def test_swap_fsim(theta):
395390
a, b = cirq.LineQubit.range(2)
396391
original = cirq.Circuit([cirq.rz(0.123).on(a), cirq.FSimGate(theta=theta, phi=0.123).on(a, b)])
397392
optimized = original.copy()
398393

399-
cirq.EjectZ().optimize_circuit(optimized)
394+
with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'):
395+
cirq.EjectZ().optimize_circuit(optimized)
400396
optimized = cirq.drop_empty_moments(optimized)
401397

402398
assert optimized[0].operations == (cirq.FSimGate(theta=theta, phi=0.123).on(a, b),)
@@ -407,19 +403,13 @@ def test_swap_fsim(theta):
407403
)
408404

409405

410-
@pytest.mark.parametrize('theta', (0, 5 * np.pi, -np.pi))
411-
def test_not_a_swap_fsim(theta):
412-
a, b = cirq.LineQubit.range(2)
413-
assert not cirq.optimizers.eject_z._is_swaplike(cirq.FSimGate(theta=theta, phi=0.456).on(a, b))
414-
415-
416406
@pytest.mark.parametrize('exponent', (1, -1))
417407
def test_swap_iswap(exponent):
418408
a, b = cirq.LineQubit.range(2)
419409
original = cirq.Circuit([cirq.rz(0.123).on(a), cirq.ISWAP(a, b) ** exponent])
420410
optimized = original.copy()
421-
422-
cirq.EjectZ().optimize_circuit(optimized)
411+
with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'):
412+
cirq.EjectZ().optimize_circuit(optimized)
423413
optimized = cirq.drop_empty_moments(optimized)
424414

425415
assert optimized[0].operations == (cirq.ISWAP(a, b) ** exponent,)

cirq-core/cirq/optimizers/merge_interactions_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit):
2929
followup_optimizations: List[Callable[[cirq.Circuit], None]] = [
3030
cirq.merge_single_qubit_gates_into_phased_x_z,
3131
cirq.EjectPhasedPaulis().optimize_circuit,
32-
cirq.EjectZ().optimize_circuit,
3332
]
3433
for post in followup_optimizations:
3534
post(actual)
3635
post(expected)
3736

3837
followup_transformers: List[cirq.TRANSFORMER] = [
38+
cirq.eject_z,
3939
cirq.drop_negligible_operations,
4040
cirq.drop_empty_moments,
4141
]

cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs):
4040
followup_optimizations: List[Callable[[cirq.Circuit], None]] = [
4141
cirq.merge_single_qubit_gates_into_phased_x_z,
4242
cirq.EjectPhasedPaulis().optimize_circuit,
43-
cirq.EjectZ().optimize_circuit,
4443
]
4544
for post in followup_optimizations:
4645
post(actual)
4746
post(expected)
4847

4948
followup_transformers: List[cirq.TRANSFORMER] = [
49+
cirq.eject_z,
5050
cirq.drop_negligible_operations,
5151
cirq.drop_empty_moments,
5252
]

cirq-core/cirq/transformers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949

5050
from cirq.transformers.drop_negligible_operations import drop_negligible_operations
5151

52+
from cirq.transformers.eject_z import eject_z
53+
5254
from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements
5355

5456
from cirq.transformers.transformer_api import (

cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
from cirq import ops, linalg, protocols, circuits
2525
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
26+
from cirq.transformers.eject_z import eject_z
2627
from cirq.optimizers import (
27-
eject_z,
2828
eject_phased_paulis,
2929
merge_single_qubit_gates,
3030
)
@@ -165,7 +165,7 @@ def _cleanup_operations(operations: Sequence[ops.Operation]):
165165
circuit = circuits.Circuit(operations)
166166
merge_single_qubit_gates.merge_single_qubit_gates_into_phased_x_z(circuit)
167167
eject_phased_paulis.EjectPhasedPaulis().optimize_circuit(circuit)
168-
eject_z.EjectZ().optimize_circuit(circuit)
168+
circuit = eject_z(circuit)
169169
circuit = circuits.Circuit(circuit.all_operations(), strategy=circuits.InsertStrategy.EARLIEST)
170170
return list(circuit.all_operations())
171171

0 commit comments

Comments
 (0)