Skip to content

Commit c0a6a9f

Browse files
committed
Update for changes in earlier PR, add tests
1 parent b455047 commit c0a6a9f

File tree

2 files changed

+68
-7
lines changed

2 files changed

+68
-7
lines changed

cirq-core/cirq/optimizers/merge_interactions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,12 @@ def __init__(
248248
self,
249249
tolerance: float = 1e-8,
250250
require_three_sqrt_iswap: bool = False,
251+
use_sqrt_iswap_inv: bool = False,
251252
post_clean_up: Callable[[Sequence[ops.Operation]], ops.OP_TREE] = lambda op_list: op_list,
252253
) -> None:
253254
super().__init__(tolerance=tolerance, post_clean_up=post_clean_up)
254255
self.require_three_sqrt_iswap = require_three_sqrt_iswap
256+
self.use_sqrt_iswap_inv = use_sqrt_iswap_inv
255257

256258
def _may_keep_old_op(self, old_op: 'cirq.Operation') -> bool:
257259
"""Returns True if the old two-qubit operation may be left unchanged
@@ -280,7 +282,7 @@ def _two_qubit_matrix_to_operations(
280282
q1,
281283
mat,
282284
required_sqrt_iswap_count=3 if self.require_three_sqrt_iswap else None,
285+
use_sqrt_iswap_inv=self.use_sqrt_iswap_inv,
283286
atol=self.tolerance,
284287
check_preconditions=False,
285-
clean_operations=False,
286288
)

cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import cirq
44

55

6-
def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit):
6+
def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs):
77
actual = cirq.Circuit(before)
8-
opt = cirq.MergeInteractionsToSqrtIswap()
8+
opt = cirq.MergeInteractionsToSqrtIswap(**kwargs)
99
opt.optimize_circuit(actual)
1010

1111
# Ignore differences that would be caught by follow-up optimizations.
@@ -23,15 +23,23 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit):
2323
assert actual == expected, f'ACTUAL {actual} : EXPECTED {expected}'
2424

2525

26-
def assert_optimization_not_broken(circuit):
26+
def assert_optimization_not_broken(circuit: cirq.Circuit):
2727
"""Check that the unitary matrix for the input circuit is the same (up to
2828
global phase and rounding error) as the unitary matrix of the optimized
2929
circuit."""
3030
u_before = circuit.unitary()
31-
cirq.MergeInteractions().optimize_circuit(circuit)
32-
u_after = circuit.unitary()
31+
c_sqrt_iswap = circuit.copy()
32+
cirq.MergeInteractionsToSqrtIswap().optimize_circuit(c_sqrt_iswap)
33+
u_after = c_sqrt_iswap.unitary()
3334

34-
cirq.testing.assert_allclose_up_to_global_phase(u_before, u_after, atol=1e-8)
35+
cirq.testing.assert_allclose_up_to_global_phase(u_before, u_after, atol=2e-8)
36+
37+
# Also test optimization with SQRT_ISWAP_INV
38+
c_sqrt_iswap_inv = circuit.copy()
39+
cirq.MergeInteractionsToSqrtIswap(use_sqrt_iswap_inv=True).optimize_circuit(c_sqrt_iswap_inv)
40+
u_after2 = c_sqrt_iswap_inv.unitary()
41+
42+
cirq.testing.assert_allclose_up_to_global_phase(u_before, u_after2, atol=2e-8)
3543

3644

3745
def test_clears_paired_cnot():
@@ -47,6 +55,57 @@ def test_clears_paired_cnot():
4755
)
4856

4957

58+
def test_simplifies_sqrt_iswap():
59+
a, b = cirq.LineQubit.range(2)
60+
assert_optimizes(
61+
before=cirq.Circuit(
62+
[
63+
# SQRT_ISWAP**8 == Identity
64+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
65+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
66+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
67+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
68+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
69+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
70+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
71+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
72+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
73+
]
74+
),
75+
expected=cirq.Circuit(
76+
[
77+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
78+
]
79+
),
80+
)
81+
82+
83+
def test_simplifies_sqrt_iswap_inv():
84+
a, b = cirq.LineQubit.range(2)
85+
assert_optimizes(
86+
use_sqrt_iswap_inv=True,
87+
before=cirq.Circuit(
88+
[
89+
# SQRT_ISWAP**8 == Identity
90+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
91+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
92+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
93+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
94+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
95+
cirq.Moment([cirq.SQRT_ISWAP_INV(a, b)]),
96+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
97+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
98+
cirq.Moment([cirq.SQRT_ISWAP(a, b)]),
99+
]
100+
),
101+
expected=cirq.Circuit(
102+
[
103+
cirq.Moment([cirq.SQRT_ISWAP_INV(a, b)]),
104+
]
105+
),
106+
)
107+
108+
50109
def test_cnots_separated_by_single_gates_correct():
51110
a, b = cirq.LineQubit.range(2)
52111
assert_optimization_not_broken(

0 commit comments

Comments
 (0)