|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -from typing import Iterable, Optional, Sequence, Tuple, Union |
| 15 | +from typing import Iterable, Optional, Sequence, Tuple, Union, List, Iterator |
16 | 16 |
|
| 17 | +from cirq._compat import cached_property |
17 | 18 | import attr
|
18 | 19 | import cirq
|
19 | 20 | from cirq_ft import infra
|
@@ -78,7 +79,7 @@ def _decompose_with_context_(
|
78 | 79 | return
|
79 | 80 | adjoint = []
|
80 | 81 |
|
81 |
| - [are_equal] = context.qubit_manager.qalloc(1) |
| 82 | + (are_equal,) = context.qubit_manager.qalloc(1) |
82 | 83 |
|
83 | 84 | # Initially our belief is that the numbers are equal.
|
84 | 85 | yield cirq.X(are_equal)
|
@@ -130,6 +131,147 @@ def _t_complexity_(self) -> infra.TComplexity:
|
130 | 131 | )
|
131 | 132 |
|
132 | 133 |
|
| 134 | +@attr.frozen |
| 135 | +class BiQubitsMixer(infra.GateWithRegisters): |
| 136 | + """Implements the COMPARE2 (Fig. 1) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf |
| 137 | +
|
| 138 | + This gates mixes the values in a way that preserves the result of comparison. |
| 139 | + The registers being compared are 2-qubit registers where |
| 140 | + x = 2*x_msb + x_lsb |
| 141 | + y = 2*y_msb + y_lsb |
| 142 | + The Gate mixes the 4 qubits so that sign(x - y) = sign(x_lsb' - y_lsb') where x_lsb' and y_lsb' |
| 143 | + are the final values of x_lsb' and y_lsb'. |
| 144 | + """ # pylint: disable=line-too-long |
| 145 | + |
| 146 | + adjoint: bool = False |
| 147 | + |
| 148 | + @cached_property |
| 149 | + def registers(self) -> infra.Registers: |
| 150 | + return infra.Registers.build(x=2, y=2, ancilla=3) |
| 151 | + |
| 152 | + def __repr__(self) -> str: |
| 153 | + return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})' |
| 154 | + |
| 155 | + def decompose_from_registers( |
| 156 | + self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] |
| 157 | + ) -> cirq.OP_TREE: |
| 158 | + x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla'] |
| 159 | + x_msb, x_lsb = x |
| 160 | + y_msb, y_lsb = y |
| 161 | + |
| 162 | + def _cswap(control: cirq.Qid, a: cirq.Qid, b: cirq.Qid, aux: cirq.Qid) -> cirq.OP_TREE: |
| 163 | + """A CSWAP with 4T complexity and whose adjoint has 0T complexity. |
| 164 | +
|
| 165 | + A controlled SWAP that swaps `a` and `b` based on `control`. |
| 166 | + It uses an extra qubit `aux` so that its adjoint would have |
| 167 | + a T complexity of zero. |
| 168 | + """ |
| 169 | + yield cirq.CNOT(a, b) |
| 170 | + yield and_gate.And(adjoint=self.adjoint).on(control, b, aux) |
| 171 | + yield cirq.CNOT(aux, a) |
| 172 | + yield cirq.CNOT(a, b) |
| 173 | + |
| 174 | + def _decomposition(): |
| 175 | + # computes the difference of x - y where |
| 176 | + # x = 2*x_msb + x_lsb |
| 177 | + # y = 2*y_msb + y_lsb |
| 178 | + # And stores the result in x_lsb and y_lsb such that |
| 179 | + # sign(x - y) = sign(x_lsb - y_lsb) |
| 180 | + # This decomposition uses 3 ancilla qubits in order to have a |
| 181 | + # T complexity of 8. |
| 182 | + yield cirq.X(ancilla[0]) |
| 183 | + yield cirq.CNOT(y_msb, x_msb) |
| 184 | + yield cirq.CNOT(y_lsb, x_lsb) |
| 185 | + yield from _cswap(x_msb, x_lsb, ancilla[0], ancilla[1]) |
| 186 | + yield from _cswap(x_msb, y_msb, y_lsb, ancilla[2]) |
| 187 | + yield cirq.CNOT(y_lsb, x_lsb) |
| 188 | + |
| 189 | + if self.adjoint: |
| 190 | + yield from reversed(tuple(cirq.flatten_to_ops(_decomposition()))) |
| 191 | + else: |
| 192 | + yield from _decomposition() |
| 193 | + |
| 194 | + def __pow__(self, power: int) -> cirq.Gate: |
| 195 | + if power == 1: |
| 196 | + return self |
| 197 | + if power == -1: |
| 198 | + return BiQubitsMixer(adjoint=not self.adjoint) |
| 199 | + return NotImplemented # coverage: ignore |
| 200 | + |
| 201 | + def _t_complexity_(self) -> infra.TComplexity: |
| 202 | + if self.adjoint: |
| 203 | + return infra.TComplexity(clifford=18) |
| 204 | + return infra.TComplexity(t=8, clifford=28) |
| 205 | + |
| 206 | + def _has_unitary_(self): |
| 207 | + return not self.adjoint |
| 208 | + |
| 209 | + |
| 210 | +@attr.frozen |
| 211 | +class SingleQubitCompare(infra.GateWithRegisters): |
| 212 | + """Applies U|a>|b>|0>|0> = |a> |a=b> |(a<b)> |(a>b)> |
| 213 | +
|
| 214 | + Source: (FIG. 3) in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf |
| 215 | + """ # pylint: disable=line-too-long |
| 216 | + |
| 217 | + adjoint: bool = False |
| 218 | + |
| 219 | + @cached_property |
| 220 | + def registers(self) -> infra.Registers: |
| 221 | + return infra.Registers.build(a=1, b=1, less_than=1, greater_than=1) |
| 222 | + |
| 223 | + def __repr__(self) -> str: |
| 224 | + return f'cirq_ft.algos.SingleQubitCompare({self.adjoint})' |
| 225 | + |
| 226 | + def decompose_from_registers( |
| 227 | + self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] |
| 228 | + ) -> cirq.OP_TREE: |
| 229 | + a = quregs['a'] |
| 230 | + b = quregs['b'] |
| 231 | + less_than = quregs['less_than'] |
| 232 | + greater_than = quregs['greater_than'] |
| 233 | + |
| 234 | + def _decomposition() -> Iterator[cirq.Operation]: |
| 235 | + yield and_gate.And((0, 1), adjoint=self.adjoint).on(*a, *b, *less_than) |
| 236 | + yield cirq.CNOT(*less_than, *greater_than) |
| 237 | + yield cirq.CNOT(*b, *greater_than) |
| 238 | + yield cirq.CNOT(*a, *b) |
| 239 | + yield cirq.CNOT(*a, *greater_than) |
| 240 | + yield cirq.X(*b) |
| 241 | + |
| 242 | + if self.adjoint: |
| 243 | + yield from reversed(tuple(_decomposition())) |
| 244 | + else: |
| 245 | + yield from _decomposition() |
| 246 | + |
| 247 | + def __pow__(self, power: int) -> cirq.Gate: |
| 248 | + if not isinstance(power, int): |
| 249 | + raise ValueError('SingleQubitCompare is only defined for integer powers.') |
| 250 | + if power % 2 == 0: |
| 251 | + return cirq.IdentityGate(4) |
| 252 | + if power < 0: |
| 253 | + return SingleQubitCompare(adjoint=not self.adjoint) |
| 254 | + return self |
| 255 | + |
| 256 | + def _t_complexity_(self) -> infra.TComplexity: |
| 257 | + if self.adjoint: |
| 258 | + return infra.TComplexity(clifford=11) |
| 259 | + return infra.TComplexity(t=4, clifford=16) |
| 260 | + |
| 261 | + |
| 262 | +def _equality_with_zero( |
| 263 | + context: cirq.DecompositionContext, qubits: Sequence[cirq.Qid], z: cirq.Qid |
| 264 | +) -> cirq.OP_TREE: |
| 265 | + if len(qubits) == 1: |
| 266 | + (q,) = qubits |
| 267 | + yield cirq.X(q) |
| 268 | + yield cirq.CNOT(q, z) |
| 269 | + return |
| 270 | + |
| 271 | + ancilla = context.qubit_manager.qalloc(len(qubits) - 2) |
| 272 | + yield and_gate.And(cv=[0] * len(qubits)).on(*qubits, *ancilla, z) |
| 273 | + |
| 274 | + |
133 | 275 | @attr.frozen
|
134 | 276 | class LessThanEqualGate(cirq.ArithmeticGate):
|
135 | 277 | """Applies U|x>|y>|z> = |x>|y> |z ^ (x <= y)>"""
|
@@ -161,9 +303,130 @@ def __pow__(self, power: int):
|
161 | 303 | def __repr__(self) -> str:
|
162 | 304 | return f'cirq_ft.LessThanEqualGate({self.x_bitsize}, {self.y_bitsize})'
|
163 | 305 |
|
| 306 | + def _decompose_via_tree( |
| 307 | + self, context: cirq.DecompositionContext, X: Sequence[cirq.Qid], Y: Sequence[cirq.Qid] |
| 308 | + ) -> cirq.OP_TREE: |
| 309 | + """Returns comparison oracle from https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf |
| 310 | +
|
| 311 | + This decomposition follows the tree structure of (FIG. 2) |
| 312 | + """ # pylint: disable=line-too-long |
| 313 | + if len(X) == 1: |
| 314 | + return |
| 315 | + if len(X) == 2: |
| 316 | + yield BiQubitsMixer().on_registers(x=X, y=Y, ancilla=context.qubit_manager.qalloc(3)) |
| 317 | + return |
| 318 | + |
| 319 | + m = len(X) // 2 |
| 320 | + yield self._decompose_via_tree(context, X[:m], Y[:m]) |
| 321 | + yield self._decompose_via_tree(context, X[m:], Y[m:]) |
| 322 | + yield BiQubitsMixer().on_registers( |
| 323 | + x=(X[m - 1], X[-1]), y=(Y[m - 1], Y[-1]), ancilla=context.qubit_manager.qalloc(3) |
| 324 | + ) |
| 325 | + |
| 326 | + def _decompose_with_context_( |
| 327 | + self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None |
| 328 | + ) -> cirq.OP_TREE: |
| 329 | + """Decomposes the gate in a T-complexity optimal way. |
| 330 | +
|
| 331 | + The construction can be broken in 4 parts: |
| 332 | + 1. In case of differing bitsizes then a multicontrol And Gate |
| 333 | + - Section III.A. https://arxiv.org/abs/1805.03662) is used to check whether |
| 334 | + the extra prefix is equal to zero: |
| 335 | + - result stored in: `prefix_equality` qubit. |
| 336 | + 2. The tree structure (FIG. 2) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf |
| 337 | + followed by a SingleQubitCompare to compute the result of comparison of |
| 338 | + the suffixes of equal length: |
| 339 | + - result stored in: `less_than` and `greater_than` with equality in qubits[-2] |
| 340 | + 3. The results from the previous two steps are combined to update the target qubit. |
| 341 | + 4. The adjoint of the previous operations is added to restore the input qubits |
| 342 | + to their original state and clean the ancilla qubits. |
| 343 | + """ # pylint: disable=line-too-long |
| 344 | + |
| 345 | + if context is None: |
| 346 | + context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) |
| 347 | + |
| 348 | + lhs, rhs, target = qubits[: self.x_bitsize], qubits[self.x_bitsize : -1], qubits[-1] |
| 349 | + |
| 350 | + n = min(len(lhs), len(rhs)) |
| 351 | + |
| 352 | + prefix_equality = None |
| 353 | + adjoint: List[cirq.Operation] = [] |
| 354 | + |
| 355 | + # if one of the registers is longer than the other store equality with |0--0> |
| 356 | + # into `prefix_equality` using d = |len(P) - len(Q)| And operations => 4d T. |
| 357 | + if len(lhs) != len(rhs): |
| 358 | + (prefix_equality,) = context.qubit_manager.qalloc(1) |
| 359 | + if len(lhs) > len(rhs): |
| 360 | + for op in cirq.flatten_to_ops( |
| 361 | + _equality_with_zero(context, lhs[:-n], prefix_equality) |
| 362 | + ): |
| 363 | + yield op |
| 364 | + adjoint.append(cirq.inverse(op)) |
| 365 | + else: |
| 366 | + for op in cirq.flatten_to_ops( |
| 367 | + _equality_with_zero(context, rhs[:-n], prefix_equality) |
| 368 | + ): |
| 369 | + yield op |
| 370 | + adjoint.append(cirq.inverse(op)) |
| 371 | + |
| 372 | + yield cirq.X(target), cirq.CNOT(prefix_equality, target) |
| 373 | + |
| 374 | + # compare the remaing suffix of P and Q |
| 375 | + lhs = lhs[-n:] |
| 376 | + rhs = rhs[-n:] |
| 377 | + for op in cirq.flatten_to_ops(self._decompose_via_tree(context, lhs, rhs)): |
| 378 | + yield op |
| 379 | + adjoint.append(cirq.inverse(op)) |
| 380 | + |
| 381 | + less_than, greater_than = context.qubit_manager.qalloc(2) |
| 382 | + yield SingleQubitCompare().on_registers( |
| 383 | + a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than |
| 384 | + ) |
| 385 | + adjoint.append( |
| 386 | + SingleQubitCompare(adjoint=True).on_registers( |
| 387 | + a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than |
| 388 | + ) |
| 389 | + ) |
| 390 | + |
| 391 | + if prefix_equality is None: |
| 392 | + yield cirq.X(target) |
| 393 | + yield cirq.CNOT(greater_than, target) |
| 394 | + else: |
| 395 | + (less_than_or_equal,) = context.qubit_manager.qalloc(1) |
| 396 | + yield and_gate.And([1, 0]).on(prefix_equality, greater_than, less_than_or_equal) |
| 397 | + adjoint.append( |
| 398 | + and_gate.And([1, 0], adjoint=True).on( |
| 399 | + prefix_equality, greater_than, less_than_or_equal |
| 400 | + ) |
| 401 | + ) |
| 402 | + |
| 403 | + yield cirq.CNOT(less_than_or_equal, target) |
| 404 | + |
| 405 | + yield from reversed(adjoint) |
| 406 | + |
164 | 407 | def _t_complexity_(self) -> infra.TComplexity:
|
165 |
| - # TODO(#112): This is rough cost that ignores cliffords. |
166 |
| - return infra.TComplexity(t=4 * (self.x_bitsize + self.y_bitsize)) |
| 408 | + n = min(self.x_bitsize, self.y_bitsize) |
| 409 | + d = max(self.x_bitsize, self.y_bitsize) - n |
| 410 | + is_second_longer = self.y_bitsize > self.x_bitsize |
| 411 | + if d == 0: |
| 412 | + # When both registers are of the same size the T complexity is |
| 413 | + # 8n - 4 same as in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf. pylint: disable=line-too-long |
| 414 | + return infra.TComplexity(t=8 * n - 4, clifford=46 * n - 17) |
| 415 | + else: |
| 416 | + # When the registers differ in size and `n` is the size of the smaller one and |
| 417 | + # `d` is the difference in size. The T complexity is the sum of the tree |
| 418 | + # decomposition as before giving 8n + O(1) and the T complexity of an `And` gate |
| 419 | + # over `d` registers giving 4d + O(1) totaling 8n + 4d + O(1). |
| 420 | + # From the decomposition we get that the constant is -4 as well as the clifford counts. |
| 421 | + if d == 1: |
| 422 | + return infra.TComplexity(t=8 * n, clifford=46 * n + 3 + 2 * is_second_longer) |
| 423 | + else: |
| 424 | + return infra.TComplexity( |
| 425 | + t=8 * n + 4 * d - 4, clifford=46 * n + 17 * d - 14 + 2 * is_second_longer |
| 426 | + ) |
| 427 | + |
| 428 | + def _has_unitary_(self): |
| 429 | + return True |
167 | 430 |
|
168 | 431 |
|
169 | 432 | @attr.frozen
|
|
0 commit comments