|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -from typing import Any, Dict, Sequence, Tuple, TYPE_CHECKING |
| 15 | +from typing import Any, Dict, Iterable, Sequence, Tuple, TYPE_CHECKING |
16 | 16 |
|
17 | 17 | from cirq import protocols, value
|
18 | 18 | from cirq._compat import deprecated
|
19 |
| -from cirq.ops import raw_types |
| 19 | +from cirq.ops import raw_types, swap_gates |
20 | 20 |
|
21 | 21 | if TYPE_CHECKING:
|
22 | 22 | import cirq
|
@@ -74,6 +74,25 @@ def num_qubits(self):
|
74 | 74 | def _has_unitary_(self):
|
75 | 75 | return True
|
76 | 76 |
|
| 77 | + def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': |
| 78 | + n = len(qubits) |
| 79 | + qubit_ids = [*range(n)] |
| 80 | + is_sorted = False |
| 81 | + |
| 82 | + def _swap_if_out_of_order(idx: int) -> Iterable['cirq.Operation']: |
| 83 | + nonlocal is_sorted |
| 84 | + if self._permutation[qubit_ids[idx]] > self._permutation[qubit_ids[idx + 1]]: |
| 85 | + yield swap_gates.SWAP(qubits[idx], qubits[idx + 1]) |
| 86 | + qubit_ids[idx + 1], qubit_ids[idx] = qubit_ids[idx], qubit_ids[idx + 1] |
| 87 | + is_sorted = False |
| 88 | + |
| 89 | + while not is_sorted: |
| 90 | + is_sorted = True |
| 91 | + for i in range(0, n - 1, 2): |
| 92 | + yield from _swap_if_out_of_order(i) |
| 93 | + for i in range(1, n - 1, 2): |
| 94 | + yield from _swap_if_out_of_order(i) |
| 95 | + |
77 | 96 | def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'):
|
78 | 97 | # Compute the permutation index list.
|
79 | 98 | permuted_axes = list(range(len(args.target_tensor.shape)))
|
|
0 commit comments