Skip to content

Implemented 8n T complexity decomposition of LessThanEqual gate #6156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 159 additions & 3 deletions cirq-ft/cirq_ft/algos/arithmetic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Optional, Sequence, Tuple, Union
from typing import Iterable, Optional, Sequence, Tuple, Union, List

import attr
import cirq
Expand Down Expand Up @@ -130,6 +130,67 @@ def _t_complexity_(self) -> infra.TComplexity:
)


def mix_double_qubit_registers(
context: cirq.DecompositionContext,
msb: Tuple[cirq.Qid, cirq.Qid],
lsb: Tuple[cirq.Qid, cirq.Qid],
) -> cirq.OP_TREE:
"""Generates the COMPARE2 circuit from https://www.nature.com/articles/s41534-018-0071-5#Sec8"""
[ancilla] = context.qubit_manager.qalloc(1)
x_1, y_1 = msb
x_0, y_0 = lsb

def _cswap(c: cirq.Qid, a: cirq.Qid, b: cirq.Qid) -> cirq.OP_TREE:
[q] = context.qubit_manager.qalloc(1)
yield cirq.CNOT(a, b)
yield and_gate.And().on(c, b, q)
yield cirq.CNOT(q, a)
yield cirq.CNOT(a, b)

yield cirq.X(ancilla)
yield cirq.CNOT(y_1, x_1)
yield cirq.CNOT(y_0, x_0)
yield _cswap(x_1, x_0, ancilla)
yield _cswap(x_1, y_1, y_0)
yield cirq.CNOT(y_0, x_0)


def compare_qubits(
x: cirq.Qid, y: cirq.Qid, less_than: cirq.Qid, greater_than: cirq.Qid
) -> cirq.OP_TREE:
"""Generates the comparison circuit from https://www.nature.com/articles/s41534-018-0071-5#Sec8

Args:
x: first qubit of the comparison and stays the same after circuit execution.
y: second qubit of the comparison.
This qubit will store equality value `x==y` after circuit execution.
less_than: Assumed to be in zero state. Will store `x < y`.
greater_than: Assumed to be in zero state. Will store `x > y`.

Returns:
The circuit in (Fig. 3) in https://www.nature.com/articles/s41534-018-0071-5#Sec8
"""

yield and_gate.And([0, 1]).on(x, y, less_than)
yield cirq.CNOT(less_than, greater_than)
yield cirq.CNOT(y, greater_than)
yield cirq.CNOT(x, y)
yield cirq.CNOT(x, greater_than)
yield cirq.X(y)


def _equality_with_zero(
context: cirq.DecompositionContext, X: Tuple[cirq.Qid, ...], z: cirq.Qid
) -> cirq.OP_TREE:
if len(X) == 1:
yield cirq.X(X[0])
yield cirq.CNOT(X[0], z)
return

ancilla = context.qubit_manager.qalloc(len(X) - 2)
yield and_gate.And(cv=[0] * len(X)).on(*X, *ancilla, z)


@attr.frozen
class LessThanEqualGate(cirq.ArithmeticGate):
"""Applies U|x>|y>|z> = |x>|y> |z ^ (x <= y)>"""
Expand Down Expand Up @@ -161,9 +222,104 @@ def __pow__(self, power: int):
def __repr__(self) -> str:
return f'cirq_ft.LessThanEqualGate({self.x_bitsize}, {self.y_bitsize})'

def _decompose_via_tree(
self, context: cirq.DecompositionContext, X: Tuple[cirq.Qid, ...], Y: Tuple[cirq.Qid, ...]
) -> cirq.OP_TREE:
"""Returns comparison oracle from https://www.nature.com/articles/s41534-018-0071-5#Sec8"""
if len(X) == 1:
return
if len(X) == 2:
yield mix_double_qubit_registers(context, (X[0], Y[0]), (X[1], Y[1]))
return

m = len(X) // 2
yield self._decompose_via_tree(context, X[:m], Y[:m])
yield self._decompose_via_tree(context, X[m:], Y[m:])
yield mix_double_qubit_registers(context, (X[m - 1], Y[m - 1]), (X[-1], Y[-1]))

def _decompose_with_context_(
self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None
) -> cirq.OP_TREE:
if context is None:
context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager())

P, Q, target = (qubits[: self.x_bitsize], qubits[self.x_bitsize : -1], qubits[-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use pep8 variable names here

why are we not using GateWithRegisters?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this gate was created as a cirq.ArithmeticGate, should I change that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be done later, but I think would work nicely here since you manually have to segregate the qubits into registers during decomposte at present.


n = min(len(P), len(Q))

prefix_equality = None
adjoint: List[cirq.Operation] = []

# if one of the registers is longer than the other store equality with |0--0>
# into `prefix_equality` using d = |len(P) - len(Q)| And operations => 4d T.
if len(P) != len(Q):
[prefix_equality] = context.qubit_manager.qalloc(1)
if len(P) > len(Q):
decomposition = tuple(
cirq.flatten_to_ops(
_equality_with_zero(context, tuple(P[:-n]), prefix_equality)
)
)
yield from decomposition
adjoint.extend(cirq.inverse(op) for op in decomposition)
else:
decomposition = tuple(
cirq.flatten_to_ops(
_equality_with_zero(context, tuple(Q[:-n]), prefix_equality)
)
)
yield from decomposition
adjoint.extend(cirq.inverse(op) for op in decomposition)

yield cirq.X(target), cirq.CNOT(prefix_equality, target)

# compare the remaing suffix of P and Q
P = P[-n:]
Q = Q[-n:]
decomposition = tuple(
cirq.flatten_to_ops(self._decompose_via_tree(context, tuple(P), tuple(Q)))
)
adjoint.extend(cirq.inverse(op) for op in decomposition)
yield from decomposition

less_than, greater_than = context.qubit_manager.qalloc(2)
decomposition = tuple(
cirq.flatten_to_ops(compare_qubits(P[-1], Q[-1], less_than, greater_than))
)
adjoint.extend(cirq.inverse(op) for op in decomposition)
yield from decomposition

if prefix_equality is None:
yield cirq.X(target)
yield cirq.CNOT(greater_than, target)
else:
[less_than_or_equal] = context.qubit_manager.qalloc(1)
yield and_gate.And([1, 0]).on(prefix_equality, greater_than, less_than_or_equal)
adjoint.append(
and_gate.And([1, 0], adjoint=True).on(
prefix_equality, greater_than, less_than_or_equal
)
)

yield cirq.CNOT(less_than_or_equal, target)

yield from reversed(adjoint)

def _t_complexity_(self) -> infra.TComplexity:
# TODO(#112): This is rough cost that ignores cliffords.
return infra.TComplexity(t=4 * (self.x_bitsize + self.y_bitsize))
n = min(self.x_bitsize, self.y_bitsize)
d = max(self.x_bitsize, self.y_bitsize) - n
is_second_longer = self.y_bitsize > self.x_bitsize
if d == 0:
return infra.TComplexity(t=8 * n - 4, clifford=46 * n - 17)
elif d == 1:
return infra.TComplexity(t=8 * n, clifford=46 * n + 3 + 2 * is_second_longer)
else:
return infra.TComplexity(
t=8 * n + 4 * d - 4, clifford=46 * n + 17 * d - 14 + 2 * is_second_longer
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please provide a short reference to where these numbers come from. Are they in the paper? Or is it based on the decomposition?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


def _has_unitary_(self):
return True
Comment on lines +428 to +429
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to override this? shouldn't it figure it out from the fact that we have a decomposition?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the simulation tests fail if we don't. the simulators check that what it simulates is unitary with some measurements.

the cirq.has_unitary now fails on the decomposition because the decomposition uses non-unitary operations (all the And(adjoint=True)) but we don't have support for that check now except by simulation



@attr.frozen
Expand Down
19 changes: 17 additions & 2 deletions cirq-ft/cirq_ft/algos/arithmetic_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def test_multi_in_less_equal_than_gate():
cirq.testing.assert_equivalent_computational_basis_map(identity_map(len(qubits)), circuit)


@pytest.mark.parametrize("x_bitsize", [2, 3])
@pytest.mark.parametrize("y_bitsize", [2, 3])
@pytest.mark.parametrize("x_bitsize", [*range(1, 5)])
@pytest.mark.parametrize("y_bitsize", [*range(1, 5)])
def test_less_than_equal_consistent_protocols(x_bitsize: int, y_bitsize: int):
g = cirq_ft.LessThanEqualGate(x_bitsize, y_bitsize)
cirq_ft.testing.assert_decompose_is_consistent_with_t_complexity(g)
Expand Down Expand Up @@ -322,3 +322,18 @@ def test_add_no_decompose(a, b):
assert true_out_int == int(out_bin, 2)
basis_map[input_int] = output_int
cirq.testing.assert_equivalent_computational_basis_map(basis_map, circuit)


@pytest.mark.parametrize("P,n", [(v, n) for n in range(1, 4) for v in range(1 << n)])
@pytest.mark.parametrize("Q,m", [(v, n) for n in range(1, 4) for v in range(1 << n)])
def test_decompose_less_than_equal_gate(P: int, n: int, Q: int, m: int):
qubit_states = list(bit_tools.iter_bits(P, n)) + list(bit_tools.iter_bits(Q, m))
circuit = cirq.Circuit(
cirq.decompose_once(cirq_ft.LessThanEqualGate(n, m).on(*cirq.LineQubit.range(n + m + 1)))
)
num_ancillas = len(circuit.all_qubits()) - n - m - 1
initial_state = [0] * num_ancillas + qubit_states + [0]
output_state = [0] * num_ancillas + qubit_states + [int(P <= Q)]
cirq_ft.testing.assert_circuit_inp_out_cirqsim(
circuit, sorted(circuit.all_qubits()), initial_state, output_state
)
13 changes: 8 additions & 5 deletions cirq-ft/cirq_ft/algos/state_preparation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,25 @@ def construct_gate_helper_and_qubit_order(data, eps):
context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager())

def map_func(op: cirq.Operation, _):
gateset = cirq.Gateset(cirq_ft.And)
gateset = cirq.Gateset(cirq_ft.And, cirq_ft.LessThanEqualGate, cirq_ft.LessThanGate)
return cirq.Circuit(
cirq.decompose(op, on_stuck_raise=None, keep=gateset.validate, context=context)
)

# TODO: Do not decompose `cq.And` because the `cq.map_clean_and_borrowable_qubits` currently
# gets confused and is not able to re-map qubits optimally; which results in a higher number
# of ancillas and thus the tests fails due to OOO.
# TODO: Do not decompose {cq.And, cq.LessThanEqualGate, cq.LessThanGate} because the
# `cq.map_clean_and_borrowable_qubits` currently gets confused and is not able to re-map qubits
# optimally; which results in a higher number of ancillas thus the tests fails due to OOO.
Comment on lines +38 to +40
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an issue open for this? we definitely should have one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created #6197

decomposed_circuit = cirq.map_operations_and_unroll(
g.circuit, map_func, raise_if_add_qubits=False
)
greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", size=25, maximize_reuse=True)
decomposed_circuit = cirq_ft.map_clean_and_borrowable_qubits(decomposed_circuit, qm=greedy_mm)
# We are fine decomposing the `cq.And` gates once the qubit re-mapping is complete. Ideally,
# we shouldn't require this two step process.
decomposed_circuit = cirq.Circuit(cirq.decompose(decomposed_circuit))
arithmetic_gateset = cirq.Gateset(cirq_ft.LessThanEqualGate, cirq_ft.LessThanGate)
decomposed_circuit = cirq.Circuit(
cirq.decompose(decomposed_circuit, keep=arithmetic_gateset.validate, on_stuck_raise=None)
)
ordered_input = list(itertools.chain(*g.quregs.values()))
qubit_order = cirq.QubitOrder.explicit(ordered_input, fallback=cirq.QubitOrder.DEFAULT)
return g, qubit_order, decomposed_circuit
Expand Down