Skip to content

Commit b2bf5b3

Browse files
committed
Address comments, add docstrings, add tests
1 parent a053859 commit b2bf5b3

File tree

3 files changed

+194
-19
lines changed

3 files changed

+194
-19
lines changed

cirq-core/cirq/optimizers/merge_interactions.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ def __init__(
3535
tolerance: float = 1e-8,
3636
post_clean_up: Callable[[Sequence[ops.Operation]], ops.OP_TREE] = lambda op_list: op_list,
3737
) -> None:
38+
"""
39+
Args:
40+
tolerance: A limit on the amount of absolute error introduced by the
41+
construction.
42+
post_clean_up: This function is called on each set of optimized
43+
operations before they are put into the circuit to replace the
44+
old operations.
45+
"""
3846
super().__init__(post_clean_up=post_clean_up)
3947
self.tolerance = tolerance
4048

@@ -62,7 +70,7 @@ def optimization_at(
6270
if not switch_to_new and old_interaction_count <= 1:
6371
return None
6472

65-
# Find a max-3-cz construction.
73+
# Find a (possibly ideal) decomposition of the merged operations.
6674
new_operations = self._two_qubit_matrix_to_operations(op.qubits[0], op.qubits[1], matrix)
6775
new_interaction_count = len(
6876
[new_op for new_op in new_operations if len(new_op.qubits) == 2]
@@ -207,6 +215,15 @@ def __init__(
207215
allow_partial_czs: bool = True,
208216
post_clean_up: Callable[[Sequence[ops.Operation]], ops.OP_TREE] = lambda op_list: op_list,
209217
) -> None:
218+
"""
219+
Args:
220+
tolerance: A limit on the amount of absolute error introduced by the
221+
construction.
222+
allow_partial_czs: Enables the use of Partial-CZ gates.
223+
post_clean_up: This function is called on each set of optimized
224+
operations before they are put into the circuit to replace the
225+
old operations.
226+
"""
210227
super().__init__(tolerance=tolerance, post_clean_up=post_clean_up)
211228
self.allow_partial_czs = allow_partial_czs
212229

cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2018 The Cirq Developers
1+
# Copyright 2021 The Cirq Developers
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
1515
"""An optimization pass that combines adjacent series of gates on two qubits and
1616
outputs a circuit with SQRT_ISWAP or SQRT_ISWAP_INV gates."""
1717

18-
from typing import Callable, Sequence, TYPE_CHECKING
18+
from typing import Callable, Optional, Sequence, TYPE_CHECKING
1919

2020
import numpy as np
2121

@@ -29,22 +29,51 @@
2929
class MergeInteractionsToSqrtIswap(merge_interactions.MergeInteractionsAbc):
3030
"""Combines series of adjacent one and two-qubit gates operating on a pair
3131
of qubits and replaces each series with the minimum number of SQRT_ISWAP
32-
gates."""
32+
gates.
33+
34+
See also: ``two_qubit_matrix_to_sqrt_iswap_operations``
35+
"""
3336

3437
def __init__(
3538
self,
3639
tolerance: float = 1e-8,
37-
require_three_sqrt_iswap: bool = False,
40+
*,
41+
required_sqrt_iswap_count: Optional[int] = None,
3842
use_sqrt_iswap_inv: bool = False,
3943
post_clean_up: Callable[[Sequence[ops.Operation]], ops.OP_TREE] = lambda op_list: op_list,
4044
) -> None:
45+
"""
46+
Args:
47+
tolerance: A limit on the amount of absolute error introduced by the
48+
construction.
49+
required_sqrt_iswap_count: When specified, each merged group of
50+
two-qubit gates will be decomposed into exactly this many
51+
sqrt-iSWAP gates even if fewer is possible (maximum 3). Circuit
52+
optimization will raise a ``ValueError`` if this number is 2 or
53+
lower and synthesis of any set of merged interactions requires
54+
more.
55+
use_sqrt_iswap_inv: If True, optimizes circuits using
56+
``SQRT_ISWAP_INV`` gates instead of ``SQRT_ISWAP``.
57+
post_clean_up: This function is called on each set of optimized
58+
operations before they are put into the circuit to replace the
59+
old operations.
60+
61+
Raises:
62+
ValueError:
63+
If ``required_sqrt_iswap_count`` is not one of the supported
64+
values 0, 1, 2, or 3.
65+
"""
66+
if required_sqrt_iswap_count is not None and not 0 <= required_sqrt_iswap_count <= 3:
67+
raise ValueError('the argument `required_sqrt_iswap_count` must be 0, 1, 2, or 3.')
4168
super().__init__(tolerance=tolerance, post_clean_up=post_clean_up)
42-
self.require_three_sqrt_iswap = require_three_sqrt_iswap
69+
self.required_sqrt_iswap_count = required_sqrt_iswap_count
4370
self.use_sqrt_iswap_inv = use_sqrt_iswap_inv
4471

4572
def _may_keep_old_op(self, old_op: 'cirq.Operation') -> bool:
4673
"""Returns True if the old two-qubit operation may be left unchanged
4774
without decomposition."""
75+
if self.use_sqrt_iswap_inv:
76+
return isinstance(old_op.gate, ops.ISwapPowGate) and old_op.gate.exponent == -0.5
4877
return isinstance(old_op.gate, ops.ISwapPowGate) and old_op.gate.exponent == 0.5
4978

5079
def _two_qubit_matrix_to_operations(
@@ -68,7 +97,7 @@ def _two_qubit_matrix_to_operations(
6897
q0,
6998
q1,
7099
mat,
71-
required_sqrt_iswap_count=3 if self.require_three_sqrt_iswap else None,
100+
required_sqrt_iswap_count=self.required_sqrt_iswap_count,
72101
use_sqrt_iswap_inv=self.use_sqrt_iswap_inv,
73102
atol=self.tolerance,
74103
check_preconditions=False,

cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py

+141-12
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,38 @@
1+
# Copyright 2021 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
from typing import Callable, List
216

17+
import pytest
18+
319
import cirq
420

521

622
def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs):
7-
actual = cirq.Circuit(before)
23+
"""Check that optimizing the circuit ``before`` produces the circuit ``expected``.
24+
25+
The optimized circuit is cleaned up with follow up optimizations to make the
26+
comparison more robust to extra moments or extra gates nearly equal to
27+
identity that don't matter.
28+
29+
Args:
30+
before: The input circuit to optimize.
31+
expected: The expected result of optimization to compare against.
32+
kwargs: Any extra arguments to pass to the
33+
``MergeInteractionsToSqrtIswap`` constructor.
34+
"""
35+
actual = before.copy()
836
opt = cirq.MergeInteractionsToSqrtIswap(**kwargs)
937
opt.optimize_circuit(actual)
1038

@@ -23,21 +51,21 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs):
2351
assert actual == expected, f'ACTUAL {actual} : EXPECTED {expected}'
2452

2553

26-
def assert_optimization_not_broken(circuit: cirq.Circuit):
54+
def assert_optimization_not_broken(circuit: cirq.Circuit, **kwargs):
2755
"""Check that the unitary matrix for the input circuit is the same (up to
2856
global phase and rounding error) as the unitary matrix of the optimized
2957
circuit."""
3058
u_before = circuit.unitary()
3159
c_sqrt_iswap = circuit.copy()
32-
cirq.MergeInteractionsToSqrtIswap().optimize_circuit(c_sqrt_iswap)
33-
u_after = c_sqrt_iswap.unitary()
60+
cirq.MergeInteractionsToSqrtIswap(**kwargs).optimize_circuit(c_sqrt_iswap)
61+
u_after = c_sqrt_iswap.unitary(circuit.all_qubits())
3462

3563
cirq.testing.assert_allclose_up_to_global_phase(u_before, u_after, atol=2e-8)
3664

3765
# Also test optimization with SQRT_ISWAP_INV
3866
c_sqrt_iswap_inv = circuit.copy()
3967
cirq.MergeInteractionsToSqrtIswap(use_sqrt_iswap_inv=True).optimize_circuit(c_sqrt_iswap_inv)
40-
u_after2 = c_sqrt_iswap_inv.unitary()
68+
u_after2 = c_sqrt_iswap_inv.unitary(circuit.all_qubits())
4169

4270
cirq.testing.assert_allclose_up_to_global_phase(u_before, u_after2, atol=2e-8)
4371

@@ -106,6 +134,57 @@ def test_simplifies_sqrt_iswap_inv():
106134
)
107135

108136

137+
def test_works_with_tags():
138+
a, b = cirq.LineQubit.range(2)
139+
assert_optimizes(
140+
before=cirq.Circuit(
141+
[
142+
cirq.Moment([cirq.SQRT_ISWAP(a, b).with_tags('mytag1')]),
143+
cirq.Moment([cirq.SQRT_ISWAP(a, b).with_tags('mytag2')]),
144+
cirq.Moment([cirq.SQRT_ISWAP_INV(a, b).with_tags('mytag3')]),
145+
]
146+
),
147+
expected=cirq.Circuit(
148+
[
149+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
150+
]
151+
),
152+
)
153+
154+
155+
def test_no_touch_single_sqrt_iswap():
156+
a, b = cirq.LineQubit.range(2)
157+
assert_optimizes(
158+
before=cirq.Circuit(
159+
[
160+
cirq.Moment([cirq.SQRT_ISWAP(a, b).with_tags('mytag')]),
161+
]
162+
),
163+
expected=cirq.Circuit(
164+
[
165+
cirq.Moment([cirq.SQRT_ISWAP(a, b).with_tags('mytag')]),
166+
]
167+
),
168+
)
169+
170+
171+
def test_no_touch_single_sqrt_iswap_inv():
172+
a, b = cirq.LineQubit.range(2)
173+
assert_optimizes(
174+
use_sqrt_iswap_inv=True,
175+
before=cirq.Circuit(
176+
[
177+
cirq.Moment([cirq.SQRT_ISWAP_INV(a, b).with_tags('mytag')]),
178+
]
179+
),
180+
expected=cirq.Circuit(
181+
[
182+
cirq.Moment([cirq.SQRT_ISWAP_INV(a, b).with_tags('mytag')]),
183+
]
184+
),
185+
)
186+
187+
109188
def test_cnots_separated_by_single_gates_correct():
110189
a, b = cirq.LineQubit.range(2)
111190
assert_optimization_not_broken(
@@ -165,23 +244,73 @@ def test_optimizes_single_iswap():
165244

166245
def test_optimizes_single_inv_sqrt_iswap():
167246
a, b = cirq.LineQubit.range(2)
168-
c = cirq.Circuit(cirq.SQRT_ISWAP(a, b) ** -1)
247+
c = cirq.Circuit(cirq.SQRT_ISWAP_INV(a, b))
169248
assert_optimization_not_broken(c)
170249
cirq.MergeInteractionsToSqrtIswap().optimize_circuit(c)
171250
assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 1
172251

173252

253+
def test_init_raises():
254+
with pytest.raises(ValueError, match='must be 0, 1, 2, or 3'):
255+
cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=4)
256+
257+
258+
def test_optimizes_single_iswap_require0():
259+
a, b = cirq.LineQubit.range(2)
260+
c = cirq.Circuit(cirq.CNOT(a, b), cirq.CNOT(a, b)) # Minimum 0 sqrt-iSWAP
261+
assert_optimization_not_broken(c, required_sqrt_iswap_count=0)
262+
cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=0).optimize_circuit(c)
263+
assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 0
264+
265+
266+
def test_optimizes_single_iswap_require0_raises():
267+
a, b = cirq.LineQubit.range(2)
268+
c = cirq.Circuit(cirq.CNOT(a, b)) # Minimum 2 sqrt-iSWAP
269+
with pytest.raises(ValueError, match='cannot be decomposed into exactly 0 sqrt-iSWAP gates'):
270+
cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=0).optimize_circuit(c)
271+
272+
273+
def test_optimizes_single_iswap_require1():
274+
a, b = cirq.LineQubit.range(2)
275+
c = cirq.Circuit(cirq.SQRT_ISWAP_INV(a, b)) # Minimum 1 sqrt-iSWAP
276+
assert_optimization_not_broken(c, required_sqrt_iswap_count=1)
277+
cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=1).optimize_circuit(c)
278+
assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 1
279+
280+
281+
def test_optimizes_single_iswap_require1_raises():
282+
a, b = cirq.LineQubit.range(2)
283+
c = cirq.Circuit(cirq.CNOT(a, b)) # Minimum 2 sqrt-iSWAP
284+
with pytest.raises(ValueError, match='cannot be decomposed into exactly 1 sqrt-iSWAP gates'):
285+
cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=1).optimize_circuit(c)
286+
287+
288+
def test_optimizes_single_iswap_require2():
289+
a, b = cirq.LineQubit.range(2)
290+
c = cirq.Circuit(cirq.SQRT_ISWAP_INV(a, b)) # Minimum 1 sqrt-iSWAP but 2 possible
291+
assert_optimization_not_broken(c, required_sqrt_iswap_count=2)
292+
cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=2).optimize_circuit(c)
293+
assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2
294+
295+
296+
def test_optimizes_single_iswap_require2_raises():
297+
a, b = cirq.LineQubit.range(2)
298+
c = cirq.Circuit(cirq.SWAP(a, b)) # Minimum 3 sqrt-iSWAP
299+
with pytest.raises(ValueError, match='cannot be decomposed into exactly 2 sqrt-iSWAP gates'):
300+
cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=2).optimize_circuit(c)
301+
302+
174303
def test_optimizes_single_iswap_require3():
175304
a, b = cirq.LineQubit.range(2)
176-
c = cirq.Circuit(cirq.ISWAP(a, b))
177-
assert_optimization_not_broken(c)
178-
cirq.MergeInteractionsToSqrtIswap(require_three_sqrt_iswap=True).optimize_circuit(c)
305+
c = cirq.Circuit(cirq.ISWAP(a, b)) # Minimum 2 sqrt-iSWAP but 3 possible
306+
assert_optimization_not_broken(c, required_sqrt_iswap_count=3)
307+
cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=3).optimize_circuit(c)
179308
assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 3
180309

181310

182311
def test_optimizes_single_inv_sqrt_iswap_require3():
183312
a, b = cirq.LineQubit.range(2)
184-
c = cirq.Circuit(cirq.SQRT_ISWAP(a, b) ** -1)
185-
assert_optimization_not_broken(c)
186-
cirq.MergeInteractionsToSqrtIswap(require_three_sqrt_iswap=True).optimize_circuit(c)
313+
c = cirq.Circuit(cirq.SQRT_ISWAP_INV(a, b))
314+
assert_optimization_not_broken(c, required_sqrt_iswap_count=3)
315+
cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=3).optimize_circuit(c)
187316
assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 3

0 commit comments

Comments
 (0)