Skip to content

Commit 848bfde

Browse files
authored
Add default decomposition for cirq.QubitPermutationGate in terms of adjacent swaps (#5093)
- Adds decomposition to `cirq.QubitPermutationGate` in terms of minimum number of adjacent swap operations on qubits. - Part of #4858 Closes #5090
1 parent e0f7432 commit 848bfde

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

cirq-core/cirq/ops/permutation_gate.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, Sequence, Tuple, TYPE_CHECKING
15+
from typing import Any, Dict, Iterable, Sequence, Tuple, TYPE_CHECKING
1616

1717
from cirq import protocols, value
1818
from cirq._compat import deprecated
19-
from cirq.ops import raw_types
19+
from cirq.ops import raw_types, swap_gates
2020

2121
if TYPE_CHECKING:
2222
import cirq
@@ -74,6 +74,25 @@ def num_qubits(self):
7474
def _has_unitary_(self):
7575
return True
7676

77+
def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
78+
n = len(qubits)
79+
qubit_ids = [*range(n)]
80+
is_sorted = False
81+
82+
def _swap_if_out_of_order(idx: int) -> Iterable['cirq.Operation']:
83+
nonlocal is_sorted
84+
if self._permutation[qubit_ids[idx]] > self._permutation[qubit_ids[idx + 1]]:
85+
yield swap_gates.SWAP(qubits[idx], qubits[idx + 1])
86+
qubit_ids[idx + 1], qubit_ids[idx] = qubit_ids[idx], qubit_ids[idx + 1]
87+
is_sorted = False
88+
89+
while not is_sorted:
90+
is_sorted = True
91+
for i in range(0, n - 1, 2):
92+
yield from _swap_if_out_of_order(i)
93+
for i in range(1, n - 1, 2):
94+
yield from _swap_if_out_of_order(i)
95+
7796
def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'):
7897
# Compute the permutation index list.
7998
permuted_axes = list(range(len(args.target_tensor.shape)))

cirq-core/cirq/ops/permutation_gate_test.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616

1717
import cirq
18+
import numpy as np
1819
from cirq.ops import QubitPermutationGate
1920

2021

@@ -30,8 +31,12 @@ def test_permutation_gate_repr():
3031
cirq.testing.assert_equivalent_repr(QubitPermutationGate([0, 1]))
3132

3233

33-
def test_permutation_gate_consistent_protocols():
34-
gate = QubitPermutationGate([1, 0, 2, 3])
34+
rs = np.random.RandomState(seed=1234)
35+
36+
37+
@pytest.mark.parametrize('permutation', [rs.permutation(i) for i in range(3, 7)])
38+
def test_permutation_gate_consistent_protocols(permutation):
39+
gate = QubitPermutationGate(list(permutation))
3540
cirq.testing.assert_implements_consistent_protocols(gate)
3641

3742

@@ -98,6 +103,8 @@ def test_permutation_gate_maps(maps, permutation):
98103
permutationOp = cirq.QubitPermutationGate(permutation).on(*qs)
99104
circuit = cirq.Circuit(permutationOp)
100105
cirq.testing.assert_equivalent_computational_basis_map(maps, circuit)
106+
circuit = cirq.Circuit(cirq.I.on_each(*qs), cirq.decompose(permutationOp))
107+
cirq.testing.assert_equivalent_computational_basis_map(maps, circuit)
101108

102109

103110
def test_setters_deprecated():

0 commit comments

Comments
 (0)