From 50cf43cf7b74b16b4f0ac9d44d4b1fbd2eb25cd9 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Thu, 22 Jun 2023 16:34:08 +0100 Subject: [PATCH 01/12] Implemented 8n T complexity decomposition of LessThanEqual gate --- cirq-ft/cirq_ft/algos/arithmetic_gates.py | 163 +++++++++++++++++- .../cirq_ft/algos/arithmetic_gates_test.py | 19 +- .../cirq_ft/algos/state_preparation_test.py | 13 +- 3 files changed, 185 insertions(+), 10 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index bd68a02023e..274733232b4 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Optional, Sequence, Tuple, Union +from typing import Iterable, Optional, Sequence, Tuple, Union, Iterator, cast import attr import cirq @@ -130,6 +130,53 @@ def _t_complexity_(self) -> infra.TComplexity: ) +def mix_double_qubit_registers( + context: cirq.DecompositionContext, x: Tuple[cirq.Qid, cirq.Qid], y: Tuple[cirq.Qid, cirq.Qid] +) -> cirq.OP_TREE: + """Generates the COMPARE2 circuit from https://www.nature.com/articles/s41534-018-0071-5#Sec8""" + [ancilla] = context.qubit_manager.qalloc(1) + x_1, x_0 = x + y_1, y_0 = y + + def _cswap(c: cirq.Qid, a: cirq.Qid, b: cirq.Qid) -> Iterator[cirq.Operation]: + [q] = context.qubit_manager.qalloc(1) + yield cirq.CNOT(a, b) + yield and_gate.And().on(c, b, q) + yield cirq.CNOT(q, a) + yield cirq.CNOT(a, b) + + yield cirq.X(ancilla) + yield cirq.CNOT(y_1, x_1) + yield cirq.CNOT(y_0, x_0) + yield from _cswap(x_1, x_0, ancilla) + yield from _cswap(x_1, y_1, y_0) + yield cirq.CNOT(y_0, x_0) + + +def compare_qubits( + x: cirq.Qid, y: cirq.Qid, less_than: cirq.Qid, greater_than: cirq.Qid +) -> cirq.OP_TREE: + """Generates the comparison circuit from https://www.nature.com/articles/s41534-018-0071-5#Sec8 + + Args: + x: first qubit of the comparison and stays the same after circuit execution. + y: second qubit of the comparison. + This qubit will store equality value `x==y` after circuit execution. + less_than: Assumed to be in zero state. Will store `x < y`. + greater_than: Assumed to be in zero state. Will store `x > y`. + + Returns: + The circuit in (Fig. 3) in https://www.nature.com/articles/s41534-018-0071-5#Sec8 + """ + + yield and_gate.And([0, 1]).on(x, y, less_than) + yield cirq.CNOT(less_than, greater_than) + yield cirq.CNOT(y, greater_than) + yield cirq.CNOT(x, y) + yield cirq.CNOT(x, greater_than) + yield cirq.X(y) + + @attr.frozen class LessThanEqualGate(cirq.ArithmeticGate): """Applies U|x>|y>|z> = |x>|y> |z ^ (x <= y)>""" @@ -161,9 +208,119 @@ def __pow__(self, power: int): def __repr__(self) -> str: return f'cirq_ft.LessThanEqualGate({self.x_bitsize}, {self.y_bitsize})' + def _decompose_via_tree( + self, context: cirq.DecompositionContext, X: Tuple[cirq.Qid, ...], Y: Tuple[cirq.Qid, ...] + ) -> Tuple[Tuple[cirq.Operation, ...], Tuple[cirq.Qid, cirq.Qid]]: + """Returns comparison oracle from https://www.nature.com/articles/s41534-018-0071-5#Sec8""" + assert len(X) == len(Y), f'{len(X)=} != {len(Y)=}' + if len(X) == 1: + return (), (X[0], Y[0]) + if len(X) == 2: + X = cast(Tuple[cirq.Qid, cirq.Qid], X) + Y = cast(Tuple[cirq.Qid, cirq.Qid], Y) + return tuple(cirq.flatten_to_ops(mix_double_qubit_registers(context, X, Y))), ( + X[1], + Y[1], + ) + + m = len(X) // 2 + op_left, ql = self._decompose_via_tree(context, X[:m], Y[:m]) + op_right, qr = self._decompose_via_tree(context, X[m:], Y[m:]) + return op_left + op_right + tuple( + cirq.flatten_to_ops(mix_double_qubit_registers(context, ql, qr)) + ), (ql[1], qr[1]) + + def _decompose_with_context_( + self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None + ) -> cirq.OP_TREE: + if context is None: + context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) + + P, Q, target = (qubits[: self.x_bitsize], qubits[self.x_bitsize : -1], qubits[-1]) + + n = min(len(P), len(Q)) + + equal_so_far = None + adjoint = [] + + # if one of the registers is longer than the other compute store equality value + # into `equal_so_far` using d = |len(P) - len(Q)| And operations => 4d T. + if abs(len(P) - len(Q)) == 1: + [equal_so_far] = context.qubit_manager.qalloc(1) + yield cirq.X(equal_so_far) + adjoint.append(cirq.X(equal_so_far)) + + if len(P) > len(Q): + yield cirq.CNOT(P[0], equal_so_far) + adjoint.append(cirq.CNOT(P[0], equal_so_far)) + else: + yield cirq.CNOT(Q[0], equal_so_far) + adjoint.append(cirq.CNOT(Q[0], equal_so_far)) + + yield cirq.CNOT(Q[0], target) + elif len(P) > len(Q): + [equal_so_far] = context.qubit_manager.qalloc(1) + + m = len(P) - n + ancilla = context.qubit_manager.qalloc(m - 2) + yield and_gate.And(cv=[0] * m).on(*P[:m], *ancilla, equal_so_far) + adjoint.append( + and_gate.And(cv=[0] * m, adjoint=True).on(*P[:m], *ancilla, equal_so_far) + ) + + elif len(P) < len(Q): + [equal_so_far] = context.qubit_manager.qalloc(1) + + m = len(Q) - n + ancilla = context.qubit_manager.qalloc(m - 2) + yield and_gate.And(cv=[0] * m)(*Q[:m], *ancilla, equal_so_far) + adjoint.append(and_gate.And(cv=[0] * m, adjoint=True)(*Q[:m], *ancilla, equal_so_far)) + + yield cirq.X(target), cirq.CNOT(equal_so_far, target) + + # compare the remaing suffix of P and Q + P = P[-n:] + Q = Q[-n:] + decomposition, (x, y) = self._decompose_via_tree(context, tuple(P), tuple(Q)) + yield from decomposition + adjoint.extend(cirq.inverse(op) for op in decomposition) + + less_than, greater_than = context.qubit_manager.qalloc(2) + decomposition = tuple(cirq.flatten_to_ops(compare_qubits(x, y, less_than, greater_than))) + yield from decomposition + adjoint.extend(cirq.inverse(op) for op in decomposition) + + if equal_so_far is None: + yield cirq.CNOT(greater_than, target) + yield cirq.X(target) + else: + [less_than_or_equal] = context.qubit_manager.qalloc(1) + yield and_gate.And([1, 0]).on(equal_so_far, greater_than, less_than_or_equal) + adjoint.append( + and_gate.And([1, 0], adjoint=True).on( + equal_so_far, greater_than, less_than_or_equal + ) + ) + + yield cirq.CNOT(less_than_or_equal, target) + + yield from reversed(adjoint) + def _t_complexity_(self) -> infra.TComplexity: - # TODO(#112): This is rough cost that ignores cliffords. - return infra.TComplexity(t=4 * (self.x_bitsize + self.y_bitsize)) + n = min(self.x_bitsize, self.y_bitsize) + d = max(self.x_bitsize, self.y_bitsize) - n + is_second_longer = self.y_bitsize > self.x_bitsize + if d == 0: + return infra.TComplexity(t=8 * n - 4, clifford=46 * n - 17) + elif d == 1: + return infra.TComplexity(t=8 * n, clifford=46 * n + 3 + is_second_longer) + else: + return infra.TComplexity( + t=8 * n + 4 * d - 4, clifford=46 * n + 17 * d - 14 + 2 * is_second_longer + ) + + def _has_unitary_(self): + return True @attr.frozen diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py index 204f3007502..bdcdd442054 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py @@ -112,8 +112,8 @@ def test_multi_in_less_equal_than_gate(): cirq.testing.assert_equivalent_computational_basis_map(identity_map(len(qubits)), circuit) -@pytest.mark.parametrize("x_bitsize", [2, 3]) -@pytest.mark.parametrize("y_bitsize", [2, 3]) +@pytest.mark.parametrize("x_bitsize", [*range(1, 5)]) +@pytest.mark.parametrize("y_bitsize", [*range(1, 5)]) def test_less_than_equal_consistent_protocols(x_bitsize: int, y_bitsize: int): g = cirq_ft.LessThanEqualGate(x_bitsize, y_bitsize) cirq_ft.testing.assert_decompose_is_consistent_with_t_complexity(g) @@ -322,3 +322,18 @@ def test_add_no_decompose(a, b): assert true_out_int == int(out_bin, 2) basis_map[input_int] = output_int cirq.testing.assert_equivalent_computational_basis_map(basis_map, circuit) + + +@pytest.mark.parametrize("P,n", [(v, n) for n in range(1, 3) for v in range(1 << n)]) +@pytest.mark.parametrize("Q,m", [(v, n) for n in range(1, 3) for v in range(1 << n)]) +def test_decompose_less_than_equal_gate(P: int, n: int, Q: int, m: int): + qubit_states = list(bit_tools.iter_bits(P, n)) + list(bit_tools.iter_bits(Q, m)) + circuit = cirq.Circuit( + cirq.decompose_once(cirq_ft.LessThanEqualGate(n, m).on(*cirq.LineQubit.range(n + m + 1))) + ) + num_ancillas = len(circuit.all_qubits()) - n - m - 1 + initial_state = [0] * num_ancillas + qubit_states + [0] + output_state = [0] * num_ancillas + qubit_states + [int(P <= Q)] + cirq_ft.testing.assert_circuit_inp_out_cirqsim( + circuit, sorted(circuit.all_qubits()), initial_state, output_state + ) diff --git a/cirq-ft/cirq_ft/algos/state_preparation_test.py b/cirq-ft/cirq_ft/algos/state_preparation_test.py index 01586a4b545..d1695a16424 100644 --- a/cirq-ft/cirq_ft/algos/state_preparation_test.py +++ b/cirq-ft/cirq_ft/algos/state_preparation_test.py @@ -30,14 +30,14 @@ def construct_gate_helper_and_qubit_order(data, eps): context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) def map_func(op: cirq.Operation, _): - gateset = cirq.Gateset(cirq_ft.And) + gateset = cirq.Gateset(cirq_ft.And, cirq_ft.LessThanEqualGate, cirq_ft.LessThanGate) return cirq.Circuit( cirq.decompose(op, on_stuck_raise=None, keep=gateset.validate, context=context) ) - # TODO: Do not decompose `cq.And` because the `cq.map_clean_and_borrowable_qubits` currently - # gets confused and is not able to re-map qubits optimally; which results in a higher number - # of ancillas and thus the tests fails due to OOO. + # TODO: Do not decompose {cq.And, cq.LessThanEqualGate, cq.LessThanGate} because the + # `cq.map_clean_and_borrowable_qubits` currently gets confused and is not able to re-map qubits + # optimally; which results in a higher number of ancillas thus the tests fails due to OOO. decomposed_circuit = cirq.map_operations_and_unroll( g.circuit, map_func, raise_if_add_qubits=False ) @@ -45,7 +45,10 @@ def map_func(op: cirq.Operation, _): decomposed_circuit = cirq_ft.map_clean_and_borrowable_qubits(decomposed_circuit, qm=greedy_mm) # We are fine decomposing the `cq.And` gates once the qubit re-mapping is complete. Ideally, # we shouldn't require this two step process. - decomposed_circuit = cirq.Circuit(cirq.decompose(decomposed_circuit)) + arithmetic_gateset = cirq.Gateset(cirq_ft.LessThanEqualGate, cirq_ft.LessThanGate) + decomposed_circuit = cirq.Circuit( + cirq.decompose(decomposed_circuit, keep=arithmetic_gateset.validate, on_stuck_raise=None) + ) ordered_input = list(itertools.chain(*g.quregs.values())) qubit_order = cirq.QubitOrder.explicit(ordered_input, fallback=cirq.QubitOrder.DEFAULT) return g, qubit_order, decomposed_circuit From d955168c864683e951b67241b1d7380d3bc67ec8 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Fri, 23 Jun 2023 17:29:00 +0100 Subject: [PATCH 02/12] fixed bug in implementation --- cirq-ft/cirq_ft/algos/arithmetic_gates.py | 124 +++++++++--------- .../cirq_ft/algos/arithmetic_gates_test.py | 6 +- 2 files changed, 66 insertions(+), 64 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index 274733232b4..100531abd9b 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Optional, Sequence, Tuple, Union, Iterator, cast +from typing import Iterable, Optional, Sequence, Tuple, Union, List import attr import cirq @@ -131,14 +131,16 @@ def _t_complexity_(self) -> infra.TComplexity: def mix_double_qubit_registers( - context: cirq.DecompositionContext, x: Tuple[cirq.Qid, cirq.Qid], y: Tuple[cirq.Qid, cirq.Qid] + context: cirq.DecompositionContext, + msb: Tuple[cirq.Qid, cirq.Qid], + lsb: Tuple[cirq.Qid, cirq.Qid], ) -> cirq.OP_TREE: """Generates the COMPARE2 circuit from https://www.nature.com/articles/s41534-018-0071-5#Sec8""" [ancilla] = context.qubit_manager.qalloc(1) - x_1, x_0 = x - y_1, y_0 = y + x_1, y_1 = msb + x_0, y_0 = lsb - def _cswap(c: cirq.Qid, a: cirq.Qid, b: cirq.Qid) -> Iterator[cirq.Operation]: + def _cswap(c: cirq.Qid, a: cirq.Qid, b: cirq.Qid) -> cirq.OP_TREE: [q] = context.qubit_manager.qalloc(1) yield cirq.CNOT(a, b) yield and_gate.And().on(c, b, q) @@ -148,8 +150,8 @@ def _cswap(c: cirq.Qid, a: cirq.Qid, b: cirq.Qid) -> Iterator[cirq.Operation]: yield cirq.X(ancilla) yield cirq.CNOT(y_1, x_1) yield cirq.CNOT(y_0, x_0) - yield from _cswap(x_1, x_0, ancilla) - yield from _cswap(x_1, y_1, y_0) + yield _cswap(x_1, x_0, ancilla) + yield _cswap(x_1, y_1, y_0) yield cirq.CNOT(y_0, x_0) @@ -177,6 +179,18 @@ def compare_qubits( yield cirq.X(y) +def _equality_with_zero( + context: cirq.DecompositionContext, X: Tuple[cirq.Qid, ...], z: cirq.Qid +) -> cirq.OP_TREE: + if len(X) == 1: + yield cirq.X(X[0]) + yield cirq.CNOT(X[0], z) + return + + ancilla = context.qubit_manager.qalloc(len(X) - 2) + yield and_gate.And(cv=[0] * len(X)).on(*X, *ancilla, z) + + @attr.frozen class LessThanEqualGate(cirq.ArithmeticGate): """Applies U|x>|y>|z> = |x>|y> |z ^ (x <= y)>""" @@ -210,25 +224,19 @@ def __repr__(self) -> str: def _decompose_via_tree( self, context: cirq.DecompositionContext, X: Tuple[cirq.Qid, ...], Y: Tuple[cirq.Qid, ...] - ) -> Tuple[Tuple[cirq.Operation, ...], Tuple[cirq.Qid, cirq.Qid]]: + ) -> cirq.OP_TREE: """Returns comparison oracle from https://www.nature.com/articles/s41534-018-0071-5#Sec8""" assert len(X) == len(Y), f'{len(X)=} != {len(Y)=}' if len(X) == 1: - return (), (X[0], Y[0]) + return if len(X) == 2: - X = cast(Tuple[cirq.Qid, cirq.Qid], X) - Y = cast(Tuple[cirq.Qid, cirq.Qid], Y) - return tuple(cirq.flatten_to_ops(mix_double_qubit_registers(context, X, Y))), ( - X[1], - Y[1], - ) + yield mix_double_qubit_registers(context, (X[0], Y[0]), (X[1], Y[1])) + return m = len(X) // 2 - op_left, ql = self._decompose_via_tree(context, X[:m], Y[:m]) - op_right, qr = self._decompose_via_tree(context, X[m:], Y[m:]) - return op_left + op_right + tuple( - cirq.flatten_to_ops(mix_double_qubit_registers(context, ql, qr)) - ), (ql[1], qr[1]) + yield self._decompose_via_tree(context, X[:m], Y[:m]) + yield self._decompose_via_tree(context, X[m:], Y[m:]) + yield mix_double_qubit_registers(context, (X[m - 1], Y[m - 1]), (X[-1], Y[-1])) def _decompose_with_context_( self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None @@ -240,65 +248,57 @@ def _decompose_with_context_( n = min(len(P), len(Q)) - equal_so_far = None - adjoint = [] - - # if one of the registers is longer than the other compute store equality value - # into `equal_so_far` using d = |len(P) - len(Q)| And operations => 4d T. - if abs(len(P) - len(Q)) == 1: - [equal_so_far] = context.qubit_manager.qalloc(1) - yield cirq.X(equal_so_far) - adjoint.append(cirq.X(equal_so_far)) + prefix_equality = None + adjoint: List[cirq.Operation] = [] + # if one of the registers is longer than the other store equality with |0--0> + # into `prefix_equality` using d = |len(P) - len(Q)| And operations => 4d T. + if len(P) != len(Q): + [prefix_equality] = context.qubit_manager.qalloc(1) if len(P) > len(Q): - yield cirq.CNOT(P[0], equal_so_far) - adjoint.append(cirq.CNOT(P[0], equal_so_far)) + decomposition = tuple( + cirq.flatten_to_ops( + _equality_with_zero(context, tuple(P[:-n]), prefix_equality) + ) + ) + yield from decomposition + adjoint.extend(cirq.inverse(op) for op in decomposition) else: - yield cirq.CNOT(Q[0], equal_so_far) - adjoint.append(cirq.CNOT(Q[0], equal_so_far)) - - yield cirq.CNOT(Q[0], target) - elif len(P) > len(Q): - [equal_so_far] = context.qubit_manager.qalloc(1) - - m = len(P) - n - ancilla = context.qubit_manager.qalloc(m - 2) - yield and_gate.And(cv=[0] * m).on(*P[:m], *ancilla, equal_so_far) - adjoint.append( - and_gate.And(cv=[0] * m, adjoint=True).on(*P[:m], *ancilla, equal_so_far) - ) - - elif len(P) < len(Q): - [equal_so_far] = context.qubit_manager.qalloc(1) - - m = len(Q) - n - ancilla = context.qubit_manager.qalloc(m - 2) - yield and_gate.And(cv=[0] * m)(*Q[:m], *ancilla, equal_so_far) - adjoint.append(and_gate.And(cv=[0] * m, adjoint=True)(*Q[:m], *ancilla, equal_so_far)) + decomposition = tuple( + cirq.flatten_to_ops( + _equality_with_zero(context, tuple(Q[:-n]), prefix_equality) + ) + ) + yield from decomposition + adjoint.extend(cirq.inverse(op) for op in decomposition) - yield cirq.X(target), cirq.CNOT(equal_so_far, target) + yield cirq.X(target), cirq.CNOT(prefix_equality, target) # compare the remaing suffix of P and Q P = P[-n:] Q = Q[-n:] - decomposition, (x, y) = self._decompose_via_tree(context, tuple(P), tuple(Q)) - yield from decomposition + decomposition = tuple( + cirq.flatten_to_ops(self._decompose_via_tree(context, tuple(P), tuple(Q))) + ) adjoint.extend(cirq.inverse(op) for op in decomposition) + yield from decomposition less_than, greater_than = context.qubit_manager.qalloc(2) - decomposition = tuple(cirq.flatten_to_ops(compare_qubits(x, y, less_than, greater_than))) - yield from decomposition + decomposition = tuple( + cirq.flatten_to_ops(compare_qubits(P[-1], Q[-1], less_than, greater_than)) + ) adjoint.extend(cirq.inverse(op) for op in decomposition) + yield from decomposition - if equal_so_far is None: - yield cirq.CNOT(greater_than, target) + if prefix_equality is None: yield cirq.X(target) + yield cirq.CNOT(greater_than, target) else: [less_than_or_equal] = context.qubit_manager.qalloc(1) - yield and_gate.And([1, 0]).on(equal_so_far, greater_than, less_than_or_equal) + yield and_gate.And([1, 0]).on(prefix_equality, greater_than, less_than_or_equal) adjoint.append( and_gate.And([1, 0], adjoint=True).on( - equal_so_far, greater_than, less_than_or_equal + prefix_equality, greater_than, less_than_or_equal ) ) @@ -313,7 +313,7 @@ def _t_complexity_(self) -> infra.TComplexity: if d == 0: return infra.TComplexity(t=8 * n - 4, clifford=46 * n - 17) elif d == 1: - return infra.TComplexity(t=8 * n, clifford=46 * n + 3 + is_second_longer) + return infra.TComplexity(t=8 * n, clifford=46 * n + 3 + 2 * is_second_longer) else: return infra.TComplexity( t=8 * n + 4 * d - 4, clifford=46 * n + 17 * d - 14 + 2 * is_second_longer diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py index bdcdd442054..9a3d9c2e1dd 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py @@ -324,8 +324,8 @@ def test_add_no_decompose(a, b): cirq.testing.assert_equivalent_computational_basis_map(basis_map, circuit) -@pytest.mark.parametrize("P,n", [(v, n) for n in range(1, 3) for v in range(1 << n)]) -@pytest.mark.parametrize("Q,m", [(v, n) for n in range(1, 3) for v in range(1 << n)]) +@pytest.mark.parametrize("P,n", [(v, n) for n in range(1, 4) for v in range(1 << n)]) +@pytest.mark.parametrize("Q,m", [(v, n) for n in range(1, 4) for v in range(1 << n)]) def test_decompose_less_than_equal_gate(P: int, n: int, Q: int, m: int): qubit_states = list(bit_tools.iter_bits(P, n)) + list(bit_tools.iter_bits(Q, m)) circuit = cirq.Circuit( @@ -334,6 +334,8 @@ def test_decompose_less_than_equal_gate(P: int, n: int, Q: int, m: int): num_ancillas = len(circuit.all_qubits()) - n - m - 1 initial_state = [0] * num_ancillas + qubit_states + [0] output_state = [0] * num_ancillas + qubit_states + [int(P <= Q)] + print(*sorted(circuit.all_qubits())) + print(output_state) cirq_ft.testing.assert_circuit_inp_out_cirqsim( circuit, sorted(circuit.all_qubits()), initial_state, output_state ) From e88d9cbcf245db0ecbca28cf2d52f70abc47bfa8 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Mon, 26 Jun 2023 18:20:06 +0100 Subject: [PATCH 03/12] be compatible with py3.7 --- cirq-ft/cirq_ft/algos/arithmetic_gates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index 100531abd9b..177adc2f523 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -226,7 +226,7 @@ def _decompose_via_tree( self, context: cirq.DecompositionContext, X: Tuple[cirq.Qid, ...], Y: Tuple[cirq.Qid, ...] ) -> cirq.OP_TREE: """Returns comparison oracle from https://www.nature.com/articles/s41534-018-0071-5#Sec8""" - assert len(X) == len(Y), f'{len(X)=} != {len(Y)=}' + assert len(X) == len(Y), '{} != {}'.format(len(X), len(Y)) if len(X) == 1: return if len(X) == 2: From 4e120892774013e31b933822061c314707b47e25 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Mon, 26 Jun 2023 18:21:25 +0100 Subject: [PATCH 04/12] remove debug prints --- cirq-ft/cirq_ft/algos/arithmetic_gates_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py index 9a3d9c2e1dd..263e2570c68 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py @@ -334,8 +334,6 @@ def test_decompose_less_than_equal_gate(P: int, n: int, Q: int, m: int): num_ancillas = len(circuit.all_qubits()) - n - m - 1 initial_state = [0] * num_ancillas + qubit_states + [0] output_state = [0] * num_ancillas + qubit_states + [int(P <= Q)] - print(*sorted(circuit.all_qubits())) - print(output_state) cirq_ft.testing.assert_circuit_inp_out_cirqsim( circuit, sorted(circuit.all_qubits()), initial_state, output_state ) From d766bfac1254b2d17fdaa03de7f01ea58d585b15 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Mon, 26 Jun 2023 18:23:40 +0100 Subject: [PATCH 05/12] be compatible with py3.7 --- cirq-ft/cirq_ft/algos/arithmetic_gates.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index 177adc2f523..4ad83d8792c 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -226,7 +226,6 @@ def _decompose_via_tree( self, context: cirq.DecompositionContext, X: Tuple[cirq.Qid, ...], Y: Tuple[cirq.Qid, ...] ) -> cirq.OP_TREE: """Returns comparison oracle from https://www.nature.com/articles/s41534-018-0071-5#Sec8""" - assert len(X) == len(Y), '{} != {}'.format(len(X), len(Y)) if len(X) == 1: return if len(X) == 2: From c9c213d4af1bff6a1f91cdb140aa3947644b0677 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Wed, 28 Jun 2023 16:25:38 +0100 Subject: [PATCH 06/12] addressed comments --- cirq-ft/cirq_ft/algos/__init__.py | 2 + cirq-ft/cirq_ft/algos/arithmetic_gates.py | 245 +++++++++++------- .../cirq_ft/algos/arithmetic_gates_test.py | 70 ++++- 3 files changed, 225 insertions(+), 92 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/__init__.py b/cirq-ft/cirq_ft/algos/__init__.py index 5acec04136d..b0cc284434d 100644 --- a/cirq-ft/cirq_ft/algos/__init__.py +++ b/cirq-ft/cirq_ft/algos/__init__.py @@ -20,6 +20,8 @@ ContiguousRegisterGate, LessThanEqualGate, LessThanGate, + SingleQubitCompare, + BiQubitsMixer, ) from cirq_ft.algos.generic_select import GenericSelect from cirq_ft.algos.hubbard_model import PrepareHubbard, SelectHubbard diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index 4ad83d8792c..775436bbf82 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Optional, Sequence, Tuple, Union, List +from typing import Iterable, Optional, Sequence, Tuple, Union, List, Iterator +from cirq._compat import cached_property import attr import cirq from cirq_ft import infra @@ -78,7 +79,7 @@ def _decompose_with_context_( return adjoint = [] - [are_equal] = context.qubit_manager.qalloc(1) + (are_equal,) = context.qubit_manager.qalloc(1) # Initially our belief is that the numbers are equal. yield cirq.X(are_equal) @@ -130,65 +131,133 @@ def _t_complexity_(self) -> infra.TComplexity: ) -def mix_double_qubit_registers( - context: cirq.DecompositionContext, - msb: Tuple[cirq.Qid, cirq.Qid], - lsb: Tuple[cirq.Qid, cirq.Qid], -) -> cirq.OP_TREE: - """Generates the COMPARE2 circuit from https://www.nature.com/articles/s41534-018-0071-5#Sec8""" - [ancilla] = context.qubit_manager.qalloc(1) - x_1, y_1 = msb - x_0, y_0 = lsb - - def _cswap(c: cirq.Qid, a: cirq.Qid, b: cirq.Qid) -> cirq.OP_TREE: - [q] = context.qubit_manager.qalloc(1) - yield cirq.CNOT(a, b) - yield and_gate.And().on(c, b, q) - yield cirq.CNOT(q, a) - yield cirq.CNOT(a, b) - - yield cirq.X(ancilla) - yield cirq.CNOT(y_1, x_1) - yield cirq.CNOT(y_0, x_0) - yield _cswap(x_1, x_0, ancilla) - yield _cswap(x_1, y_1, y_0) - yield cirq.CNOT(y_0, x_0) - - -def compare_qubits( - x: cirq.Qid, y: cirq.Qid, less_than: cirq.Qid, greater_than: cirq.Qid -) -> cirq.OP_TREE: - """Generates the comparison circuit from https://www.nature.com/articles/s41534-018-0071-5#Sec8 +@attr.frozen +class BiQubitsMixer(infra.GateWithRegisters): + """Implements the COMPARE2 (Fig. 1) https://www.nature.com/articles/s41534-018-0071-5#Sec8 + + This gates mixes the values in away that preserves the result of comparison. + The registers being compared are 2-qubit registers where + x = 2*x_msb + x_lsb + y = 2*y_msb + y_lsb + The Gate mixes the 4 qubits so that sign(x - y) = sign(x_lsb' - y_lsb') where x_lsb' and y_lsb' + are the final values of x_lsb' and y_lsb'. + """ - Args: - x: first qubit of the comparison and stays the same after circuit execution. - y: second qubit of the comparison. - This qubit will store equality value `x==y` after circuit execution. - less_than: Assumed to be in zero state. Will store `x < y`. - greater_than: Assumed to be in zero state. Will store `x > y`. - - Returns: - The circuit in (Fig. 3) in https://www.nature.com/articles/s41534-018-0071-5#Sec8 + adjoint: bool = False + + @cached_property + def registers(self) -> infra.Registers: + return infra.Registers.build(x=2, y=2, ancilla=3) + + def _has_unitary_(self): + return not self.adjoint + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + ) -> cirq.OP_TREE: + + x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla'] + x_msb, x_lsb = x + y_msb, y_lsb = y + + def _cswap(c: cirq.Qid, a: cirq.Qid, b: cirq.Qid, q: cirq.Qid) -> cirq.OP_TREE: + yield cirq.CNOT(a, b) + yield and_gate.And(adjoint=self.adjoint).on(c, b, q) + yield cirq.CNOT(q, a) + yield cirq.CNOT(a, b) + + def _decomposition(): + yield cirq.X(ancilla[0]) + yield cirq.CNOT(y_msb, x_msb) + yield cirq.CNOT(y_lsb, x_lsb) + yield from _cswap(x_msb, x_lsb, ancilla[0], ancilla[1]) + yield from _cswap(x_msb, y_msb, y_lsb, ancilla[2]) + yield cirq.CNOT(y_lsb, x_lsb) + + if self.adjoint: + yield from reversed(tuple(cirq.flatten_to_ops(_decomposition()))) + else: + yield from _decomposition() + + def __repr__(self) -> str: + return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})' + + def __pow__(self, power: int) -> cirq.Gate: + if power == 1: + return self + if power == -1: + return BiQubitsMixer(adjoint=True) + return NotImplemented # coverage: ignore + + def _t_complexity_(self) -> infra.TComplexity: + if self.adjoint: + return infra.TComplexity(clifford=18) + return infra.TComplexity(t=8, clifford=28) + + +@attr.frozen +class SingleQubitCompare(infra.GateWithRegisters): + """Applies U|a>|b>|0>|0> = |a> |a=b> |(a |(a>b)> + + Source: (FIG. 3) in https://www.nature.com/articles/s41534-018-0071-5#Sec8 """ - yield and_gate.And([0, 1]).on(x, y, less_than) - yield cirq.CNOT(less_than, greater_than) - yield cirq.CNOT(y, greater_than) - yield cirq.CNOT(x, y) - yield cirq.CNOT(x, greater_than) - yield cirq.X(y) + adjoint: bool = False + + @cached_property + def registers(self) -> infra.Registers: + return infra.Registers.build(a=1, b=1, less_than=1, greater_than=1) + + def __repr__(self) -> str: + return f'cirq_ft.algos.SingleQubitCompare({self.adjoint})' + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + ) -> cirq.OP_TREE: + def _decomposition() -> Iterator[cirq.Operation]: + (a,), (b,), (less_than,), (greater_than,) = ( + quregs['a'], + quregs['b'], + quregs['less_than'], + quregs['greater_than'], + ) + yield and_gate.And([0, 1], adjoint=self.adjoint).on(a, b, less_than) + yield cirq.CNOT(less_than, greater_than) + yield cirq.CNOT(b, greater_than) + yield cirq.CNOT(a, b) + yield cirq.CNOT(a, greater_than) + yield cirq.X(b) + + if self.adjoint: + yield from reversed(tuple(_decomposition())) + else: + yield from _decomposition() + + def __pow__(self, power: int) -> cirq.Gate: + if power % 2 == 0: + return cirq.IdentityGate(4) + adjoint = power < 0 + if adjoint: + return SingleQubitCompare(adjoint=True) + return self + + def _t_complexity_(self) -> infra.TComplexity: + if self.adjoint: + return infra.TComplexity(clifford=11) + return infra.TComplexity(t=4, clifford=16) def _equality_with_zero( - context: cirq.DecompositionContext, X: Tuple[cirq.Qid, ...], z: cirq.Qid + context: cirq.DecompositionContext, qubits: Sequence[cirq.Qid], z: cirq.Qid ) -> cirq.OP_TREE: - if len(X) == 1: - yield cirq.X(X[0]) - yield cirq.CNOT(X[0], z) + if len(qubits) == 1: + (q,) = qubits + yield cirq.X(q) + yield cirq.CNOT(q, z) return - ancilla = context.qubit_manager.qalloc(len(X) - 2) - yield and_gate.And(cv=[0] * len(X)).on(*X, *ancilla, z) + ancilla = context.qubit_manager.qalloc(len(qubits) - 2) + yield and_gate.And(cv=[0] * len(qubits)).on(*qubits, *ancilla, z) @attr.frozen @@ -223,19 +292,24 @@ def __repr__(self) -> str: return f'cirq_ft.LessThanEqualGate({self.x_bitsize}, {self.y_bitsize})' def _decompose_via_tree( - self, context: cirq.DecompositionContext, X: Tuple[cirq.Qid, ...], Y: Tuple[cirq.Qid, ...] + self, context: cirq.DecompositionContext, X: Sequence[cirq.Qid], Y: Sequence[cirq.Qid] ) -> cirq.OP_TREE: - """Returns comparison oracle from https://www.nature.com/articles/s41534-018-0071-5#Sec8""" + """Returns comparison oracle from https://www.nature.com/articles/s41534-018-0071-5#Sec8 + + This decomposition follows the tree structure of (FIG. 2) + """ if len(X) == 1: return if len(X) == 2: - yield mix_double_qubit_registers(context, (X[0], Y[0]), (X[1], Y[1])) + yield BiQubitsMixer().on_registers(x=X, y=Y, ancilla=context.qubit_manager.qalloc(3)) return m = len(X) // 2 yield self._decompose_via_tree(context, X[:m], Y[:m]) yield self._decompose_via_tree(context, X[m:], Y[m:]) - yield mix_double_qubit_registers(context, (X[m - 1], Y[m - 1]), (X[-1], Y[-1])) + yield BiQubitsMixer().on_registers( + x=(X[m - 1], X[-1]), y=(Y[m - 1], Y[-1]), ancilla=context.qubit_manager.qalloc(3) + ) def _decompose_with_context_( self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None @@ -243,57 +317,54 @@ def _decompose_with_context_( if context is None: context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) - P, Q, target = (qubits[: self.x_bitsize], qubits[self.x_bitsize : -1], qubits[-1]) + lhs, rhs, target = qubits[: self.x_bitsize], qubits[self.x_bitsize : -1], qubits[-1] - n = min(len(P), len(Q)) + n = min(len(lhs), len(rhs)) prefix_equality = None adjoint: List[cirq.Operation] = [] # if one of the registers is longer than the other store equality with |0--0> # into `prefix_equality` using d = |len(P) - len(Q)| And operations => 4d T. - if len(P) != len(Q): - [prefix_equality] = context.qubit_manager.qalloc(1) - if len(P) > len(Q): - decomposition = tuple( - cirq.flatten_to_ops( - _equality_with_zero(context, tuple(P[:-n]), prefix_equality) - ) - ) - yield from decomposition - adjoint.extend(cirq.inverse(op) for op in decomposition) + if len(lhs) != len(rhs): + (prefix_equality,) = context.qubit_manager.qalloc(1) + if len(lhs) > len(rhs): + for op in cirq.flatten_to_ops( + _equality_with_zero(context, lhs[:-n], prefix_equality) + ): + yield op + adjoint.append(cirq.inverse(op)) else: - decomposition = tuple( - cirq.flatten_to_ops( - _equality_with_zero(context, tuple(Q[:-n]), prefix_equality) - ) - ) - yield from decomposition - adjoint.extend(cirq.inverse(op) for op in decomposition) + for op in cirq.flatten_to_ops( + _equality_with_zero(context, rhs[:-n], prefix_equality) + ): + yield op + adjoint.append(cirq.inverse(op)) yield cirq.X(target), cirq.CNOT(prefix_equality, target) # compare the remaing suffix of P and Q - P = P[-n:] - Q = Q[-n:] - decomposition = tuple( - cirq.flatten_to_ops(self._decompose_via_tree(context, tuple(P), tuple(Q))) - ) - adjoint.extend(cirq.inverse(op) for op in decomposition) - yield from decomposition + lhs = lhs[-n:] + rhs = rhs[-n:] + for op in cirq.flatten_to_ops(self._decompose_via_tree(context, lhs, rhs)): + yield op + adjoint.append(cirq.inverse(op)) less_than, greater_than = context.qubit_manager.qalloc(2) - decomposition = tuple( - cirq.flatten_to_ops(compare_qubits(P[-1], Q[-1], less_than, greater_than)) + yield SingleQubitCompare().on_registers( + a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than + ) + adjoint.append( + SingleQubitCompare(adjoint=True).on_registers( + a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than + ) ) - adjoint.extend(cirq.inverse(op) for op in decomposition) - yield from decomposition if prefix_equality is None: yield cirq.X(target) yield cirq.CNOT(greater_than, target) else: - [less_than_or_equal] = context.qubit_manager.qalloc(1) + (less_than_or_equal,) = context.qubit_manager.qalloc(1) yield and_gate.And([1, 0]).on(prefix_equality, greater_than, less_than_or_equal) adjoint.append( and_gate.And([1, 0], adjoint=True).on( diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py index 263e2570c68..9e500a3d7da 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py @@ -18,7 +18,7 @@ import cirq_ft import numpy as np import pytest -from cirq_ft.infra import bit_tools +from cirq_ft.infra import bit_tools, GreedyQubitManager def identity_map(n: int): @@ -329,11 +329,71 @@ def test_add_no_decompose(a, b): def test_decompose_less_than_equal_gate(P: int, n: int, Q: int, m: int): qubit_states = list(bit_tools.iter_bits(P, n)) + list(bit_tools.iter_bits(Q, m)) circuit = cirq.Circuit( - cirq.decompose_once(cirq_ft.LessThanEqualGate(n, m).on(*cirq.LineQubit.range(n + m + 1))) + cirq.decompose_once( + cirq_ft.LessThanEqualGate(n, m).on(*cirq.LineQubit.range(n + m + 1)), + context=cirq.DecompositionContext(GreedyQubitManager(prefix='_c')), + ) ) + qubit_order = tuple(sorted(circuit.all_qubits())) num_ancillas = len(circuit.all_qubits()) - n - m - 1 - initial_state = [0] * num_ancillas + qubit_states + [0] - output_state = [0] * num_ancillas + qubit_states + [int(P <= Q)] + initial_state = qubit_states + [0] + [0] * num_ancillas + output_state = qubit_states + [int(P <= Q)] + [0] * num_ancillas cirq_ft.testing.assert_circuit_inp_out_cirqsim( - circuit, sorted(circuit.all_qubits()), initial_state, output_state + circuit, qubit_order, initial_state, output_state + ) + + +@pytest.mark.parametrize("adjoint", [False, True]) +def test_single_qubit_compare_protocols(adjoint: bool): + g = cirq_ft.algos.SingleQubitCompare(adjoint=adjoint) + cirq_ft.testing.assert_decompose_is_consistent_with_t_complexity(g) + cirq.testing.assert_equivalent_repr(g, setup_code='import cirq_ft') + + +@pytest.mark.parametrize("v1,v2", [(v1, v2) for v1 in range(2) for v2 in range(2)]) +def test_single_qubit_compare(v1: int, v2: int): + g = cirq_ft.algos.SingleQubitCompare() + qubits = cirq.LineQid.range(4, dimension=2) + c = cirq.Circuit(g.on(*qubits)) + initial_state = [v1, v2, 0, 0] + output_state = [v1, int(v1 == v2), int(v1 < v2), int(v1 > v2)] + cirq_ft.testing.assert_circuit_inp_out_cirqsim( + c, sorted(c.all_qubits()), initial_state, output_state + ) + + # Check that g**-1 restores the qubits to their original state + c = cirq.Circuit(g.on(*qubits), (g**-1).on(*qubits)) + cirq_ft.testing.assert_circuit_inp_out_cirqsim( + c, sorted(c.all_qubits()), initial_state, initial_state + ) + + +@pytest.mark.parametrize("adjoint", [False, True]) +def test_bi_qubits_mixer_protocols(adjoint: bool): + g = cirq_ft.algos.BiQubitsMixer(adjoint=adjoint) + cirq_ft.testing.assert_decompose_is_consistent_with_t_complexity(g) + cirq.testing.assert_equivalent_repr(g, setup_code='import cirq_ft') + + +@pytest.mark.parametrize("x", [*range(4)]) +@pytest.mark.parametrize("y", [*range(4)]) +def test_bi_qubits_mixer(x: int, y: int): + g = cirq_ft.algos.BiQubitsMixer() + qubits = cirq.LineQid.range(7, dimension=2) + c = cirq.Circuit(g.on(*qubits)) + x_1, x_0 = (x >> 1) & 1, x & 1 + y_1, y_0 = (y >> 1) & 1, y & 1 + initial_state = [x_1, x_0, y_1, y_0, 0, 0, 0] + result = ( + cirq.Simulator() + .simulate(c, initial_state=initial_state, qubit_order=qubits) + .dirac_notation()[1:-1] + ) + x_0, y_0 = int(result[1]), int(result[3]) + assert np.sign(x - y) == np.sign(x_0 - y_0) + + # Check that g**-1 restores the qubits to their original state + c = cirq.Circuit(g.on(*qubits), (g**-1).on(*qubits)) + cirq_ft.testing.assert_circuit_inp_out_cirqsim( + c, sorted(c.all_qubits()), initial_state, initial_state ) From e778f58d6ab9b03d622b8b51de78d55091f0e16b Mon Sep 17 00:00:00 2001 From: Noureldin Date: Wed, 28 Jun 2023 16:46:12 +0100 Subject: [PATCH 07/12] updated docstring --- cirq-ft/cirq_ft/algos/arithmetic_gates.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index 775436bbf82..c558f20e319 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -314,6 +314,22 @@ def _decompose_via_tree( def _decompose_with_context_( self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None ) -> cirq.OP_TREE: + """Decomposes the gate in a T-complexity optimal way. + + The construction can be broken in 4 parts: + 1. In case of differing bitsizes then a multicontrol And Gate + (Section III.A. https://arxiv.org/abs/1805.03662) is used to check whether + the extra prefix is equal to zero: + result stored in: `prefix_equality` qubit. + 2. The tree structure (FIG. 2) https://www.nature.com/articles/s41534-018-0071-5#Sec8 + followed by a SingleQubitCompare to compute the result of comparison of + the suffixes of equal length: + result stored in: `less_than`, `greater_than` qubits with equality in qubits[-2] + 3. The results from the previous two steps are combined to update the target qubit. + 4. The adjoint of the previous operations is added to restore the input qubits + to their original state and clean the ancilla qubits. + """ + if context is None: context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) From f6410130d79833bce0057cdb97073bef0b664686 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Wed, 28 Jun 2023 17:28:00 +0100 Subject: [PATCH 08/12] fix coverage --- cirq-ft/cirq_ft/algos/arithmetic_gates.py | 4 ++-- cirq-ft/cirq_ft/algos/arithmetic_gates_test.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index c558f20e319..b0ebe3f9cba 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -186,7 +186,7 @@ def __pow__(self, power: int) -> cirq.Gate: if power == 1: return self if power == -1: - return BiQubitsMixer(adjoint=True) + return BiQubitsMixer(adjoint=self.adjoint ^ True) return NotImplemented # coverage: ignore def _t_complexity_(self) -> infra.TComplexity: @@ -238,7 +238,7 @@ def __pow__(self, power: int) -> cirq.Gate: return cirq.IdentityGate(4) adjoint = power < 0 if adjoint: - return SingleQubitCompare(adjoint=True) + return SingleQubitCompare(adjoint=self.adjoint ^ True) return self def _t_complexity_(self) -> infra.TComplexity: diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py index 9e500a3d7da..e7e808e442c 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py @@ -117,6 +117,13 @@ def test_multi_in_less_equal_than_gate(): def test_less_than_equal_consistent_protocols(x_bitsize: int, y_bitsize: int): g = cirq_ft.LessThanEqualGate(x_bitsize, y_bitsize) cirq_ft.testing.assert_decompose_is_consistent_with_t_complexity(g) + + # Decomposition works even when context is None. + qubits = cirq.LineQid.range(x_bitsize + y_bitsize + 1, dimension=2) + assert cirq.Circuit(g._decompose_with_context_(qubits=qubits)) == cirq.Circuit( + cirq.decompose_once(g.on(*qubits)) + ) + cirq.testing.assert_equivalent_repr(g, setup_code='import cirq_ft') # Test the unitary is self-inverse assert g**-1 is g @@ -349,6 +356,10 @@ def test_single_qubit_compare_protocols(adjoint: bool): cirq_ft.testing.assert_decompose_is_consistent_with_t_complexity(g) cirq.testing.assert_equivalent_repr(g, setup_code='import cirq_ft') + assert g**2 == cirq.IdentityGate(4) + assert g**1 is g + assert g**-1 == cirq_ft.algos.SingleQubitCompare(adjoint=adjoint ^ True) + @pytest.mark.parametrize("v1,v2", [(v1, v2) for v1 in range(2) for v2 in range(2)]) def test_single_qubit_compare(v1: int, v2: int): @@ -374,6 +385,9 @@ def test_bi_qubits_mixer_protocols(adjoint: bool): cirq_ft.testing.assert_decompose_is_consistent_with_t_complexity(g) cirq.testing.assert_equivalent_repr(g, setup_code='import cirq_ft') + assert g**1 is g + assert g**-1 == cirq_ft.algos.BiQubitsMixer(adjoint=adjoint ^ True) + @pytest.mark.parametrize("x", [*range(4)]) @pytest.mark.parametrize("y", [*range(4)]) From 916ac5f04e18092e63f35923f102cb1b081d5619 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Fri, 14 Jul 2023 15:35:15 +0100 Subject: [PATCH 09/12] addressing comments --- cirq-ft/cirq_ft/algos/arithmetic_gates.py | 73 ++++++++++++------- .../cirq_ft/algos/arithmetic_gates_test.py | 4 +- 2 files changed, 49 insertions(+), 28 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index b0ebe3f9cba..04520bd3f80 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -135,7 +135,7 @@ def _t_complexity_(self) -> infra.TComplexity: class BiQubitsMixer(infra.GateWithRegisters): """Implements the COMPARE2 (Fig. 1) https://www.nature.com/articles/s41534-018-0071-5#Sec8 - This gates mixes the values in away that preserves the result of comparison. + This gates mixes the values in a way that preserves the result of comparison. The registers being compared are 2-qubit registers where x = 2*x_msb + x_lsb y = 2*y_msb + y_lsb @@ -155,18 +155,30 @@ def _has_unitary_(self): def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] ) -> cirq.OP_TREE: - x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla'] x_msb, x_lsb = x y_msb, y_lsb = y - def _cswap(c: cirq.Qid, a: cirq.Qid, b: cirq.Qid, q: cirq.Qid) -> cirq.OP_TREE: + def _cswap(control: cirq.Qid, a: cirq.Qid, b: cirq.Qid, aux: cirq.Qid) -> cirq.OP_TREE: + """A CSWAP with 4T complexity and whose adjoint has 0T complexity. + + A controlled SWAP that swaps `a` and `b` based on `control`. + It uses an extra qubit `aux` so that its adjoint would have + a T complexity of zero. + """ yield cirq.CNOT(a, b) - yield and_gate.And(adjoint=self.adjoint).on(c, b, q) - yield cirq.CNOT(q, a) + yield and_gate.And(adjoint=self.adjoint).on(control, b, aux) + yield cirq.CNOT(aux, a) yield cirq.CNOT(a, b) def _decomposition(): + # computes the difference of x - y where + # x = 2*x_msb + x_lsb + # y = 2*y_msb + y_lsb + # And stores the result in x_lsb and y_lsb such that + # sign(x - y) = sign(x_lsb - y_lsb) + # This decomposition uses 3 ancilla qubits in order to have a + # T complexity of 8. yield cirq.X(ancilla[0]) yield cirq.CNOT(y_msb, x_msb) yield cirq.CNOT(y_lsb, x_lsb) @@ -186,7 +198,7 @@ def __pow__(self, power: int) -> cirq.Gate: if power == 1: return self if power == -1: - return BiQubitsMixer(adjoint=self.adjoint ^ True) + return BiQubitsMixer(adjoint=not self.adjoint) return NotImplemented # coverage: ignore def _t_complexity_(self) -> infra.TComplexity: @@ -214,19 +226,18 @@ def __repr__(self) -> str: def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] ) -> cirq.OP_TREE: + a = quregs['a'] + b = quregs['b'] + less_than = quregs['less_than'] + greater_than = quregs['greater_than'] + def _decomposition() -> Iterator[cirq.Operation]: - (a,), (b,), (less_than,), (greater_than,) = ( - quregs['a'], - quregs['b'], - quregs['less_than'], - quregs['greater_than'], - ) - yield and_gate.And([0, 1], adjoint=self.adjoint).on(a, b, less_than) - yield cirq.CNOT(less_than, greater_than) - yield cirq.CNOT(b, greater_than) - yield cirq.CNOT(a, b) - yield cirq.CNOT(a, greater_than) - yield cirq.X(b) + yield and_gate.And([0, 1], adjoint=self.adjoint).on(*a, *b, *less_than) + yield cirq.CNOT(*less_than, *greater_than) + yield cirq.CNOT(*b, *greater_than) + yield cirq.CNOT(*a, *b) + yield cirq.CNOT(*a, *greater_than) + yield cirq.X(*b) if self.adjoint: yield from reversed(tuple(_decomposition())) @@ -238,7 +249,7 @@ def __pow__(self, power: int) -> cirq.Gate: return cirq.IdentityGate(4) adjoint = power < 0 if adjoint: - return SingleQubitCompare(adjoint=self.adjoint ^ True) + return SingleQubitCompare(adjoint=not self.adjoint) return self def _t_complexity_(self) -> infra.TComplexity: @@ -318,13 +329,13 @@ def _decompose_with_context_( The construction can be broken in 4 parts: 1. In case of differing bitsizes then a multicontrol And Gate - (Section III.A. https://arxiv.org/abs/1805.03662) is used to check whether + - Section III.A. https://arxiv.org/abs/1805.03662) is used to check whether the extra prefix is equal to zero: - result stored in: `prefix_equality` qubit. + - result stored in: `prefix_equality` qubit. 2. The tree structure (FIG. 2) https://www.nature.com/articles/s41534-018-0071-5#Sec8 followed by a SingleQubitCompare to compute the result of comparison of the suffixes of equal length: - result stored in: `less_than`, `greater_than` qubits with equality in qubits[-2] + - result stored in: `less_than` and `greater_than` with equality in qubits[-2] 3. The results from the previous two steps are combined to update the target qubit. 4. The adjoint of the previous operations is added to restore the input qubits to their original state and clean the ancilla qubits. @@ -397,13 +408,21 @@ def _t_complexity_(self) -> infra.TComplexity: d = max(self.x_bitsize, self.y_bitsize) - n is_second_longer = self.y_bitsize > self.x_bitsize if d == 0: + # When both registers are of the same size the T complexity is + # 8n - 4 same as in https://www.nature.com/articles/s41534-018-0071-5#Sec8. return infra.TComplexity(t=8 * n - 4, clifford=46 * n - 17) - elif d == 1: - return infra.TComplexity(t=8 * n, clifford=46 * n + 3 + 2 * is_second_longer) else: - return infra.TComplexity( - t=8 * n + 4 * d - 4, clifford=46 * n + 17 * d - 14 + 2 * is_second_longer - ) + # When the registers differ in size and `n` is the size of the smaller one and + # `d` is the difference in size. The T complexity is the sum of the tree + # decomposition as before giving 8n + O(1) and the T complexity of an `And` gate + # over `d` registers giving 4d + O(1) totaling 8n + 4d + O(1). + # From the decomposition we get that the constant is -4 as well as the clifford counts. + if d == 1: + return infra.TComplexity(t=8 * n, clifford=46 * n + 3 + 2 * is_second_longer) + else: + return infra.TComplexity( + t=8 * n + 4 * d - 4, clifford=46 * n + 17 * d - 14 + 2 * is_second_longer + ) def _has_unitary_(self): return True diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py index e7e808e442c..296afb50237 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py @@ -121,7 +121,9 @@ def test_less_than_equal_consistent_protocols(x_bitsize: int, y_bitsize: int): # Decomposition works even when context is None. qubits = cirq.LineQid.range(x_bitsize + y_bitsize + 1, dimension=2) assert cirq.Circuit(g._decompose_with_context_(qubits=qubits)) == cirq.Circuit( - cirq.decompose_once(g.on(*qubits)) + cirq.decompose_once( + g.on(*qubits), context=cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) + ) ) cirq.testing.assert_equivalent_repr(g, setup_code='import cirq_ft') From 07f4277e078b784666605b599b2d00f80393fb91 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Fri, 14 Jul 2023 15:52:12 +0100 Subject: [PATCH 10/12] nits --- cirq-ft/cirq_ft/algos/arithmetic_gates.py | 5 +++-- cirq-ft/cirq_ft/algos/arithmetic_gates_test.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index 04520bd3f80..56338a05b2f 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -245,10 +245,11 @@ def _decomposition() -> Iterator[cirq.Operation]: yield from _decomposition() def __pow__(self, power: int) -> cirq.Gate: + if not isinstance(power, int): + raise ValueError('SingleQubitCompare is only defined for integer powers.') if power % 2 == 0: return cirq.IdentityGate(4) - adjoint = power < 0 - if adjoint: + if power < 0: return SingleQubitCompare(adjoint=not self.adjoint) return self diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py index 296afb50237..f62892a1528 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py @@ -358,6 +358,9 @@ def test_single_qubit_compare_protocols(adjoint: bool): cirq_ft.testing.assert_decompose_is_consistent_with_t_complexity(g) cirq.testing.assert_equivalent_repr(g, setup_code='import cirq_ft') + with pytest.raises(ValueError): + _ = g**0.5 + assert g**2 == cirq.IdentityGate(4) assert g**1 is g assert g**-1 == cirq_ft.algos.SingleQubitCompare(adjoint=adjoint ^ True) From accf212d892e109fccb0489f1e74bb436f8882d2 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Fri, 14 Jul 2023 15:57:09 +0100 Subject: [PATCH 11/12] fix mypy --- cirq-ft/cirq_ft/algos/arithmetic_gates_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py index f62892a1528..b33e9f4c7ce 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py @@ -359,7 +359,7 @@ def test_single_qubit_compare_protocols(adjoint: bool): cirq.testing.assert_equivalent_repr(g, setup_code='import cirq_ft') with pytest.raises(ValueError): - _ = g**0.5 + _ = g**0.5 # type: ignore assert g**2 == cirq.IdentityGate(4) assert g**1 is g From 3d51b3e1b177c455b9c4102f22f7a62e553ce1b3 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Mon, 17 Jul 2023 13:39:08 +0100 Subject: [PATCH 12/12] fix paper links --- cirq-ft/cirq_ft/algos/arithmetic_gates.py | 30 +++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index 56338a05b2f..8f51d019794 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -133,7 +133,7 @@ def _t_complexity_(self) -> infra.TComplexity: @attr.frozen class BiQubitsMixer(infra.GateWithRegisters): - """Implements the COMPARE2 (Fig. 1) https://www.nature.com/articles/s41534-018-0071-5#Sec8 + """Implements the COMPARE2 (Fig. 1) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf This gates mixes the values in a way that preserves the result of comparison. The registers being compared are 2-qubit registers where @@ -141,7 +141,7 @@ class BiQubitsMixer(infra.GateWithRegisters): y = 2*y_msb + y_lsb The Gate mixes the 4 qubits so that sign(x - y) = sign(x_lsb' - y_lsb') where x_lsb' and y_lsb' are the final values of x_lsb' and y_lsb'. - """ + """ # pylint: disable=line-too-long adjoint: bool = False @@ -149,8 +149,8 @@ class BiQubitsMixer(infra.GateWithRegisters): def registers(self) -> infra.Registers: return infra.Registers.build(x=2, y=2, ancilla=3) - def _has_unitary_(self): - return not self.adjoint + def __repr__(self) -> str: + return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})' def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] @@ -191,9 +191,6 @@ def _decomposition(): else: yield from _decomposition() - def __repr__(self) -> str: - return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})' - def __pow__(self, power: int) -> cirq.Gate: if power == 1: return self @@ -206,13 +203,16 @@ def _t_complexity_(self) -> infra.TComplexity: return infra.TComplexity(clifford=18) return infra.TComplexity(t=8, clifford=28) + def _has_unitary_(self): + return not self.adjoint + @attr.frozen class SingleQubitCompare(infra.GateWithRegisters): """Applies U|a>|b>|0>|0> = |a> |a=b> |(a |(a>b)> - Source: (FIG. 3) in https://www.nature.com/articles/s41534-018-0071-5#Sec8 - """ + Source: (FIG. 3) in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf + """ # pylint: disable=line-too-long adjoint: bool = False @@ -232,7 +232,7 @@ def decompose_from_registers( greater_than = quregs['greater_than'] def _decomposition() -> Iterator[cirq.Operation]: - yield and_gate.And([0, 1], adjoint=self.adjoint).on(*a, *b, *less_than) + yield and_gate.And((0, 1), adjoint=self.adjoint).on(*a, *b, *less_than) yield cirq.CNOT(*less_than, *greater_than) yield cirq.CNOT(*b, *greater_than) yield cirq.CNOT(*a, *b) @@ -306,10 +306,10 @@ def __repr__(self) -> str: def _decompose_via_tree( self, context: cirq.DecompositionContext, X: Sequence[cirq.Qid], Y: Sequence[cirq.Qid] ) -> cirq.OP_TREE: - """Returns comparison oracle from https://www.nature.com/articles/s41534-018-0071-5#Sec8 + """Returns comparison oracle from https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf This decomposition follows the tree structure of (FIG. 2) - """ + """ # pylint: disable=line-too-long if len(X) == 1: return if len(X) == 2: @@ -333,14 +333,14 @@ def _decompose_with_context_( - Section III.A. https://arxiv.org/abs/1805.03662) is used to check whether the extra prefix is equal to zero: - result stored in: `prefix_equality` qubit. - 2. The tree structure (FIG. 2) https://www.nature.com/articles/s41534-018-0071-5#Sec8 + 2. The tree structure (FIG. 2) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf followed by a SingleQubitCompare to compute the result of comparison of the suffixes of equal length: - result stored in: `less_than` and `greater_than` with equality in qubits[-2] 3. The results from the previous two steps are combined to update the target qubit. 4. The adjoint of the previous operations is added to restore the input qubits to their original state and clean the ancilla qubits. - """ + """ # pylint: disable=line-too-long if context is None: context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) @@ -410,7 +410,7 @@ def _t_complexity_(self) -> infra.TComplexity: is_second_longer = self.y_bitsize > self.x_bitsize if d == 0: # When both registers are of the same size the T complexity is - # 8n - 4 same as in https://www.nature.com/articles/s41534-018-0071-5#Sec8. + # 8n - 4 same as in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf. pylint: disable=line-too-long return infra.TComplexity(t=8 * n - 4, clifford=46 * n - 17) else: # When the registers differ in size and `n` is the size of the smaller one and