diff --git a/cirq-ft/cirq_ft/algos/__init__.py b/cirq-ft/cirq_ft/algos/__init__.py index 5acec04136d..b0cc284434d 100644 --- a/cirq-ft/cirq_ft/algos/__init__.py +++ b/cirq-ft/cirq_ft/algos/__init__.py @@ -20,6 +20,8 @@ ContiguousRegisterGate, LessThanEqualGate, LessThanGate, + SingleQubitCompare, + BiQubitsMixer, ) from cirq_ft.algos.generic_select import GenericSelect from cirq_ft.algos.hubbard_model import PrepareHubbard, SelectHubbard diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index bd68a02023e..8f51d019794 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Optional, Sequence, Tuple, Union +from typing import Iterable, Optional, Sequence, Tuple, Union, List, Iterator +from cirq._compat import cached_property import attr import cirq from cirq_ft import infra @@ -78,7 +79,7 @@ def _decompose_with_context_( return adjoint = [] - [are_equal] = context.qubit_manager.qalloc(1) + (are_equal,) = context.qubit_manager.qalloc(1) # Initially our belief is that the numbers are equal. yield cirq.X(are_equal) @@ -130,6 +131,147 @@ def _t_complexity_(self) -> infra.TComplexity: ) +@attr.frozen +class BiQubitsMixer(infra.GateWithRegisters): + """Implements the COMPARE2 (Fig. 1) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf + + This gates mixes the values in a way that preserves the result of comparison. + The registers being compared are 2-qubit registers where + x = 2*x_msb + x_lsb + y = 2*y_msb + y_lsb + The Gate mixes the 4 qubits so that sign(x - y) = sign(x_lsb' - y_lsb') where x_lsb' and y_lsb' + are the final values of x_lsb' and y_lsb'. + """ # pylint: disable=line-too-long + + adjoint: bool = False + + @cached_property + def registers(self) -> infra.Registers: + return infra.Registers.build(x=2, y=2, ancilla=3) + + def __repr__(self) -> str: + return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})' + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + ) -> cirq.OP_TREE: + x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla'] + x_msb, x_lsb = x + y_msb, y_lsb = y + + def _cswap(control: cirq.Qid, a: cirq.Qid, b: cirq.Qid, aux: cirq.Qid) -> cirq.OP_TREE: + """A CSWAP with 4T complexity and whose adjoint has 0T complexity. + + A controlled SWAP that swaps `a` and `b` based on `control`. + It uses an extra qubit `aux` so that its adjoint would have + a T complexity of zero. + """ + yield cirq.CNOT(a, b) + yield and_gate.And(adjoint=self.adjoint).on(control, b, aux) + yield cirq.CNOT(aux, a) + yield cirq.CNOT(a, b) + + def _decomposition(): + # computes the difference of x - y where + # x = 2*x_msb + x_lsb + # y = 2*y_msb + y_lsb + # And stores the result in x_lsb and y_lsb such that + # sign(x - y) = sign(x_lsb - y_lsb) + # This decomposition uses 3 ancilla qubits in order to have a + # T complexity of 8. + yield cirq.X(ancilla[0]) + yield cirq.CNOT(y_msb, x_msb) + yield cirq.CNOT(y_lsb, x_lsb) + yield from _cswap(x_msb, x_lsb, ancilla[0], ancilla[1]) + yield from _cswap(x_msb, y_msb, y_lsb, ancilla[2]) + yield cirq.CNOT(y_lsb, x_lsb) + + if self.adjoint: + yield from reversed(tuple(cirq.flatten_to_ops(_decomposition()))) + else: + yield from _decomposition() + + def __pow__(self, power: int) -> cirq.Gate: + if power == 1: + return self + if power == -1: + return BiQubitsMixer(adjoint=not self.adjoint) + return NotImplemented # coverage: ignore + + def _t_complexity_(self) -> infra.TComplexity: + if self.adjoint: + return infra.TComplexity(clifford=18) + return infra.TComplexity(t=8, clifford=28) + + def _has_unitary_(self): + return not self.adjoint + + +@attr.frozen +class SingleQubitCompare(infra.GateWithRegisters): + """Applies U|a>|b>|0>|0> = |a> |a=b> |(a |(a>b)> + + Source: (FIG. 3) in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf + """ # pylint: disable=line-too-long + + adjoint: bool = False + + @cached_property + def registers(self) -> infra.Registers: + return infra.Registers.build(a=1, b=1, less_than=1, greater_than=1) + + def __repr__(self) -> str: + return f'cirq_ft.algos.SingleQubitCompare({self.adjoint})' + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + ) -> cirq.OP_TREE: + a = quregs['a'] + b = quregs['b'] + less_than = quregs['less_than'] + greater_than = quregs['greater_than'] + + def _decomposition() -> Iterator[cirq.Operation]: + yield and_gate.And((0, 1), adjoint=self.adjoint).on(*a, *b, *less_than) + yield cirq.CNOT(*less_than, *greater_than) + yield cirq.CNOT(*b, *greater_than) + yield cirq.CNOT(*a, *b) + yield cirq.CNOT(*a, *greater_than) + yield cirq.X(*b) + + if self.adjoint: + yield from reversed(tuple(_decomposition())) + else: + yield from _decomposition() + + def __pow__(self, power: int) -> cirq.Gate: + if not isinstance(power, int): + raise ValueError('SingleQubitCompare is only defined for integer powers.') + if power % 2 == 0: + return cirq.IdentityGate(4) + if power < 0: + return SingleQubitCompare(adjoint=not self.adjoint) + return self + + def _t_complexity_(self) -> infra.TComplexity: + if self.adjoint: + return infra.TComplexity(clifford=11) + return infra.TComplexity(t=4, clifford=16) + + +def _equality_with_zero( + context: cirq.DecompositionContext, qubits: Sequence[cirq.Qid], z: cirq.Qid +) -> cirq.OP_TREE: + if len(qubits) == 1: + (q,) = qubits + yield cirq.X(q) + yield cirq.CNOT(q, z) + return + + ancilla = context.qubit_manager.qalloc(len(qubits) - 2) + yield and_gate.And(cv=[0] * len(qubits)).on(*qubits, *ancilla, z) + + @attr.frozen class LessThanEqualGate(cirq.ArithmeticGate): """Applies U|x>|y>|z> = |x>|y> |z ^ (x <= y)>""" @@ -161,9 +303,130 @@ def __pow__(self, power: int): def __repr__(self) -> str: return f'cirq_ft.LessThanEqualGate({self.x_bitsize}, {self.y_bitsize})' + def _decompose_via_tree( + self, context: cirq.DecompositionContext, X: Sequence[cirq.Qid], Y: Sequence[cirq.Qid] + ) -> cirq.OP_TREE: + """Returns comparison oracle from https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf + + This decomposition follows the tree structure of (FIG. 2) + """ # pylint: disable=line-too-long + if len(X) == 1: + return + if len(X) == 2: + yield BiQubitsMixer().on_registers(x=X, y=Y, ancilla=context.qubit_manager.qalloc(3)) + return + + m = len(X) // 2 + yield self._decompose_via_tree(context, X[:m], Y[:m]) + yield self._decompose_via_tree(context, X[m:], Y[m:]) + yield BiQubitsMixer().on_registers( + x=(X[m - 1], X[-1]), y=(Y[m - 1], Y[-1]), ancilla=context.qubit_manager.qalloc(3) + ) + + def _decompose_with_context_( + self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None + ) -> cirq.OP_TREE: + """Decomposes the gate in a T-complexity optimal way. + + The construction can be broken in 4 parts: + 1. In case of differing bitsizes then a multicontrol And Gate + - Section III.A. https://arxiv.org/abs/1805.03662) is used to check whether + the extra prefix is equal to zero: + - result stored in: `prefix_equality` qubit. + 2. The tree structure (FIG. 2) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf + followed by a SingleQubitCompare to compute the result of comparison of + the suffixes of equal length: + - result stored in: `less_than` and `greater_than` with equality in qubits[-2] + 3. The results from the previous two steps are combined to update the target qubit. + 4. The adjoint of the previous operations is added to restore the input qubits + to their original state and clean the ancilla qubits. + """ # pylint: disable=line-too-long + + if context is None: + context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) + + lhs, rhs, target = qubits[: self.x_bitsize], qubits[self.x_bitsize : -1], qubits[-1] + + n = min(len(lhs), len(rhs)) + + prefix_equality = None + adjoint: List[cirq.Operation] = [] + + # if one of the registers is longer than the other store equality with |0--0> + # into `prefix_equality` using d = |len(P) - len(Q)| And operations => 4d T. + if len(lhs) != len(rhs): + (prefix_equality,) = context.qubit_manager.qalloc(1) + if len(lhs) > len(rhs): + for op in cirq.flatten_to_ops( + _equality_with_zero(context, lhs[:-n], prefix_equality) + ): + yield op + adjoint.append(cirq.inverse(op)) + else: + for op in cirq.flatten_to_ops( + _equality_with_zero(context, rhs[:-n], prefix_equality) + ): + yield op + adjoint.append(cirq.inverse(op)) + + yield cirq.X(target), cirq.CNOT(prefix_equality, target) + + # compare the remaing suffix of P and Q + lhs = lhs[-n:] + rhs = rhs[-n:] + for op in cirq.flatten_to_ops(self._decompose_via_tree(context, lhs, rhs)): + yield op + adjoint.append(cirq.inverse(op)) + + less_than, greater_than = context.qubit_manager.qalloc(2) + yield SingleQubitCompare().on_registers( + a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than + ) + adjoint.append( + SingleQubitCompare(adjoint=True).on_registers( + a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than + ) + ) + + if prefix_equality is None: + yield cirq.X(target) + yield cirq.CNOT(greater_than, target) + else: + (less_than_or_equal,) = context.qubit_manager.qalloc(1) + yield and_gate.And([1, 0]).on(prefix_equality, greater_than, less_than_or_equal) + adjoint.append( + and_gate.And([1, 0], adjoint=True).on( + prefix_equality, greater_than, less_than_or_equal + ) + ) + + yield cirq.CNOT(less_than_or_equal, target) + + yield from reversed(adjoint) + def _t_complexity_(self) -> infra.TComplexity: - # TODO(#112): This is rough cost that ignores cliffords. - return infra.TComplexity(t=4 * (self.x_bitsize + self.y_bitsize)) + n = min(self.x_bitsize, self.y_bitsize) + d = max(self.x_bitsize, self.y_bitsize) - n + is_second_longer = self.y_bitsize > self.x_bitsize + if d == 0: + # When both registers are of the same size the T complexity is + # 8n - 4 same as in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf. pylint: disable=line-too-long + return infra.TComplexity(t=8 * n - 4, clifford=46 * n - 17) + else: + # When the registers differ in size and `n` is the size of the smaller one and + # `d` is the difference in size. The T complexity is the sum of the tree + # decomposition as before giving 8n + O(1) and the T complexity of an `And` gate + # over `d` registers giving 4d + O(1) totaling 8n + 4d + O(1). + # From the decomposition we get that the constant is -4 as well as the clifford counts. + if d == 1: + return infra.TComplexity(t=8 * n, clifford=46 * n + 3 + 2 * is_second_longer) + else: + return infra.TComplexity( + t=8 * n + 4 * d - 4, clifford=46 * n + 17 * d - 14 + 2 * is_second_longer + ) + + def _has_unitary_(self): + return True @attr.frozen diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py index 204f3007502..b33e9f4c7ce 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates_test.py @@ -18,7 +18,7 @@ import cirq_ft import numpy as np import pytest -from cirq_ft.infra import bit_tools +from cirq_ft.infra import bit_tools, GreedyQubitManager def identity_map(n: int): @@ -112,11 +112,20 @@ def test_multi_in_less_equal_than_gate(): cirq.testing.assert_equivalent_computational_basis_map(identity_map(len(qubits)), circuit) -@pytest.mark.parametrize("x_bitsize", [2, 3]) -@pytest.mark.parametrize("y_bitsize", [2, 3]) +@pytest.mark.parametrize("x_bitsize", [*range(1, 5)]) +@pytest.mark.parametrize("y_bitsize", [*range(1, 5)]) def test_less_than_equal_consistent_protocols(x_bitsize: int, y_bitsize: int): g = cirq_ft.LessThanEqualGate(x_bitsize, y_bitsize) cirq_ft.testing.assert_decompose_is_consistent_with_t_complexity(g) + + # Decomposition works even when context is None. + qubits = cirq.LineQid.range(x_bitsize + y_bitsize + 1, dimension=2) + assert cirq.Circuit(g._decompose_with_context_(qubits=qubits)) == cirq.Circuit( + cirq.decompose_once( + g.on(*qubits), context=cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) + ) + ) + cirq.testing.assert_equivalent_repr(g, setup_code='import cirq_ft') # Test the unitary is self-inverse assert g**-1 is g @@ -322,3 +331,88 @@ def test_add_no_decompose(a, b): assert true_out_int == int(out_bin, 2) basis_map[input_int] = output_int cirq.testing.assert_equivalent_computational_basis_map(basis_map, circuit) + + +@pytest.mark.parametrize("P,n", [(v, n) for n in range(1, 4) for v in range(1 << n)]) +@pytest.mark.parametrize("Q,m", [(v, n) for n in range(1, 4) for v in range(1 << n)]) +def test_decompose_less_than_equal_gate(P: int, n: int, Q: int, m: int): + qubit_states = list(bit_tools.iter_bits(P, n)) + list(bit_tools.iter_bits(Q, m)) + circuit = cirq.Circuit( + cirq.decompose_once( + cirq_ft.LessThanEqualGate(n, m).on(*cirq.LineQubit.range(n + m + 1)), + context=cirq.DecompositionContext(GreedyQubitManager(prefix='_c')), + ) + ) + qubit_order = tuple(sorted(circuit.all_qubits())) + num_ancillas = len(circuit.all_qubits()) - n - m - 1 + initial_state = qubit_states + [0] + [0] * num_ancillas + output_state = qubit_states + [int(P <= Q)] + [0] * num_ancillas + cirq_ft.testing.assert_circuit_inp_out_cirqsim( + circuit, qubit_order, initial_state, output_state + ) + + +@pytest.mark.parametrize("adjoint", [False, True]) +def test_single_qubit_compare_protocols(adjoint: bool): + g = cirq_ft.algos.SingleQubitCompare(adjoint=adjoint) + cirq_ft.testing.assert_decompose_is_consistent_with_t_complexity(g) + cirq.testing.assert_equivalent_repr(g, setup_code='import cirq_ft') + + with pytest.raises(ValueError): + _ = g**0.5 # type: ignore + + assert g**2 == cirq.IdentityGate(4) + assert g**1 is g + assert g**-1 == cirq_ft.algos.SingleQubitCompare(adjoint=adjoint ^ True) + + +@pytest.mark.parametrize("v1,v2", [(v1, v2) for v1 in range(2) for v2 in range(2)]) +def test_single_qubit_compare(v1: int, v2: int): + g = cirq_ft.algos.SingleQubitCompare() + qubits = cirq.LineQid.range(4, dimension=2) + c = cirq.Circuit(g.on(*qubits)) + initial_state = [v1, v2, 0, 0] + output_state = [v1, int(v1 == v2), int(v1 < v2), int(v1 > v2)] + cirq_ft.testing.assert_circuit_inp_out_cirqsim( + c, sorted(c.all_qubits()), initial_state, output_state + ) + + # Check that g**-1 restores the qubits to their original state + c = cirq.Circuit(g.on(*qubits), (g**-1).on(*qubits)) + cirq_ft.testing.assert_circuit_inp_out_cirqsim( + c, sorted(c.all_qubits()), initial_state, initial_state + ) + + +@pytest.mark.parametrize("adjoint", [False, True]) +def test_bi_qubits_mixer_protocols(adjoint: bool): + g = cirq_ft.algos.BiQubitsMixer(adjoint=adjoint) + cirq_ft.testing.assert_decompose_is_consistent_with_t_complexity(g) + cirq.testing.assert_equivalent_repr(g, setup_code='import cirq_ft') + + assert g**1 is g + assert g**-1 == cirq_ft.algos.BiQubitsMixer(adjoint=adjoint ^ True) + + +@pytest.mark.parametrize("x", [*range(4)]) +@pytest.mark.parametrize("y", [*range(4)]) +def test_bi_qubits_mixer(x: int, y: int): + g = cirq_ft.algos.BiQubitsMixer() + qubits = cirq.LineQid.range(7, dimension=2) + c = cirq.Circuit(g.on(*qubits)) + x_1, x_0 = (x >> 1) & 1, x & 1 + y_1, y_0 = (y >> 1) & 1, y & 1 + initial_state = [x_1, x_0, y_1, y_0, 0, 0, 0] + result = ( + cirq.Simulator() + .simulate(c, initial_state=initial_state, qubit_order=qubits) + .dirac_notation()[1:-1] + ) + x_0, y_0 = int(result[1]), int(result[3]) + assert np.sign(x - y) == np.sign(x_0 - y_0) + + # Check that g**-1 restores the qubits to their original state + c = cirq.Circuit(g.on(*qubits), (g**-1).on(*qubits)) + cirq_ft.testing.assert_circuit_inp_out_cirqsim( + c, sorted(c.all_qubits()), initial_state, initial_state + ) diff --git a/cirq-ft/cirq_ft/algos/state_preparation_test.py b/cirq-ft/cirq_ft/algos/state_preparation_test.py index 01586a4b545..d1695a16424 100644 --- a/cirq-ft/cirq_ft/algos/state_preparation_test.py +++ b/cirq-ft/cirq_ft/algos/state_preparation_test.py @@ -30,14 +30,14 @@ def construct_gate_helper_and_qubit_order(data, eps): context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) def map_func(op: cirq.Operation, _): - gateset = cirq.Gateset(cirq_ft.And) + gateset = cirq.Gateset(cirq_ft.And, cirq_ft.LessThanEqualGate, cirq_ft.LessThanGate) return cirq.Circuit( cirq.decompose(op, on_stuck_raise=None, keep=gateset.validate, context=context) ) - # TODO: Do not decompose `cq.And` because the `cq.map_clean_and_borrowable_qubits` currently - # gets confused and is not able to re-map qubits optimally; which results in a higher number - # of ancillas and thus the tests fails due to OOO. + # TODO: Do not decompose {cq.And, cq.LessThanEqualGate, cq.LessThanGate} because the + # `cq.map_clean_and_borrowable_qubits` currently gets confused and is not able to re-map qubits + # optimally; which results in a higher number of ancillas thus the tests fails due to OOO. decomposed_circuit = cirq.map_operations_and_unroll( g.circuit, map_func, raise_if_add_qubits=False ) @@ -45,7 +45,10 @@ def map_func(op: cirq.Operation, _): decomposed_circuit = cirq_ft.map_clean_and_borrowable_qubits(decomposed_circuit, qm=greedy_mm) # We are fine decomposing the `cq.And` gates once the qubit re-mapping is complete. Ideally, # we shouldn't require this two step process. - decomposed_circuit = cirq.Circuit(cirq.decompose(decomposed_circuit)) + arithmetic_gateset = cirq.Gateset(cirq_ft.LessThanEqualGate, cirq_ft.LessThanGate) + decomposed_circuit = cirq.Circuit( + cirq.decompose(decomposed_circuit, keep=arithmetic_gateset.validate, on_stuck_raise=None) + ) ordered_input = list(itertools.chain(*g.quregs.values())) qubit_order = cirq.QubitOrder.explicit(ordered_input, fallback=cirq.QubitOrder.DEFAULT) return g, qubit_order, decomposed_circuit