-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Implemented 8n T complexity decomposition of LessThanEqual gate #6156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
50cf43c
c7a8ab5
a4a6f51
d955168
5b9b08e
c698264
e88d9cb
4e12089
d766bfa
c9c213d
e778f58
f641013
bced4fb
916ac5f
07f4277
accf212
3d51b3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,8 +12,9 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Iterable, Optional, Sequence, Tuple, Union | ||
from typing import Iterable, Optional, Sequence, Tuple, Union, List, Iterator | ||
|
||
from cirq._compat import cached_property | ||
import attr | ||
import cirq | ||
from cirq_ft import infra | ||
|
@@ -78,7 +79,7 @@ def _decompose_with_context_( | |
return | ||
adjoint = [] | ||
|
||
[are_equal] = context.qubit_manager.qalloc(1) | ||
(are_equal,) = context.qubit_manager.qalloc(1) | ||
|
||
# Initially our belief is that the numbers are equal. | ||
yield cirq.X(are_equal) | ||
|
@@ -130,6 +131,147 @@ def _t_complexity_(self) -> infra.TComplexity: | |
) | ||
|
||
|
||
@attr.frozen | ||
class BiQubitsMixer(infra.GateWithRegisters): | ||
"""Implements the COMPARE2 (Fig. 1) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf | ||
|
||
This gates mixes the values in a way that preserves the result of comparison. | ||
The registers being compared are 2-qubit registers where | ||
x = 2*x_msb + x_lsb | ||
y = 2*y_msb + y_lsb | ||
The Gate mixes the 4 qubits so that sign(x - y) = sign(x_lsb' - y_lsb') where x_lsb' and y_lsb' | ||
are the final values of x_lsb' and y_lsb'. | ||
""" # pylint: disable=line-too-long | ||
|
||
adjoint: bool = False | ||
|
||
@cached_property | ||
def registers(self) -> infra.Registers: | ||
return infra.Registers.build(x=2, y=2, ancilla=3) | ||
|
||
def __repr__(self) -> str: | ||
return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})' | ||
|
||
def decompose_from_registers( | ||
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] | ||
) -> cirq.OP_TREE: | ||
x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla'] | ||
NoureldinYosri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
x_msb, x_lsb = x | ||
y_msb, y_lsb = y | ||
|
||
def _cswap(control: cirq.Qid, a: cirq.Qid, b: cirq.Qid, aux: cirq.Qid) -> cirq.OP_TREE: | ||
"""A CSWAP with 4T complexity and whose adjoint has 0T complexity. | ||
|
||
A controlled SWAP that swaps `a` and `b` based on `control`. | ||
It uses an extra qubit `aux` so that its adjoint would have | ||
a T complexity of zero. | ||
""" | ||
yield cirq.CNOT(a, b) | ||
yield and_gate.And(adjoint=self.adjoint).on(control, b, aux) | ||
yield cirq.CNOT(aux, a) | ||
yield cirq.CNOT(a, b) | ||
|
||
def _decomposition(): | ||
NoureldinYosri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# computes the difference of x - y where | ||
# x = 2*x_msb + x_lsb | ||
# y = 2*y_msb + y_lsb | ||
# And stores the result in x_lsb and y_lsb such that | ||
# sign(x - y) = sign(x_lsb - y_lsb) | ||
# This decomposition uses 3 ancilla qubits in order to have a | ||
# T complexity of 8. | ||
yield cirq.X(ancilla[0]) | ||
yield cirq.CNOT(y_msb, x_msb) | ||
yield cirq.CNOT(y_lsb, x_lsb) | ||
yield from _cswap(x_msb, x_lsb, ancilla[0], ancilla[1]) | ||
yield from _cswap(x_msb, y_msb, y_lsb, ancilla[2]) | ||
yield cirq.CNOT(y_lsb, x_lsb) | ||
|
||
if self.adjoint: | ||
yield from reversed(tuple(cirq.flatten_to_ops(_decomposition()))) | ||
else: | ||
yield from _decomposition() | ||
|
||
def __pow__(self, power: int) -> cirq.Gate: | ||
if power == 1: | ||
return self | ||
if power == -1: | ||
return BiQubitsMixer(adjoint=not self.adjoint) | ||
return NotImplemented # coverage: ignore | ||
|
||
def _t_complexity_(self) -> infra.TComplexity: | ||
if self.adjoint: | ||
return infra.TComplexity(clifford=18) | ||
return infra.TComplexity(t=8, clifford=28) | ||
|
||
def _has_unitary_(self): | ||
return not self.adjoint | ||
|
||
|
||
@attr.frozen | ||
class SingleQubitCompare(infra.GateWithRegisters): | ||
"""Applies U|a>|b>|0>|0> = |a> |a=b> |(a<b)> |(a>b)> | ||
mpharrigan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Source: (FIG. 3) in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf | ||
""" # pylint: disable=line-too-long | ||
|
||
adjoint: bool = False | ||
|
||
@cached_property | ||
def registers(self) -> infra.Registers: | ||
return infra.Registers.build(a=1, b=1, less_than=1, greater_than=1) | ||
|
||
def __repr__(self) -> str: | ||
return f'cirq_ft.algos.SingleQubitCompare({self.adjoint})' | ||
NoureldinYosri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def decompose_from_registers( | ||
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] | ||
) -> cirq.OP_TREE: | ||
a = quregs['a'] | ||
b = quregs['b'] | ||
less_than = quregs['less_than'] | ||
greater_than = quregs['greater_than'] | ||
|
||
def _decomposition() -> Iterator[cirq.Operation]: | ||
yield and_gate.And((0, 1), adjoint=self.adjoint).on(*a, *b, *less_than) | ||
yield cirq.CNOT(*less_than, *greater_than) | ||
yield cirq.CNOT(*b, *greater_than) | ||
yield cirq.CNOT(*a, *b) | ||
yield cirq.CNOT(*a, *greater_than) | ||
yield cirq.X(*b) | ||
|
||
if self.adjoint: | ||
yield from reversed(tuple(_decomposition())) | ||
else: | ||
yield from _decomposition() | ||
|
||
def __pow__(self, power: int) -> cirq.Gate: | ||
if not isinstance(power, int): | ||
raise ValueError('SingleQubitCompare is only defined for integer powers.') | ||
if power % 2 == 0: | ||
return cirq.IdentityGate(4) | ||
if power < 0: | ||
return SingleQubitCompare(adjoint=not self.adjoint) | ||
return self | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if power is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's only defined for integer powers.. raising an error otherwise |
||
|
||
def _t_complexity_(self) -> infra.TComplexity: | ||
if self.adjoint: | ||
return infra.TComplexity(clifford=11) | ||
return infra.TComplexity(t=4, clifford=16) | ||
|
||
|
||
def _equality_with_zero( | ||
context: cirq.DecompositionContext, qubits: Sequence[cirq.Qid], z: cirq.Qid | ||
) -> cirq.OP_TREE: | ||
if len(qubits) == 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. describe what this subcircuit is doing |
||
(q,) = qubits | ||
yield cirq.X(q) | ||
yield cirq.CNOT(q, z) | ||
return | ||
|
||
ancilla = context.qubit_manager.qalloc(len(qubits) - 2) | ||
yield and_gate.And(cv=[0] * len(qubits)).on(*qubits, *ancilla, z) | ||
|
||
|
||
@attr.frozen | ||
class LessThanEqualGate(cirq.ArithmeticGate): | ||
"""Applies U|x>|y>|z> = |x>|y> |z ^ (x <= y)>""" | ||
|
@@ -161,9 +303,130 @@ def __pow__(self, power: int): | |
def __repr__(self) -> str: | ||
return f'cirq_ft.LessThanEqualGate({self.x_bitsize}, {self.y_bitsize})' | ||
|
||
def _decompose_via_tree( | ||
self, context: cirq.DecompositionContext, X: Sequence[cirq.Qid], Y: Sequence[cirq.Qid] | ||
NoureldinYosri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> cirq.OP_TREE: | ||
"""Returns comparison oracle from https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf | ||
|
||
This decomposition follows the tree structure of (FIG. 2) | ||
""" # pylint: disable=line-too-long | ||
if len(X) == 1: | ||
return | ||
if len(X) == 2: | ||
yield BiQubitsMixer().on_registers(x=X, y=Y, ancilla=context.qubit_manager.qalloc(3)) | ||
return | ||
|
||
m = len(X) // 2 | ||
yield self._decompose_via_tree(context, X[:m], Y[:m]) | ||
yield self._decompose_via_tree(context, X[m:], Y[m:]) | ||
yield BiQubitsMixer().on_registers( | ||
x=(X[m - 1], X[-1]), y=(Y[m - 1], Y[-1]), ancilla=context.qubit_manager.qalloc(3) | ||
) | ||
|
||
def _decompose_with_context_( | ||
NoureldinYosri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None | ||
) -> cirq.OP_TREE: | ||
"""Decomposes the gate in a T-complexity optimal way. | ||
|
||
The construction can be broken in 4 parts: | ||
1. In case of differing bitsizes then a multicontrol And Gate | ||
- Section III.A. https://arxiv.org/abs/1805.03662) is used to check whether | ||
the extra prefix is equal to zero: | ||
- result stored in: `prefix_equality` qubit. | ||
2. The tree structure (FIG. 2) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf | ||
followed by a SingleQubitCompare to compute the result of comparison of | ||
the suffixes of equal length: | ||
- result stored in: `less_than` and `greater_than` with equality in qubits[-2] | ||
3. The results from the previous two steps are combined to update the target qubit. | ||
4. The adjoint of the previous operations is added to restore the input qubits | ||
to their original state and clean the ancilla qubits. | ||
""" # pylint: disable=line-too-long | ||
|
||
if context is None: | ||
context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) | ||
|
||
lhs, rhs, target = qubits[: self.x_bitsize], qubits[self.x_bitsize : -1], qubits[-1] | ||
|
||
n = min(len(lhs), len(rhs)) | ||
|
||
prefix_equality = None | ||
adjoint: List[cirq.Operation] = [] | ||
|
||
# if one of the registers is longer than the other store equality with |0--0> | ||
# into `prefix_equality` using d = |len(P) - len(Q)| And operations => 4d T. | ||
if len(lhs) != len(rhs): | ||
(prefix_equality,) = context.qubit_manager.qalloc(1) | ||
if len(lhs) > len(rhs): | ||
for op in cirq.flatten_to_ops( | ||
_equality_with_zero(context, lhs[:-n], prefix_equality) | ||
): | ||
yield op | ||
adjoint.append(cirq.inverse(op)) | ||
else: | ||
for op in cirq.flatten_to_ops( | ||
_equality_with_zero(context, rhs[:-n], prefix_equality) | ||
): | ||
yield op | ||
adjoint.append(cirq.inverse(op)) | ||
|
||
yield cirq.X(target), cirq.CNOT(prefix_equality, target) | ||
|
||
# compare the remaing suffix of P and Q | ||
lhs = lhs[-n:] | ||
rhs = rhs[-n:] | ||
for op in cirq.flatten_to_ops(self._decompose_via_tree(context, lhs, rhs)): | ||
yield op | ||
adjoint.append(cirq.inverse(op)) | ||
|
||
less_than, greater_than = context.qubit_manager.qalloc(2) | ||
yield SingleQubitCompare().on_registers( | ||
a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than | ||
) | ||
adjoint.append( | ||
SingleQubitCompare(adjoint=True).on_registers( | ||
a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than | ||
) | ||
) | ||
|
||
if prefix_equality is None: | ||
yield cirq.X(target) | ||
yield cirq.CNOT(greater_than, target) | ||
else: | ||
(less_than_or_equal,) = context.qubit_manager.qalloc(1) | ||
yield and_gate.And([1, 0]).on(prefix_equality, greater_than, less_than_or_equal) | ||
adjoint.append( | ||
and_gate.And([1, 0], adjoint=True).on( | ||
prefix_equality, greater_than, less_than_or_equal | ||
) | ||
) | ||
|
||
yield cirq.CNOT(less_than_or_equal, target) | ||
|
||
yield from reversed(adjoint) | ||
|
||
def _t_complexity_(self) -> infra.TComplexity: | ||
# TODO(#112): This is rough cost that ignores cliffords. | ||
return infra.TComplexity(t=4 * (self.x_bitsize + self.y_bitsize)) | ||
n = min(self.x_bitsize, self.y_bitsize) | ||
d = max(self.x_bitsize, self.y_bitsize) - n | ||
is_second_longer = self.y_bitsize > self.x_bitsize | ||
if d == 0: | ||
# When both registers are of the same size the T complexity is | ||
# 8n - 4 same as in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf. pylint: disable=line-too-long | ||
return infra.TComplexity(t=8 * n - 4, clifford=46 * n - 17) | ||
else: | ||
# When the registers differ in size and `n` is the size of the smaller one and | ||
# `d` is the difference in size. The T complexity is the sum of the tree | ||
# decomposition as before giving 8n + O(1) and the T complexity of an `And` gate | ||
# over `d` registers giving 4d + O(1) totaling 8n + 4d + O(1). | ||
# From the decomposition we get that the constant is -4 as well as the clifford counts. | ||
if d == 1: | ||
return infra.TComplexity(t=8 * n, clifford=46 * n + 3 + 2 * is_second_longer) | ||
else: | ||
return infra.TComplexity( | ||
t=8 * n + 4 * d - 4, clifford=46 * n + 17 * d - 14 + 2 * is_second_longer | ||
) | ||
|
||
def _has_unitary_(self): | ||
return True | ||
Comment on lines
+428
to
+429
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need to override this? shouldn't it figure it out from the fact that we have a decomposition? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bump There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the simulation tests fail if we don't. the simulators check that what it simulates is unitary with some measurements. the cirq.has_unitary now fails on the decomposition because the decomposition uses non-unitary operations (all the And(adjoint=True)) but we don't have support for that check now except by simulation |
||
|
||
|
||
@attr.frozen | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question: why can't we use
functools.cached_property
in cirq?