diff --git a/cirq-core/cirq/ops/boolean_hamiltonian.py b/cirq-core/cirq/ops/boolean_hamiltonian.py index 5e0ad608008..0be4a2a7241 100644 --- a/cirq-core/cirq/ops/boolean_hamiltonian.py +++ b/cirq-core/cirq/ops/boolean_hamiltonian.py @@ -19,7 +19,11 @@ by Stuart Hadfield, https://arxiv.org/pdf/1804.09130.pdf [2] https://www.youtube.com/watch?v=AOKM9BkweVU is a useful intro [3] https://github.com/rsln-s/IEEE_QW_2020/blob/master/Slides.pdf +[4] Efficient Quantum Circuits for Diagonal Unitaries Without Ancillas by Jonathan Welch, Daniel + Greenbaum, Sarah Mostame, and Alán Aspuru-Guzik, https://arxiv.org/abs/1306.3991 """ +import itertools +import functools from typing import Any, Dict, Generator, List, Sequence, Tuple @@ -112,6 +116,187 @@ def _decompose_(self): ) +def _gray_code_comparator(k1: Tuple[int, ...], k2: Tuple[int, ...], flip: bool = False) -> int: + """Compares two Gray-encoded binary numbers. + + Args: + k1: A tuple of ints, representing the bits that are one. For example, 6 would be (1, 2). + k2: The second number, represented similarly as k1. + flip: Whether to flip the comparison. + + Returns: + -1 if k1 < k2 (or +1 if flip is true) + 0 if k1 == k2 + +1 if k1 > k2 (or -1 if flip is true) + """ + max_1 = k1[-1] if k1 else -1 + max_2 = k2[-1] if k2 else -1 + if max_1 != max_2: + return -1 if (max_1 < max_2) ^ flip else 1 + if max_1 == -1: + return 0 + return _gray_code_comparator(k1[0:-1], k2[0:-1], not flip) + + +def _simplify_commuting_cnots( + cnots: List[Tuple[int, int]], flip_control_and_target: bool +) -> Tuple[bool, List[Tuple[int, int]]]: + """Attempts to commute CNOTs and remove cancelling pairs. + + Commutation relations are based on 9 (flip_control_and_target=False) or 10 + (flip_control_target=True) of [4]: + When flip_control_target=True: + + CNOT(j, i) @ CNOT(j, k) = CNOT(j, k) @ CNOT(j, i) + ───X─────── ───────X─── + │ │ + ───@───@─── = ───@───@─── + │ │ + ───────X─── ───X─────── + + When flip_control_target=False: + + CNOT(i, j) @ CNOT(k, j) = CNOT(k, j) @ CNOT(i, j) + ───@─────── ───────@─── + │ │ + ───X───X─── = ───X───X─── + │ │ + ───────@─── ───@─────── + + Args: + cnots: A list of CNOTS, encoded as integer tuples (control, target). The code does not make + any assumption as to the order of the CNOTs, but it is likely to work better if its + inputs are from Gray-sorted Hamiltonians. Regardless of the order of the CNOTs, the + code is conservative and should be robust to mis-ordered inputs with the only side + effect being a lack of simplification. + flip_control_and_target: Whether to flip control and target. + + Returns: + A tuple containing a Boolean that tells whether a simplification has been performed and the + CNOT list, potentially simplified, encoded as integer tuples (control, target). + """ + + target, control = (0, 1) if flip_control_and_target else (1, 0) + + i = 0 + qubit_to_index: Dict[int, int] = {cnots[i][control]: i} if cnots else {} + for j in range(1, len(cnots)): + if cnots[i][target] != cnots[j][target]: + # The targets (resp. control) don't match, so we reset the search. + i = j + qubit_to_index = {cnots[j][control]: j} + continue + + if cnots[j][control] in qubit_to_index: + k = qubit_to_index[cnots[j][control]] + # The controls (resp. targets) are the same, so we can simplify away. + cnots = [cnots[n] for n in range(len(cnots)) if n != j and n != k] + # TODO(#4532): Speed up code by not returning early. + return True, cnots + + qubit_to_index[cnots[j][control]] = j + + return False, cnots + + +def _simplify_cnots_triplets( + cnots: List[Tuple[int, int]], flip_control_and_target: bool +) -> Tuple[bool, List[Tuple[int, int]]]: + """Simplifies CNOT pairs according to equation 11 of [4]. + + CNOT(i, j) @ CNOT(j, k) == CNOT(j, k) @ CNOT(i, k) @ CNOT(i, j) + ───@─────── ───────@───@─── + │ │ │ + ───X───@─── = ───@───┼───X─── + │ │ │ + ───────X─── ───X───X─────── + + Args: + cnots: A list of CNOTS, encoded as integer tuples (control, target). + flip_control_and_target: Whether to flip control and target. + + Returns: + A tuple containing a Boolean that tells whether a simplification has been performed and the + CNOT list, potentially simplified, encoded as integer tuples (control, target). + """ + target, control = (0, 1) if flip_control_and_target else (1, 0) + + # We investigate potential pivots sequentially. + for j in range(1, len(cnots) - 1): + # First, we look back for as long as the controls (resp. targets) are the same. + # They all commute, so all are potential candidates for being simplified. + # prev_match_index is qubit to index in `cnots` array. + prev_match_index: Dict[int, int] = {} + for i in range(j - 1, -1, -1): + # These CNOTs have the same target (resp. control) and though they are not candidates + # for simplification, since they commute, we can keep looking for candidates. + if cnots[i][target] == cnots[j][target]: + continue + if cnots[i][control] != cnots[j][control]: + break + # We take a note of the control (resp. target). + prev_match_index[cnots[i][target]] = i + + # Next, we look forward for as long as the targets (resp. controls) are the + # same. They all commute, so all are potential candidates for being simplified. + # post_match_index is qubit to index in `cnots` array. + post_match_index: Dict[int, int] = {} + for k in range(j + 1, len(cnots)): + # These CNOTs have the same control (resp. target) and though they are not candidates + # for simplification, since they commute, we can keep looking for candidates. + if cnots[j][control] == cnots[k][control]: + continue + if cnots[j][target] != cnots[k][target]: + break + # We take a note of the target (resp. control). + post_match_index[cnots[k][control]] = k + + # Among all the candidates, find if they have a match. + keys = prev_match_index.keys() & post_match_index.keys() + for key in keys: + # We perform the swap which removes the pivot. + new_idx: List[int] = ( + # Anything strictly before the pivot that is not the CNOT to swap. + [idx for idx in range(0, j) if idx != prev_match_index[key]] + # The two swapped CNOTs. + + [post_match_index[key], prev_match_index[key]] + # Anything after the pivot that is not the CNOT to swap. + + [idx for idx in range(j + 1, len(cnots)) if idx != post_match_index[key]] + ) + # Since we removed the pivot, the length should be one fewer. + cnots = [cnots[idx] for idx in new_idx] + # TODO(#4532): Speed up code by not returning early. + return True, cnots + + return False, cnots + + +def _simplify_cnots(cnots: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + """Takes a series of CNOTs and tries to applies rule to cancel out gates. + + Algorithm based on "Efficient quantum circuits for diagonal unitaries without ancillas" by + Jonathan Welch, Daniel Greenbaum, Sarah Mostame, Alán Aspuru-Guzik + https://arxiv.org/abs/1306.3991 + + Args: + cnots: A list of CNOTs represented as tuples of integer (control, target). + + Returns: + The simplified list of CNOTs, encoded as integer tuples (control, target). + """ + + found_simplification = True + while found_simplification: + for simplify_fn, flip_control_and_target in itertools.product( + [_simplify_commuting_cnots, _simplify_cnots_triplets], [False, True] + ): + found_simplification, cnots = simplify_fn(cnots, flip_control_and_target) + if found_simplification: + break + + return cnots + + def _get_gates_from_hamiltonians( hamiltonian_polynomial_list: List['cirq.PauliSum'], qubit_map: Dict[str, 'cirq.Qid'], @@ -145,16 +330,18 @@ def _apply_cnots(prevh: Tuple[int, ...], currh: Tuple[int, ...]): cnots.extend((prevh[i], prevh[-1]) for i in range(len(prevh) - 1)) cnots.extend((currh[i], currh[-1]) for i in range(len(currh) - 1)) - # TODO(tonybruguier): At this point, some CNOT gates can be cancelled out according to: - # "Efficient quantum circuits for diagonal unitaries without ancillas" by Jonathan Welch, - # Daniel Greenbaum, Sarah Mostame, Alán Aspuru-Guzik - # https://arxiv.org/abs/1306.3991 + cnots = _simplify_cnots(cnots) for gate in (cirq.CNOT(qubits[c], qubits[t]) for c, t in cnots): yield gate + sorted_hamiltonian_keys = sorted( + hamiltonians.keys(), key=functools.cmp_to_key(_gray_code_comparator) + ) + previous_h: Tuple[int, ...] = () - for h, w in hamiltonians.items(): + for h in sorted_hamiltonian_keys: + w = hamiltonians[h] yield _apply_cnots(previous_h, h) if len(h) >= 1: diff --git a/cirq-core/cirq/ops/boolean_hamiltonian_test.py b/cirq-core/cirq/ops/boolean_hamiltonian_test.py index 610c2236c37..9671101210e 100644 --- a/cirq-core/cirq/ops/boolean_hamiltonian_test.py +++ b/cirq-core/cirq/ops/boolean_hamiltonian_test.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import itertools import math +import random import numpy as np import pytest import sympy.parsing.sympy_parser as sympy_parser import cirq +import cirq.ops.boolean_hamiltonian as bh @pytest.mark.parametrize( @@ -98,3 +101,126 @@ def test_with_custom_names(): with pytest.raises(ValueError, match='Length of replacement qubits must be the same'): original_op.with_qubits(q2) + + +@pytest.mark.parametrize( + 'n_bits,expected_hs', + [ + (1, [(), (0,)]), + (2, [(), (0,), (0, 1), (1,)]), + (3, [(), (0,), (0, 1), (1,), (1, 2), (0, 1, 2), (0, 2), (2,)]), + ], +) +def test_gray_code_sorting(n_bits, expected_hs): + hs_template = [] + for x in range(2 ** n_bits): + h = [] + for i in range(n_bits): + if x % 2 == 1: + h.append(i) + x -= 1 + x //= 2 + hs_template.append(tuple(sorted(h))) + + for seed in range(10): + random.seed(seed) + + hs = hs_template.copy() + random.shuffle(hs) + + sorted_hs = sorted(list(hs), key=functools.cmp_to_key(bh._gray_code_comparator)) + + np.testing.assert_array_equal(sorted_hs, expected_hs) + + +@pytest.mark.parametrize( + 'seq_a,seq_b,expected', + [ + ((), (), 0), + ((), (0,), -1), + ((0,), (), 1), + ((0,), (0,), 0), + ], +) +def test_gray_code_comparison(seq_a, seq_b, expected): + assert bh._gray_code_comparator(seq_a, seq_b) == expected + + +@pytest.mark.parametrize( + 'input_cnots,input_flip_control_and_target,expected_simplified,expected_output_cnots', + [ + # Empty inputs don't get simplified. + ([], False, False, []), + ([], True, False, []), + # Single CNOTs don't get simplified. + ([(0, 1)], False, False, [(0, 1)]), + ([(0, 1)], True, False, [(0, 1)]), + # Simplify away two CNOTs that are identical: + ([(0, 1), (0, 1)], False, True, []), + ([(0, 1), (0, 1)], True, True, []), + # Also simplify away if there's another CNOT in between. + ([(0, 1), (2, 1), (0, 1)], False, True, [(2, 1)]), + ([(0, 1), (0, 2), (0, 1)], True, True, [(0, 2)]), + # However, the in-between has to share the same target/control. + ([(0, 1), (0, 2), (0, 1)], False, False, [(0, 1), (0, 2), (0, 1)]), + ([(0, 1), (2, 1), (0, 1)], True, False, [(0, 1), (2, 1), (0, 1)]), + # Can simplify, but violates CNOT ordering assumption + ([(0, 1), (2, 3), (0, 1)], False, False, [(0, 1), (2, 3), (0, 1)]), + ], +) +def test_simplify_commuting_cnots( + input_cnots, input_flip_control_and_target, expected_simplified, expected_output_cnots +): + actual_simplified, actual_output_cnots = bh._simplify_commuting_cnots( + input_cnots, input_flip_control_and_target + ) + assert actual_simplified == expected_simplified + assert actual_output_cnots == expected_output_cnots + + +@pytest.mark.parametrize( + 'input_cnots,input_flip_control_and_target,expected_simplified,expected_output_cnots', + [ + # Empty inputs don't get simplified. + ([], False, False, []), + ([], True, False, []), + # Single CNOTs don't get simplified. + ([(0, 1)], False, False, [(0, 1)]), + ([(0, 1)], True, False, [(0, 1)]), + # Simplify according to equation 11 of [4]. + ([(2, 1), (2, 0), (1, 0)], False, True, [(1, 0), (2, 1)]), + ([(1, 2), (0, 2), (0, 1)], True, True, [(0, 1), (1, 2)]), + # Same as above, but with a intervening CNOTs that prevent simplifications. + ([(2, 1), (2, 0), (100, 101), (1, 0)], False, False, [(2, 1), (2, 0), (100, 101), (1, 0)]), + ([(2, 1), (100, 101), (2, 0), (1, 0)], False, False, [(2, 1), (100, 101), (2, 0), (1, 0)]), + # swap (2, 1) and (1, 0) around (2, 0) + ([(2, 1), (2, 3), (2, 0), (3, 0), (1, 0)], False, True, [(2, 3), (1, 0), (2, 1), (3, 0)]), + ([(2, 1), (2, 0), (2, 3), (3, 0), (1, 0)], False, True, [(1, 0), (2, 1), (2, 3), (3, 0)]), + ([(2, 3), (2, 1), (2, 0), (3, 0), (1, 0)], False, True, [(2, 3), (1, 0), (2, 1), (3, 0)]), + ([(2, 1), (2, 3), (3, 0), (2, 0), (1, 0)], False, True, [(2, 3), (3, 0), (1, 0), (2, 1)]), + ([(2, 1), (2, 3), (2, 0), (1, 0), (3, 0)], False, True, [(2, 3), (1, 0), (2, 1), (3, 0)]), + ], +) +def test_simplify_cnots_triplets( + input_cnots, input_flip_control_and_target, expected_simplified, expected_output_cnots +): + actual_simplified, actual_output_cnots = bh._simplify_cnots_triplets( + input_cnots, input_flip_control_and_target + ) + assert actual_simplified == expected_simplified + assert actual_output_cnots == expected_output_cnots + + # Check that the unitaries are the same. + qubit_ids = set(sum(input_cnots, ())) + qubits = {qubit_id: cirq.NamedQubit(f"{qubit_id}") for qubit_id in qubit_ids} + + target, control = (0, 1) if input_flip_control_and_target else (1, 0) + + circuit_input = cirq.Circuit() + for input_cnot in input_cnots: + circuit_input.append(cirq.CNOT(qubits[input_cnot[target]], qubits[input_cnot[control]])) + circuit_actual = cirq.Circuit() + for actual_cnot in actual_output_cnots: + circuit_actual.append(cirq.CNOT(qubits[actual_cnot[target]], qubits[actual_cnot[control]])) + + np.testing.assert_allclose(cirq.unitary(circuit_input), cirq.unitary(circuit_actual), atol=1e-6)