Skip to content

Commit cad82b3

Browse files
authored
Add cirq.merge_k_qubit_unitaries transformer to replace cirq.MergeSingleQubitGates optimizer (quantumlib#4986)
* Replace cirq.MergeSingleQubitGates optimizer with cirq.merge_single_qubit_gates transformer * Add merge_k_qubit_unitaries primitive
1 parent 77843b1 commit cad82b3

16 files changed

+591
-57
lines changed

cirq/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,14 @@
372372
map_moments,
373373
map_operations,
374374
map_operations_and_unroll,
375+
merge_k_qubit_unitaries,
375376
merge_k_qubit_unitaries_to_circuit_op,
376377
merge_moments,
377378
merge_operations,
378379
merge_operations_to_circuit_op,
380+
merge_single_qubit_gates_to_phased_x_and_z,
381+
merge_single_qubit_gates_to_phxz,
382+
merge_single_qubit_moments_to_phxz,
379383
prepare_two_qubit_state_using_cz,
380384
prepare_two_qubit_state_using_sqrt_iswap,
381385
single_qubit_matrix_to_gates,

cirq/contrib/paulistring/convert_gate_set.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def converted_gate_set(
2828
"""
2929
conv_circuit = circuits.Circuit(circuit)
3030
optimizers.ConvertToCzAndSingleGates().optimize_circuit(conv_circuit)
31-
optimizers.MergeSingleQubitGates().optimize_circuit(conv_circuit)
31+
conv_circuit = transformers.merge_k_qubit_unitaries(conv_circuit, k=1)
3232
ConvertToPauliStringPhasors(
3333
ignore_failures=True,
3434
keep_clifford=not no_clifford_gates,

cirq/devices/noise_model_test.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def test_noise_composition():
130130
a, b, c = cirq.LineQubit.range(3)
131131
noise_z = cirq.ConstantQubitNoiseModel(cirq.Z)
132132
noise_inv_s = cirq.ConstantQubitNoiseModel(cirq.S ** -1)
133-
merge = cirq.optimizers.merge_single_qubit_gates_into_phased_x_z
134133
base_moments = [cirq.Moment([cirq.X(a)]), cirq.Moment([cirq.Y(b)]), cirq.Moment([cirq.H(c)])]
135134
circuit_z = cirq.Circuit(noise_z.noisy_moments(base_moments, [a, b, c]))
136135
circuit_s = cirq.Circuit(noise_inv_s.noisy_moments(base_moments, [a, b, c]))
@@ -147,9 +146,9 @@ def test_noise_composition():
147146
)
148147

149148
# All of the gates will be the same, just out of order. Merging fixes this.
150-
merge(actual_zs)
151-
merge(actual_sz)
152-
merge(expected_circuit)
149+
actual_zs = cirq.merge_single_qubit_gates_to_phased_x_and_z(actual_zs)
150+
actual_sz = cirq.merge_single_qubit_gates_to_phased_x_and_z(actual_sz)
151+
expected_circuit = cirq.merge_single_qubit_gates_to_phased_x_and_z(expected_circuit)
153152
assert_equivalent_op_tree(actual_zs, actual_sz)
154153
assert_equivalent_op_tree(actual_zs, expected_circuit)
155154

cirq/ion/convert_to_ion_gates.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import numpy as np
1616

17-
from cirq import ops, protocols, optimizers, circuits, transformers
17+
from cirq import ops, protocols, circuits, transformers
1818
from cirq.ion import ms, two_qubit_matrix_to_ion_operations, ion_device
1919

2020

@@ -86,6 +86,4 @@ def convert_circuit(self, circuit: circuits.Circuit) -> circuits.Circuit:
8686
for moment in circuit:
8787
for op in moment.operations:
8888
new_circuit.append(self.convert_one(op))
89-
optimizers.merge_single_qubit_gates_into_phased_x_z(new_circuit)
90-
91-
return new_circuit
89+
return transformers.merge_single_qubit_gates_to_phased_x_and_z(new_circuit)

cirq/ion/ion_decomposition.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import numpy as np
2525

26-
from cirq import ops, linalg, protocols, optimizers, circuits, transformers
26+
from cirq import ops, linalg, protocols, circuits, transformers
2727
from cirq.ion import ms
2828

2929
if TYPE_CHECKING:
@@ -52,7 +52,7 @@ def two_qubit_matrix_to_ion_operations(
5252

5353
def _cleanup_operations(operations: List[ops.Operation]):
5454
circuit = circuits.Circuit(operations)
55-
optimizers.merge_single_qubit_gates.merge_single_qubit_gates_into_phased_x_z(circuit)
55+
circuit = transformers.merge_single_qubit_gates_to_phased_x_and_z(circuit)
5656
circuit = transformers.eject_phased_paulis(circuit)
5757
circuit = transformers.eject_z(circuit)
5858
circuit = circuits.Circuit(circuit.all_operations(), strategy=circuits.InsertStrategy.EARLIEST)

cirq/optimizers/merge_interactions_test.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Callable, List
15+
from typing import List
1616

1717
import pytest
1818
import sympy
@@ -26,14 +26,8 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit):
2626
opt.optimize_circuit(actual)
2727

2828
# Ignore differences that would be caught by follow-up optimizations.
29-
followup_optimizations: List[Callable[[cirq.Circuit], None]] = [
30-
cirq.merge_single_qubit_gates_into_phased_x_z,
31-
]
32-
for post in followup_optimizations:
33-
post(actual)
34-
post(expected)
35-
3629
followup_transformers: List[cirq.TRANSFORMER] = [
30+
cirq.merge_single_qubit_gates_to_phased_x_and_z,
3731
cirq.eject_phased_paulis,
3832
cirq.eject_z,
3933
cirq.drop_negligible_operations,

cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Callable, List
15+
from typing import List
1616

1717
import pytest
1818

@@ -37,14 +37,8 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs):
3737
opt.optimize_circuit(actual)
3838

3939
# Ignore differences that would be caught by follow-up optimizations.
40-
followup_optimizations: List[Callable[[cirq.Circuit], None]] = [
41-
cirq.merge_single_qubit_gates_into_phased_x_z,
42-
]
43-
for post in followup_optimizations:
44-
post(actual)
45-
post(expected)
46-
4740
followup_transformers: List[cirq.TRANSFORMER] = [
41+
cirq.merge_single_qubit_gates_to_phased_x_and_z,
4842
cirq.eject_phased_paulis,
4943
cirq.eject_z,
5044
cirq.drop_negligible_operations,

cirq/optimizers/merge_single_qubit_gates.py

+10-14
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818

1919
import numpy as np
2020

21-
from cirq import ops, linalg, protocols, circuits
22-
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
21+
from cirq import ops, linalg, protocols, circuits, _compat, transformers
2322

2423
if TYPE_CHECKING:
2524
import cirq
2625

2726

27+
@_compat.deprecated_class(deadline='v1.0', fix='Use cirq.merge_k_qubit_unitaries instead.')
2828
class MergeSingleQubitGates(circuits.PointOptimizer):
2929
"""Optimizes runs of adjacent unitary 1-qubit operations."""
3030

@@ -101,6 +101,9 @@ def optimization_at(
101101
)
102102

103103

104+
@_compat.deprecated(
105+
deadline='v1.0', fix='Use cirq.merge_single_qubit_gates_to_phased_x_and_z instead.'
106+
)
104107
def merge_single_qubit_gates_into_phased_x_z(circuit: circuits.Circuit, atol: float = 1e-8) -> None:
105108
"""Canonicalizes runs of single-qubit rotations in a circuit.
106109
@@ -113,14 +116,12 @@ def merge_single_qubit_gates_into_phased_x_z(circuit: circuits.Circuit, atol: fl
113116
atol: Absolute tolerance to angle error. Larger values allow more
114117
negligible gates to be dropped, smaller values increase accuracy.
115118
"""
116-
117-
def synth(qubit: 'cirq.Qid', matrix: np.ndarray) -> List[ops.Operation]:
118-
out_gates = single_qubit_decompositions.single_qubit_matrix_to_phased_x_z(matrix, atol)
119-
return [gate(qubit) for gate in out_gates]
120-
121-
MergeSingleQubitGates(synthesizer=synth).optimize_circuit(circuit)
119+
circuit._moments = [
120+
*transformers.merge_single_qubit_gates_to_phased_x_and_z(circuit, atol=atol)
121+
]
122122

123123

124+
@_compat.deprecated(deadline='v1.0', fix='Use cirq.merge_single_qubit_gates_to_phxz instead.')
124125
def merge_single_qubit_gates_into_phxz(
125126
circuit: circuits.Circuit,
126127
atol: float = 1e-8,
@@ -135,9 +136,4 @@ def merge_single_qubit_gates_into_phxz(
135136
atol: Absolute tolerance to angle error. Larger values allow more
136137
negligible gates to be dropped, smaller values increase accuracy.
137138
"""
138-
139-
def synth(qubit: 'cirq.Qid', matrix: np.ndarray) -> List[ops.Operation]:
140-
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(matrix, atol)
141-
return [gate(qubit)] if gate else []
142-
143-
MergeSingleQubitGates(synthesizer=synth).optimize_circuit(circuit)
139+
circuit._moments = [*transformers.merge_single_qubit_gates_to_phxz(circuit, atol=atol)]

cirq/optimizers/merge_single_qubit_gates_test.py

+26-13
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ def assert_optimizes(
2323
before: cirq.Circuit,
2424
expected: cirq.Circuit,
2525
optimizer: Optional[Callable[[cirq.Circuit], None]] = None,
26+
deprecated_msg: str = "Use cirq.merge_k_qubit_unitaries",
2627
):
27-
if optimizer is None:
28-
optimizer = cirq.MergeSingleQubitGates().optimize_circuit
29-
optimizer(before)
28+
with cirq.testing.assert_deprecated(deprecated_msg, deadline='v1.0'):
29+
if optimizer is None:
30+
optimizer = cirq.MergeSingleQubitGates().optimize_circuit
31+
optimizer(before)
3032

3133
# Ignore differences that would be caught by follow-up optimizations.
3234
followup_transformers = [cirq.drop_negligible_operations, cirq.drop_empty_moments]
@@ -38,7 +40,8 @@ def assert_optimizes(
3840

3941

4042
def test_leaves_singleton():
41-
m = cirq.MergeSingleQubitGates()
43+
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
44+
m = cirq.MergeSingleQubitGates()
4245
q = cirq.NamedQubit('q')
4346
c = cirq.Circuit([cirq.Moment([cirq.X(q)])])
4447

@@ -48,12 +51,16 @@ def test_leaves_singleton():
4851

4952

5053
def test_not_both():
51-
with pytest.raises(ValueError):
52-
_ = cirq.MergeSingleQubitGates(synthesizer=lambda *args: None, rewriter=lambda *args: None)
54+
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
55+
with pytest.raises(ValueError):
56+
_ = cirq.MergeSingleQubitGates(
57+
synthesizer=lambda *args: None, rewriter=lambda *args: None
58+
)
5359

5460

5561
def test_combines_sequence():
56-
m = cirq.MergeSingleQubitGates()
62+
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
63+
m = cirq.MergeSingleQubitGates()
5764
q = cirq.NamedQubit('q')
5865
c = cirq.Circuit(cirq.X(q) ** 0.5, cirq.Z(q) ** 0.5, cirq.X(q) ** -0.5)
5966

@@ -83,7 +90,8 @@ def test_removes_identity_sequence():
8390

8491

8592
def test_stopped_at_2qubit():
86-
m = cirq.MergeSingleQubitGates()
93+
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
94+
m = cirq.MergeSingleQubitGates()
8795
q = cirq.NamedQubit('q')
8896
q2 = cirq.NamedQubit('q2')
8997
c = cirq.Circuit(
@@ -109,7 +117,8 @@ def test_stopped_at_2qubit():
109117

110118

111119
def test_ignores_2qubit_target():
112-
m = cirq.MergeSingleQubitGates()
120+
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
121+
m = cirq.MergeSingleQubitGates()
113122
q = cirq.NamedQubit('q')
114123
q2 = cirq.NamedQubit('q2')
115124
c = cirq.Circuit(
@@ -132,7 +141,8 @@ class UnsupportedDummy(cirq.SingleQubitGate):
132141
UnsupportedDummy()(q0),
133142
)
134143
c_orig = cirq.Circuit(circuit)
135-
cirq.MergeSingleQubitGates().optimize_circuit(circuit)
144+
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
145+
cirq.MergeSingleQubitGates().optimize_circuit(circuit)
136146

137147
assert circuit == c_orig
138148

@@ -147,9 +157,10 @@ def test_rewrite():
147157
cirq.CZ(q0, q1),
148158
cirq.Y(q1),
149159
)
150-
cirq.MergeSingleQubitGates(rewriter=lambda ops: cirq.H(ops[0].qubits[0])).optimize_circuit(
151-
circuit
152-
)
160+
with cirq.testing.assert_deprecated("Use cirq.merge_k_qubit_unitaries", deadline='v1.0'):
161+
cirq.MergeSingleQubitGates(rewriter=lambda ops: cirq.H(ops[0].qubits[0])).optimize_circuit(
162+
circuit
163+
)
153164
circuit = cirq.drop_empty_moments(circuit)
154165

155166
cirq.testing.assert_same_circuits(
@@ -180,6 +191,7 @@ def test_merge_single_qubit_gates_into_phased_x_z():
180191
(cirq.PhasedXPowGate(phase_exponent=-0.5)(a)) ** 0.5,
181192
),
182193
optimizer=cirq.merge_single_qubit_gates_into_phased_x_z,
194+
deprecated_msg="Use cirq.merge_single_qubit_gates_to_phased_x_and_z",
183195
)
184196

185197

@@ -207,4 +219,5 @@ def phxz(a, x, z):
207219
phxz(-0.5, 0.5, 0).on(a),
208220
),
209221
optimizer=cirq.merge_single_qubit_gates_into_phxz,
222+
deprecated_msg="Use cirq.merge_single_qubit_gates_to_phxz",
210223
)

cirq/transformers/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@
6060
dephase_measurements,
6161
)
6262

63+
from cirq.transformers.merge_k_qubit_gates import merge_k_qubit_unitaries
64+
65+
from cirq.transformers.merge_single_qubit_gates import (
66+
merge_single_qubit_gates_to_phased_x_and_z,
67+
merge_single_qubit_gates_to_phxz,
68+
merge_single_qubit_moments_to_phxz,
69+
)
70+
6371
from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements
6472

6573
from cirq.transformers.transformer_api import (

cirq/transformers/analytical_decompositions/two_qubit_to_cz.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323

2424
from cirq import ops, linalg, protocols, circuits
2525
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
26+
from cirq.transformers.merge_single_qubit_gates import merge_single_qubit_gates_to_phased_x_and_z
2627
from cirq.transformers.eject_z import eject_z
2728
from cirq.transformers.eject_phased_paulis import eject_phased_paulis
28-
from cirq.optimizers import merge_single_qubit_gates
2929

3030
if TYPE_CHECKING:
3131
import cirq
@@ -161,7 +161,7 @@ def _xx_yy_zz_interaction_via_full_czs(
161161

162162
def _cleanup_operations(operations: Sequence[ops.Operation]):
163163
circuit = circuits.Circuit(operations)
164-
merge_single_qubit_gates.merge_single_qubit_gates_into_phased_x_z(circuit)
164+
circuit = merge_single_qubit_gates_to_phased_x_and_z(circuit)
165165
circuit = eject_phased_paulis(circuit)
166166
circuit = eject_z(circuit)
167167
circuit = circuits.Circuit(circuit.all_operations(), strategy=circuits.InsertStrategy.EARLIEST)

0 commit comments

Comments
 (0)