Skip to content

Implemented 8n T complexity decomposition of LessThanEqual gate #6156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cirq-ft/cirq_ft/algos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
ContiguousRegisterGate,
LessThanEqualGate,
LessThanGate,
SingleQubitCompare,
BiQubitsMixer,
)
from cirq_ft.algos.generic_select import GenericSelect
from cirq_ft.algos.hubbard_model import PrepareHubbard, SelectHubbard
Expand Down
271 changes: 267 additions & 4 deletions cirq-ft/cirq_ft/algos/arithmetic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Optional, Sequence, Tuple, Union
from typing import Iterable, Optional, Sequence, Tuple, Union, List, Iterator

from cirq._compat import cached_property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: why can't we use functools.cached_property in cirq?

import attr
import cirq
from cirq_ft import infra
Expand Down Expand Up @@ -78,7 +79,7 @@ def _decompose_with_context_(
return
adjoint = []

[are_equal] = context.qubit_manager.qalloc(1)
(are_equal,) = context.qubit_manager.qalloc(1)

# Initially our belief is that the numbers are equal.
yield cirq.X(are_equal)
Expand Down Expand Up @@ -130,6 +131,147 @@ def _t_complexity_(self) -> infra.TComplexity:
)


@attr.frozen
class BiQubitsMixer(infra.GateWithRegisters):
"""Implements the COMPARE2 (Fig. 1) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf

This gates mixes the values in a way that preserves the result of comparison.
The registers being compared are 2-qubit registers where
x = 2*x_msb + x_lsb
y = 2*y_msb + y_lsb
The Gate mixes the 4 qubits so that sign(x - y) = sign(x_lsb' - y_lsb') where x_lsb' and y_lsb'
are the final values of x_lsb' and y_lsb'.
""" # pylint: disable=line-too-long

adjoint: bool = False

@cached_property
def registers(self) -> infra.Registers:
return infra.Registers.build(x=2, y=2, ancilla=3)

def __repr__(self) -> str:
return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})'

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
) -> cirq.OP_TREE:
x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla']
x_msb, x_lsb = x
y_msb, y_lsb = y

def _cswap(control: cirq.Qid, a: cirq.Qid, b: cirq.Qid, aux: cirq.Qid) -> cirq.OP_TREE:
"""A CSWAP with 4T complexity and whose adjoint has 0T complexity.

A controlled SWAP that swaps `a` and `b` based on `control`.
It uses an extra qubit `aux` so that its adjoint would have
a T complexity of zero.
"""
yield cirq.CNOT(a, b)
yield and_gate.And(adjoint=self.adjoint).on(control, b, aux)
yield cirq.CNOT(aux, a)
yield cirq.CNOT(a, b)

def _decomposition():
# computes the difference of x - y where
# x = 2*x_msb + x_lsb
# y = 2*y_msb + y_lsb
# And stores the result in x_lsb and y_lsb such that
# sign(x - y) = sign(x_lsb - y_lsb)
# This decomposition uses 3 ancilla qubits in order to have a
# T complexity of 8.
yield cirq.X(ancilla[0])
yield cirq.CNOT(y_msb, x_msb)
yield cirq.CNOT(y_lsb, x_lsb)
yield from _cswap(x_msb, x_lsb, ancilla[0], ancilla[1])
yield from _cswap(x_msb, y_msb, y_lsb, ancilla[2])
yield cirq.CNOT(y_lsb, x_lsb)

if self.adjoint:
yield from reversed(tuple(cirq.flatten_to_ops(_decomposition())))
else:
yield from _decomposition()

def __pow__(self, power: int) -> cirq.Gate:
if power == 1:
return self
if power == -1:
return BiQubitsMixer(adjoint=not self.adjoint)
return NotImplemented # coverage: ignore

def _t_complexity_(self) -> infra.TComplexity:
if self.adjoint:
return infra.TComplexity(clifford=18)
return infra.TComplexity(t=8, clifford=28)

def _has_unitary_(self):
return not self.adjoint


@attr.frozen
class SingleQubitCompare(infra.GateWithRegisters):
"""Applies U|a>|b>|0>|0> = |a> |a=b> |(a<b)> |(a>b)>

Source: (FIG. 3) in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf
""" # pylint: disable=line-too-long

adjoint: bool = False

@cached_property
def registers(self) -> infra.Registers:
return infra.Registers.build(a=1, b=1, less_than=1, greater_than=1)

def __repr__(self) -> str:
return f'cirq_ft.algos.SingleQubitCompare({self.adjoint})'

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
) -> cirq.OP_TREE:
a = quregs['a']
b = quregs['b']
less_than = quregs['less_than']
greater_than = quregs['greater_than']

def _decomposition() -> Iterator[cirq.Operation]:
yield and_gate.And((0, 1), adjoint=self.adjoint).on(*a, *b, *less_than)
yield cirq.CNOT(*less_than, *greater_than)
yield cirq.CNOT(*b, *greater_than)
yield cirq.CNOT(*a, *b)
yield cirq.CNOT(*a, *greater_than)
yield cirq.X(*b)

if self.adjoint:
yield from reversed(tuple(_decomposition()))
else:
yield from _decomposition()

def __pow__(self, power: int) -> cirq.Gate:
if not isinstance(power, int):
raise ValueError('SingleQubitCompare is only defined for integer powers.')
if power % 2 == 0:
return cirq.IdentityGate(4)
if power < 0:
return SingleQubitCompare(adjoint=not self.adjoint)
return self
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if power is 2*np.pi/3?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's only defined for integer powers.. raising an error otherwise


def _t_complexity_(self) -> infra.TComplexity:
if self.adjoint:
return infra.TComplexity(clifford=11)
return infra.TComplexity(t=4, clifford=16)


def _equality_with_zero(
context: cirq.DecompositionContext, qubits: Sequence[cirq.Qid], z: cirq.Qid
) -> cirq.OP_TREE:
if len(qubits) == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

describe what this subcircuit is doing

(q,) = qubits
yield cirq.X(q)
yield cirq.CNOT(q, z)
return

ancilla = context.qubit_manager.qalloc(len(qubits) - 2)
yield and_gate.And(cv=[0] * len(qubits)).on(*qubits, *ancilla, z)


@attr.frozen
class LessThanEqualGate(cirq.ArithmeticGate):
"""Applies U|x>|y>|z> = |x>|y> |z ^ (x <= y)>"""
Expand Down Expand Up @@ -161,9 +303,130 @@ def __pow__(self, power: int):
def __repr__(self) -> str:
return f'cirq_ft.LessThanEqualGate({self.x_bitsize}, {self.y_bitsize})'

def _decompose_via_tree(
self, context: cirq.DecompositionContext, X: Sequence[cirq.Qid], Y: Sequence[cirq.Qid]
) -> cirq.OP_TREE:
"""Returns comparison oracle from https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf

This decomposition follows the tree structure of (FIG. 2)
""" # pylint: disable=line-too-long
if len(X) == 1:
return
if len(X) == 2:
yield BiQubitsMixer().on_registers(x=X, y=Y, ancilla=context.qubit_manager.qalloc(3))
return

m = len(X) // 2
yield self._decompose_via_tree(context, X[:m], Y[:m])
yield self._decompose_via_tree(context, X[m:], Y[m:])
yield BiQubitsMixer().on_registers(
x=(X[m - 1], X[-1]), y=(Y[m - 1], Y[-1]), ancilla=context.qubit_manager.qalloc(3)
)

def _decompose_with_context_(
self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None
) -> cirq.OP_TREE:
"""Decomposes the gate in a T-complexity optimal way.

The construction can be broken in 4 parts:
1. In case of differing bitsizes then a multicontrol And Gate
- Section III.A. https://arxiv.org/abs/1805.03662) is used to check whether
the extra prefix is equal to zero:
- result stored in: `prefix_equality` qubit.
2. The tree structure (FIG. 2) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf
followed by a SingleQubitCompare to compute the result of comparison of
the suffixes of equal length:
- result stored in: `less_than` and `greater_than` with equality in qubits[-2]
3. The results from the previous two steps are combined to update the target qubit.
4. The adjoint of the previous operations is added to restore the input qubits
to their original state and clean the ancilla qubits.
""" # pylint: disable=line-too-long

if context is None:
context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager())

lhs, rhs, target = qubits[: self.x_bitsize], qubits[self.x_bitsize : -1], qubits[-1]

n = min(len(lhs), len(rhs))

prefix_equality = None
adjoint: List[cirq.Operation] = []

# if one of the registers is longer than the other store equality with |0--0>
# into `prefix_equality` using d = |len(P) - len(Q)| And operations => 4d T.
if len(lhs) != len(rhs):
(prefix_equality,) = context.qubit_manager.qalloc(1)
if len(lhs) > len(rhs):
for op in cirq.flatten_to_ops(
_equality_with_zero(context, lhs[:-n], prefix_equality)
):
yield op
adjoint.append(cirq.inverse(op))
else:
for op in cirq.flatten_to_ops(
_equality_with_zero(context, rhs[:-n], prefix_equality)
):
yield op
adjoint.append(cirq.inverse(op))

yield cirq.X(target), cirq.CNOT(prefix_equality, target)

# compare the remaing suffix of P and Q
lhs = lhs[-n:]
rhs = rhs[-n:]
for op in cirq.flatten_to_ops(self._decompose_via_tree(context, lhs, rhs)):
yield op
adjoint.append(cirq.inverse(op))

less_than, greater_than = context.qubit_manager.qalloc(2)
yield SingleQubitCompare().on_registers(
a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than
)
adjoint.append(
SingleQubitCompare(adjoint=True).on_registers(
a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than
)
)

if prefix_equality is None:
yield cirq.X(target)
yield cirq.CNOT(greater_than, target)
else:
(less_than_or_equal,) = context.qubit_manager.qalloc(1)
yield and_gate.And([1, 0]).on(prefix_equality, greater_than, less_than_or_equal)
adjoint.append(
and_gate.And([1, 0], adjoint=True).on(
prefix_equality, greater_than, less_than_or_equal
)
)

yield cirq.CNOT(less_than_or_equal, target)

yield from reversed(adjoint)

def _t_complexity_(self) -> infra.TComplexity:
# TODO(#112): This is rough cost that ignores cliffords.
return infra.TComplexity(t=4 * (self.x_bitsize + self.y_bitsize))
n = min(self.x_bitsize, self.y_bitsize)
d = max(self.x_bitsize, self.y_bitsize) - n
is_second_longer = self.y_bitsize > self.x_bitsize
if d == 0:
# When both registers are of the same size the T complexity is
# 8n - 4 same as in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf. pylint: disable=line-too-long
return infra.TComplexity(t=8 * n - 4, clifford=46 * n - 17)
else:
# When the registers differ in size and `n` is the size of the smaller one and
# `d` is the difference in size. The T complexity is the sum of the tree
# decomposition as before giving 8n + O(1) and the T complexity of an `And` gate
# over `d` registers giving 4d + O(1) totaling 8n + 4d + O(1).
# From the decomposition we get that the constant is -4 as well as the clifford counts.
if d == 1:
return infra.TComplexity(t=8 * n, clifford=46 * n + 3 + 2 * is_second_longer)
else:
return infra.TComplexity(
t=8 * n + 4 * d - 4, clifford=46 * n + 17 * d - 14 + 2 * is_second_longer
)

def _has_unitary_(self):
return True
Comment on lines +428 to +429
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to override this? shouldn't it figure it out from the fact that we have a decomposition?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the simulation tests fail if we don't. the simulators check that what it simulates is unitary with some measurements.

the cirq.has_unitary now fails on the decomposition because the decomposition uses non-unitary operations (all the And(adjoint=True)) but we don't have support for that check now except by simulation



@attr.frozen
Expand Down
Loading