Skip to content

Commit c93224e

Browse files
Implemented 8n T complexity decomposition of LessThanEqual gate (#6156)
This is the Comparison Oracle from https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf
1 parent cb05a69 commit c93224e

File tree

4 files changed

+374
-12
lines changed

4 files changed

+374
-12
lines changed

cirq-ft/cirq_ft/algos/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
ContiguousRegisterGate,
2121
LessThanEqualGate,
2222
LessThanGate,
23+
SingleQubitCompare,
24+
BiQubitsMixer,
2325
)
2426
from cirq_ft.algos.generic_select import GenericSelect
2527
from cirq_ft.algos.hubbard_model import PrepareHubbard, SelectHubbard

cirq-ft/cirq_ft/algos/arithmetic_gates.py

+267-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Iterable, Optional, Sequence, Tuple, Union
15+
from typing import Iterable, Optional, Sequence, Tuple, Union, List, Iterator
1616

17+
from cirq._compat import cached_property
1718
import attr
1819
import cirq
1920
from cirq_ft import infra
@@ -78,7 +79,7 @@ def _decompose_with_context_(
7879
return
7980
adjoint = []
8081

81-
[are_equal] = context.qubit_manager.qalloc(1)
82+
(are_equal,) = context.qubit_manager.qalloc(1)
8283

8384
# Initially our belief is that the numbers are equal.
8485
yield cirq.X(are_equal)
@@ -130,6 +131,147 @@ def _t_complexity_(self) -> infra.TComplexity:
130131
)
131132

132133

134+
@attr.frozen
135+
class BiQubitsMixer(infra.GateWithRegisters):
136+
"""Implements the COMPARE2 (Fig. 1) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf
137+
138+
This gates mixes the values in a way that preserves the result of comparison.
139+
The registers being compared are 2-qubit registers where
140+
x = 2*x_msb + x_lsb
141+
y = 2*y_msb + y_lsb
142+
The Gate mixes the 4 qubits so that sign(x - y) = sign(x_lsb' - y_lsb') where x_lsb' and y_lsb'
143+
are the final values of x_lsb' and y_lsb'.
144+
""" # pylint: disable=line-too-long
145+
146+
adjoint: bool = False
147+
148+
@cached_property
149+
def registers(self) -> infra.Registers:
150+
return infra.Registers.build(x=2, y=2, ancilla=3)
151+
152+
def __repr__(self) -> str:
153+
return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})'
154+
155+
def decompose_from_registers(
156+
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
157+
) -> cirq.OP_TREE:
158+
x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla']
159+
x_msb, x_lsb = x
160+
y_msb, y_lsb = y
161+
162+
def _cswap(control: cirq.Qid, a: cirq.Qid, b: cirq.Qid, aux: cirq.Qid) -> cirq.OP_TREE:
163+
"""A CSWAP with 4T complexity and whose adjoint has 0T complexity.
164+
165+
A controlled SWAP that swaps `a` and `b` based on `control`.
166+
It uses an extra qubit `aux` so that its adjoint would have
167+
a T complexity of zero.
168+
"""
169+
yield cirq.CNOT(a, b)
170+
yield and_gate.And(adjoint=self.adjoint).on(control, b, aux)
171+
yield cirq.CNOT(aux, a)
172+
yield cirq.CNOT(a, b)
173+
174+
def _decomposition():
175+
# computes the difference of x - y where
176+
# x = 2*x_msb + x_lsb
177+
# y = 2*y_msb + y_lsb
178+
# And stores the result in x_lsb and y_lsb such that
179+
# sign(x - y) = sign(x_lsb - y_lsb)
180+
# This decomposition uses 3 ancilla qubits in order to have a
181+
# T complexity of 8.
182+
yield cirq.X(ancilla[0])
183+
yield cirq.CNOT(y_msb, x_msb)
184+
yield cirq.CNOT(y_lsb, x_lsb)
185+
yield from _cswap(x_msb, x_lsb, ancilla[0], ancilla[1])
186+
yield from _cswap(x_msb, y_msb, y_lsb, ancilla[2])
187+
yield cirq.CNOT(y_lsb, x_lsb)
188+
189+
if self.adjoint:
190+
yield from reversed(tuple(cirq.flatten_to_ops(_decomposition())))
191+
else:
192+
yield from _decomposition()
193+
194+
def __pow__(self, power: int) -> cirq.Gate:
195+
if power == 1:
196+
return self
197+
if power == -1:
198+
return BiQubitsMixer(adjoint=not self.adjoint)
199+
return NotImplemented # coverage: ignore
200+
201+
def _t_complexity_(self) -> infra.TComplexity:
202+
if self.adjoint:
203+
return infra.TComplexity(clifford=18)
204+
return infra.TComplexity(t=8, clifford=28)
205+
206+
def _has_unitary_(self):
207+
return not self.adjoint
208+
209+
210+
@attr.frozen
211+
class SingleQubitCompare(infra.GateWithRegisters):
212+
"""Applies U|a>|b>|0>|0> = |a> |a=b> |(a<b)> |(a>b)>
213+
214+
Source: (FIG. 3) in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf
215+
""" # pylint: disable=line-too-long
216+
217+
adjoint: bool = False
218+
219+
@cached_property
220+
def registers(self) -> infra.Registers:
221+
return infra.Registers.build(a=1, b=1, less_than=1, greater_than=1)
222+
223+
def __repr__(self) -> str:
224+
return f'cirq_ft.algos.SingleQubitCompare({self.adjoint})'
225+
226+
def decompose_from_registers(
227+
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
228+
) -> cirq.OP_TREE:
229+
a = quregs['a']
230+
b = quregs['b']
231+
less_than = quregs['less_than']
232+
greater_than = quregs['greater_than']
233+
234+
def _decomposition() -> Iterator[cirq.Operation]:
235+
yield and_gate.And((0, 1), adjoint=self.adjoint).on(*a, *b, *less_than)
236+
yield cirq.CNOT(*less_than, *greater_than)
237+
yield cirq.CNOT(*b, *greater_than)
238+
yield cirq.CNOT(*a, *b)
239+
yield cirq.CNOT(*a, *greater_than)
240+
yield cirq.X(*b)
241+
242+
if self.adjoint:
243+
yield from reversed(tuple(_decomposition()))
244+
else:
245+
yield from _decomposition()
246+
247+
def __pow__(self, power: int) -> cirq.Gate:
248+
if not isinstance(power, int):
249+
raise ValueError('SingleQubitCompare is only defined for integer powers.')
250+
if power % 2 == 0:
251+
return cirq.IdentityGate(4)
252+
if power < 0:
253+
return SingleQubitCompare(adjoint=not self.adjoint)
254+
return self
255+
256+
def _t_complexity_(self) -> infra.TComplexity:
257+
if self.adjoint:
258+
return infra.TComplexity(clifford=11)
259+
return infra.TComplexity(t=4, clifford=16)
260+
261+
262+
def _equality_with_zero(
263+
context: cirq.DecompositionContext, qubits: Sequence[cirq.Qid], z: cirq.Qid
264+
) -> cirq.OP_TREE:
265+
if len(qubits) == 1:
266+
(q,) = qubits
267+
yield cirq.X(q)
268+
yield cirq.CNOT(q, z)
269+
return
270+
271+
ancilla = context.qubit_manager.qalloc(len(qubits) - 2)
272+
yield and_gate.And(cv=[0] * len(qubits)).on(*qubits, *ancilla, z)
273+
274+
133275
@attr.frozen
134276
class LessThanEqualGate(cirq.ArithmeticGate):
135277
"""Applies U|x>|y>|z> = |x>|y> |z ^ (x <= y)>"""
@@ -161,9 +303,130 @@ def __pow__(self, power: int):
161303
def __repr__(self) -> str:
162304
return f'cirq_ft.LessThanEqualGate({self.x_bitsize}, {self.y_bitsize})'
163305

306+
def _decompose_via_tree(
307+
self, context: cirq.DecompositionContext, X: Sequence[cirq.Qid], Y: Sequence[cirq.Qid]
308+
) -> cirq.OP_TREE:
309+
"""Returns comparison oracle from https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf
310+
311+
This decomposition follows the tree structure of (FIG. 2)
312+
""" # pylint: disable=line-too-long
313+
if len(X) == 1:
314+
return
315+
if len(X) == 2:
316+
yield BiQubitsMixer().on_registers(x=X, y=Y, ancilla=context.qubit_manager.qalloc(3))
317+
return
318+
319+
m = len(X) // 2
320+
yield self._decompose_via_tree(context, X[:m], Y[:m])
321+
yield self._decompose_via_tree(context, X[m:], Y[m:])
322+
yield BiQubitsMixer().on_registers(
323+
x=(X[m - 1], X[-1]), y=(Y[m - 1], Y[-1]), ancilla=context.qubit_manager.qalloc(3)
324+
)
325+
326+
def _decompose_with_context_(
327+
self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None
328+
) -> cirq.OP_TREE:
329+
"""Decomposes the gate in a T-complexity optimal way.
330+
331+
The construction can be broken in 4 parts:
332+
1. In case of differing bitsizes then a multicontrol And Gate
333+
- Section III.A. https://arxiv.org/abs/1805.03662) is used to check whether
334+
the extra prefix is equal to zero:
335+
- result stored in: `prefix_equality` qubit.
336+
2. The tree structure (FIG. 2) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf
337+
followed by a SingleQubitCompare to compute the result of comparison of
338+
the suffixes of equal length:
339+
- result stored in: `less_than` and `greater_than` with equality in qubits[-2]
340+
3. The results from the previous two steps are combined to update the target qubit.
341+
4. The adjoint of the previous operations is added to restore the input qubits
342+
to their original state and clean the ancilla qubits.
343+
""" # pylint: disable=line-too-long
344+
345+
if context is None:
346+
context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager())
347+
348+
lhs, rhs, target = qubits[: self.x_bitsize], qubits[self.x_bitsize : -1], qubits[-1]
349+
350+
n = min(len(lhs), len(rhs))
351+
352+
prefix_equality = None
353+
adjoint: List[cirq.Operation] = []
354+
355+
# if one of the registers is longer than the other store equality with |0--0>
356+
# into `prefix_equality` using d = |len(P) - len(Q)| And operations => 4d T.
357+
if len(lhs) != len(rhs):
358+
(prefix_equality,) = context.qubit_manager.qalloc(1)
359+
if len(lhs) > len(rhs):
360+
for op in cirq.flatten_to_ops(
361+
_equality_with_zero(context, lhs[:-n], prefix_equality)
362+
):
363+
yield op
364+
adjoint.append(cirq.inverse(op))
365+
else:
366+
for op in cirq.flatten_to_ops(
367+
_equality_with_zero(context, rhs[:-n], prefix_equality)
368+
):
369+
yield op
370+
adjoint.append(cirq.inverse(op))
371+
372+
yield cirq.X(target), cirq.CNOT(prefix_equality, target)
373+
374+
# compare the remaing suffix of P and Q
375+
lhs = lhs[-n:]
376+
rhs = rhs[-n:]
377+
for op in cirq.flatten_to_ops(self._decompose_via_tree(context, lhs, rhs)):
378+
yield op
379+
adjoint.append(cirq.inverse(op))
380+
381+
less_than, greater_than = context.qubit_manager.qalloc(2)
382+
yield SingleQubitCompare().on_registers(
383+
a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than
384+
)
385+
adjoint.append(
386+
SingleQubitCompare(adjoint=True).on_registers(
387+
a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than
388+
)
389+
)
390+
391+
if prefix_equality is None:
392+
yield cirq.X(target)
393+
yield cirq.CNOT(greater_than, target)
394+
else:
395+
(less_than_or_equal,) = context.qubit_manager.qalloc(1)
396+
yield and_gate.And([1, 0]).on(prefix_equality, greater_than, less_than_or_equal)
397+
adjoint.append(
398+
and_gate.And([1, 0], adjoint=True).on(
399+
prefix_equality, greater_than, less_than_or_equal
400+
)
401+
)
402+
403+
yield cirq.CNOT(less_than_or_equal, target)
404+
405+
yield from reversed(adjoint)
406+
164407
def _t_complexity_(self) -> infra.TComplexity:
165-
# TODO(#112): This is rough cost that ignores cliffords.
166-
return infra.TComplexity(t=4 * (self.x_bitsize + self.y_bitsize))
408+
n = min(self.x_bitsize, self.y_bitsize)
409+
d = max(self.x_bitsize, self.y_bitsize) - n
410+
is_second_longer = self.y_bitsize > self.x_bitsize
411+
if d == 0:
412+
# When both registers are of the same size the T complexity is
413+
# 8n - 4 same as in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf. pylint: disable=line-too-long
414+
return infra.TComplexity(t=8 * n - 4, clifford=46 * n - 17)
415+
else:
416+
# When the registers differ in size and `n` is the size of the smaller one and
417+
# `d` is the difference in size. The T complexity is the sum of the tree
418+
# decomposition as before giving 8n + O(1) and the T complexity of an `And` gate
419+
# over `d` registers giving 4d + O(1) totaling 8n + 4d + O(1).
420+
# From the decomposition we get that the constant is -4 as well as the clifford counts.
421+
if d == 1:
422+
return infra.TComplexity(t=8 * n, clifford=46 * n + 3 + 2 * is_second_longer)
423+
else:
424+
return infra.TComplexity(
425+
t=8 * n + 4 * d - 4, clifford=46 * n + 17 * d - 14 + 2 * is_second_longer
426+
)
427+
428+
def _has_unitary_(self):
429+
return True
167430

168431

169432
@attr.frozen

0 commit comments

Comments
 (0)