Skip to content

Commit 5bbdc22

Browse files
authored
Apply variable-spaced optimization to QROM circuits (#6257)
* Apply variable-spaced optimization to QROM circuits * Fix flaky test due to a flakiness bug in GreedyQubitManager * Fix mypy issues * More tests and update hash for QROM since T-complexity now depends upon the data * Fix typo and failing test
1 parent 6abc740 commit 5bbdc22

File tree

5 files changed

+188
-21
lines changed

5 files changed

+188
-21
lines changed

cirq-ft/cirq_ft/algos/qrom.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Callable, Sequence, Tuple
15+
from typing import Callable, Sequence, Tuple, Set
1616

1717
import attr
1818
import cirq
@@ -29,7 +29,11 @@ class QROM(unary_iteration_gate.UnaryIterationGate):
2929
"""Gate to load data[l] in the target register when the selection stores an index l.
3030
3131
In the case of multi-dimensional data[p,q,r,...] we use multiple named
32-
selection registers [p, q, r, ...] to index and load the data.
32+
selection registers [p, q, r, ...] to index and load the data. Here `p, q, r, ...`
33+
correspond to registers named `selection0`, `selection1`, `selection2`, ... etc.
34+
35+
When the input data elements contain consecutive entries of identical data elements to
36+
load, the QROM also implements the "variable-spaced" QROM optimization described in Ref[2].
3337
3438
Args:
3539
data: List of numpy ndarrays specifying the data to load. If the length
@@ -44,6 +48,15 @@ class QROM(unary_iteration_gate.UnaryIterationGate):
4448
registers. This can be deduced from the maximum element of each of the
4549
datasets. Should be of length len(data), i.e. the number of datasets.
4650
num_controls: The number of control registers.
51+
52+
References:
53+
[Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity]
54+
(https://arxiv.org/abs/1805.03662).
55+
Babbush et. al. (2018). Figure 1.
56+
57+
[Compilation of Fault-Tolerant Quantum Heuristics for Combinatorial Optimization]
58+
(https://arxiv.org/abs/2007.07391).
59+
Babbush et. al. (2020). Figure 3.
4760
"""
4861

4962
data: Sequence[NDArray]
@@ -152,11 +165,22 @@ def decompose_zero_selection(
152165
yield cirq.inverse(multi_controlled_and)
153166
context.qubit_manager.qfree(and_ancilla + [and_target])
154167

168+
def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int):
169+
global_unique_element: Set[int] = set()
170+
for data in self.data:
171+
unique_element = np.unique(data[selection_index_prefix][l:r])
172+
if len(unique_element) > 1:
173+
return False
174+
global_unique_element.add(unique_element[0])
175+
if len(global_unique_element) > 1:
176+
return False
177+
return True
178+
155179
def nth_operation(
156180
self, context: cirq.DecompositionContext, control: cirq.Qid, **kwargs
157181
) -> cirq.OP_TREE:
158182
selection_idx = tuple(kwargs[reg.name] for reg in self.selection_registers)
159-
target_regs = {k: v for k, v in kwargs.items() if k in self.target_registers}
183+
target_regs = {reg.name: kwargs[reg.name] for reg in self.target_registers}
160184
yield self._load_nth_data(selection_idx, lambda q: cirq.CNOT(control, q), **target_regs)
161185

162186
def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
@@ -172,4 +196,5 @@ def __pow__(self, power: int):
172196
return NotImplemented # pragma: no cover
173197

174198
def _value_equality_values_(self):
175-
return (self.selection_registers, self.target_registers, self.control_registers)
199+
data_tuple = tuple(tuple(d.flatten()) for d in self.data)
200+
return (self.selection_registers, self.target_registers, self.control_registers, data_tuple)

cirq-ft/cirq_ft/algos/qrom_test.py

+74
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,80 @@ def test_t_complexity(data):
116116
assert cirq_ft.t_complexity(g.gate).t == max(0, 4 * n - 8), n
117117

118118

119+
def _assert_qrom_has_diagram(qrom: cirq_ft.QROM, expected_diagram: str):
120+
gh = cirq_ft.testing.GateHelper(qrom)
121+
op = gh.operation
122+
context = cirq.DecompositionContext(qubit_manager=cirq_ft.GreedyQubitManager(prefix="anc"))
123+
circuit = cirq.Circuit(cirq.decompose_once(op, context=context))
124+
selection = [
125+
*itertools.chain.from_iterable(gh.quregs[reg.name] for reg in qrom.selection_registers)
126+
]
127+
selection = [q for q in selection if q in circuit.all_qubits()]
128+
anc = sorted(set(circuit.all_qubits()) - set(op.qubits))
129+
selection_and_anc = (selection[0],) + sum(zip(selection[1:], anc), ())
130+
qubit_order = cirq.QubitOrder.explicit(selection_and_anc, fallback=cirq.QubitOrder.DEFAULT)
131+
cirq.testing.assert_has_diagram(circuit, expected_diagram, qubit_order=qubit_order)
132+
133+
134+
def test_qrom_variable_spacing():
135+
# Tests for variable spacing optimization applied from https://arxiv.org/abs/2007.07391
136+
data = [1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8] # Figure 3a.
137+
assert cirq_ft.t_complexity(cirq_ft.QROM.build(data)).t == (8 - 2) * 4
138+
data = [1, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5] # Figure 3b.
139+
assert cirq_ft.t_complexity(cirq_ft.QROM.build(data)).t == (5 - 2) * 4
140+
data = [1, 2, 3, 4, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7] # Negative test: t count is not (g-2)*4
141+
assert cirq_ft.t_complexity(cirq_ft.QROM.build(data)).t == (8 - 2) * 4
142+
# Works as expected when multiple data arrays are to be loaded.
143+
data = [1, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5]
144+
assert cirq_ft.t_complexity(cirq_ft.QROM.build(data, data)).t == (5 - 2) * 4
145+
assert cirq_ft.t_complexity(cirq_ft.QROM.build(data, 2 * np.array(data))).t == (16 - 2) * 4
146+
# Works as expected when multidimensional input data is to be loaded
147+
qrom = cirq_ft.QROM.build(
148+
np.array(
149+
[
150+
[1, 1, 1, 1, 1, 1, 1, 1],
151+
[1, 1, 1, 1, 1, 1, 1, 1],
152+
[2, 2, 2, 2, 2, 2, 2, 2],
153+
[2, 2, 2, 2, 2, 2, 2, 2],
154+
]
155+
)
156+
)
157+
# Value to be loaded depends only the on the first bit of outer loop.
158+
_assert_qrom_has_diagram(
159+
qrom,
160+
r'''
161+
selection00: ───X───@───X───@───
162+
│ │
163+
target00: ──────────┼───────X───
164+
165+
target01: ──────────X───────────
166+
''',
167+
)
168+
# When inner loop range is not a power of 2, the inner segment tree cannot be skipped.
169+
qrom = cirq_ft.QROM.build(
170+
np.array(
171+
[[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [2, 2, 2, 2, 2, 2]],
172+
dtype=int,
173+
)
174+
)
175+
_assert_qrom_has_diagram(
176+
qrom,
177+
r'''
178+
selection00: ───X───@─────────@───────@──────X───@─────────@───────@──────
179+
│ │ │ │ │ │
180+
selection10: ───────(0)───────┼───────@──────────(0)───────┼───────@──────
181+
│ │ │ │ │ │
182+
anc_1: ─────────────And───@───X───@───And†───────And───@───X───@───And†───
183+
│ │ │ │
184+
target00: ────────────────┼───────┼────────────────────X───────X──────────
185+
│ │
186+
target01: ────────────────X───────X───────────────────────────────────────
187+
''',
188+
)
189+
# No T-gates needed if all elements to load are identical.
190+
assert cirq_ft.t_complexity(cirq_ft.QROM.build([3, 3, 3, 3])).t == 0
191+
192+
119193
@pytest.mark.parametrize(
120194
"data",
121195
[[np.arange(6).reshape(2, 3), 4 * np.arange(6).reshape(2, 3)], [np.arange(8).reshape(2, 2, 2)]],

cirq-ft/cirq_ft/algos/unary_iteration_gate.py

+75-15
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import abc
16-
from typing import Dict, Iterator, List, Sequence, Tuple
16+
from typing import Callable, Dict, Iterator, List, Sequence, Tuple
1717
from numpy.typing import NDArray
1818

1919
import cirq
@@ -34,6 +34,7 @@ def _unary_iteration_segtree(
3434
r: int,
3535
l_iter: int,
3636
r_iter: int,
37+
break_early: Callable[[int, int], bool],
3738
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
3839
"""Constructs a unary iteration circuit by iterating over nodes of an implicit Segment Tree.
3940
@@ -53,6 +54,11 @@ def _unary_iteration_segtree(
5354
r: Right index of the range represented by current node of the segment tree.
5455
l_iter: Left index of iteration range over which the segment tree should be constructed.
5556
r_iter: Right index of iteration range over which the segment tree should be constructed.
57+
break_early: For each internal node of the segment tree, `break_early(l, r)` is called to
58+
evaluate whether the unary iteration should terminate early and not recurse in the
59+
subtree of the node representing range `[l, r)`. If True, the internal node is
60+
considered equivalent to a leaf node and the method yields only one tuple
61+
`(OP_TREE, control_qubit, l)` for all integers in the range `[l, r)`.
5662
5763
Yields:
5864
One `Tuple[cirq.OP_TREE, cirq.Qid, int]` for each leaf node in the segment tree. The i'th
@@ -68,8 +74,8 @@ def _unary_iteration_segtree(
6874
if l >= r_iter or l_iter >= r:
6975
# Range corresponding to this node is completely outside of iteration range.
7076
return
71-
if l == (r - 1):
72-
# Reached a leaf node; yield the operations.
77+
if l_iter <= l < r <= r_iter and (l == (r - 1) or break_early(l, r)):
78+
# Reached a leaf node or a "special" internal node; yield the operations.
7379
yield tuple(ops), control, l
7480
ops.clear()
7581
return
@@ -78,20 +84,24 @@ def _unary_iteration_segtree(
7884
if r_iter <= m:
7985
# Yield only left sub-tree.
8086
yield from _unary_iteration_segtree(
81-
ops, control, selection, ancilla, sl + 1, l, m, l_iter, r_iter
87+
ops, control, selection, ancilla, sl + 1, l, m, l_iter, r_iter, break_early
8288
)
8389
return
8490
if l_iter >= m:
8591
# Yield only right sub-tree
8692
yield from _unary_iteration_segtree(
87-
ops, control, selection, ancilla, sl + 1, m, r, l_iter, r_iter
93+
ops, control, selection, ancilla, sl + 1, m, r, l_iter, r_iter, break_early
8894
)
8995
return
9096
anc, sq = ancilla[sl], selection[sl]
9197
ops.append(and_gate.And((1, 0)).on(control, sq, anc))
92-
yield from _unary_iteration_segtree(ops, anc, selection, ancilla, sl + 1, l, m, l_iter, r_iter)
98+
yield from _unary_iteration_segtree(
99+
ops, anc, selection, ancilla, sl + 1, l, m, l_iter, r_iter, break_early
100+
)
93101
ops.append(cirq.CNOT(control, anc))
94-
yield from _unary_iteration_segtree(ops, anc, selection, ancilla, sl + 1, m, r, l_iter, r_iter)
102+
yield from _unary_iteration_segtree(
103+
ops, anc, selection, ancilla, sl + 1, m, r, l_iter, r_iter, break_early
104+
)
95105
ops.append(and_gate.And(adjoint=True).on(control, sq, anc))
96106

97107

@@ -101,16 +111,17 @@ def _unary_iteration_zero_control(
101111
ancilla: Sequence[cirq.Qid],
102112
l_iter: int,
103113
r_iter: int,
114+
break_early: Callable[[int, int], bool],
104115
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
105116
sl, l, r = 0, 0, 2 ** len(selection)
106117
m = (l + r) >> 1
107118
ops.append(cirq.X(selection[0]))
108119
yield from _unary_iteration_segtree(
109-
ops, selection[0], selection[1:], ancilla, sl, l, m, l_iter, r_iter
120+
ops, selection[0], selection[1:], ancilla, sl, l, m, l_iter, r_iter, break_early
110121
)
111122
ops.append(cirq.X(selection[0]))
112123
yield from _unary_iteration_segtree(
113-
ops, selection[0], selection[1:], ancilla, sl, m, r, l_iter, r_iter
124+
ops, selection[0], selection[1:], ancilla, sl, m, r, l_iter, r_iter, break_early
114125
)
115126

116127

@@ -121,9 +132,12 @@ def _unary_iteration_single_control(
121132
ancilla: Sequence[cirq.Qid],
122133
l_iter: int,
123134
r_iter: int,
135+
break_early: Callable[[int, int], bool],
124136
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
125137
sl, l, r = 0, 0, 2 ** len(selection)
126-
yield from _unary_iteration_segtree(ops, control, selection, ancilla, sl, l, r, l_iter, r_iter)
138+
yield from _unary_iteration_segtree(
139+
ops, control, selection, ancilla, sl, l, r, l_iter, r_iter, break_early
140+
)
127141

128142

129143
def _unary_iteration_multi_controls(
@@ -133,6 +147,7 @@ def _unary_iteration_multi_controls(
133147
ancilla: Sequence[cirq.Qid],
134148
l_iter: int,
135149
r_iter: int,
150+
break_early: Callable[[int, int], bool],
136151
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
137152
num_controls = len(controls)
138153
and_ancilla = ancilla[: num_controls - 2]
@@ -142,7 +157,7 @@ def _unary_iteration_multi_controls(
142157
)
143158
ops.append(multi_controlled_and)
144159
yield from _unary_iteration_single_control(
145-
ops, and_target, selection, ancilla[num_controls - 1 :], l_iter, r_iter
160+
ops, and_target, selection, ancilla[num_controls - 1 :], l_iter, r_iter, break_early
146161
)
147162
ops.append(cirq.inverse(multi_controlled_and))
148163

@@ -154,6 +169,7 @@ def unary_iteration(
154169
controls: Sequence[cirq.Qid],
155170
selection: Sequence[cirq.Qid],
156171
qubit_manager: cirq.QubitManager,
172+
break_early: Callable[[int, int], bool] = lambda l, r: False,
157173
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
158174
"""The method performs unary iteration on `selection` integer in `range(l_iter, r_iter)`.
159175
@@ -181,6 +197,9 @@ def unary_iteration(
181197
... circuit.append(j_ops)
182198
>>> circuit.append(i_ops)
183199
200+
Note: Unary iteration circuits assume that the selection register stores integers only in the
201+
range `[l, r)` for which the corresponding unary iteration circuit should be built.
202+
184203
Args:
185204
l_iter: Starting index of the iteration range.
186205
r_iter: Ending index of the iteration range.
@@ -192,6 +211,11 @@ def unary_iteration(
192211
controls: Control register of qubits.
193212
selection: Selection register of qubits.
194213
qubit_manager: A `cirq.QubitManager` to allocate new qubits.
214+
break_early: For each internal node of the segment tree, `break_early(l, r)` is called to
215+
evaluate whether the unary iteration should terminate early and not recurse in the
216+
subtree of the node representing range `[l, r)`. If True, the internal node is
217+
considered equivalent to a leaf node and the method yields only one tuple
218+
`(OP_TREE, control_qubit, l)` for all integers in the range `[l, r)`.
195219
196220
Yields:
197221
(r_iter - l_iter) different tuples, each corresponding to an integer in range
@@ -207,14 +231,16 @@ def unary_iteration(
207231
assert len(selection) > 0
208232
ancilla = qubit_manager.qalloc(max(0, len(controls) + len(selection) - 1))
209233
if len(controls) == 0:
210-
yield from _unary_iteration_zero_control(flanking_ops, selection, ancilla, l_iter, r_iter)
234+
yield from _unary_iteration_zero_control(
235+
flanking_ops, selection, ancilla, l_iter, r_iter, break_early
236+
)
211237
elif len(controls) == 1:
212238
yield from _unary_iteration_single_control(
213-
flanking_ops, controls[0], selection, ancilla, l_iter, r_iter
239+
flanking_ops, controls[0], selection, ancilla, l_iter, r_iter, break_early
214240
)
215241
else:
216242
yield from _unary_iteration_multi_controls(
217-
flanking_ops, controls, selection, ancilla, l_iter, r_iter
243+
flanking_ops, controls, selection, ancilla, l_iter, r_iter, break_early
218244
)
219245
qubit_manager.qfree(ancilla)
220246

@@ -231,6 +257,9 @@ class UnaryIterationGate(infra.GateWithRegisters):
231257
indexed operations on a target register depending on the index value stored in a selection
232258
register.
233259
260+
Note: Unary iteration circuits assume that the selection register stores integers only in the
261+
range `[l, r)` for which the corresponding unary iteration circuit should be built.
262+
234263
References:
235264
[Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity]
236265
(https://arxiv.org/abs/1805.03662).
@@ -308,10 +337,38 @@ def decompose_zero_selection(
308337
"""
309338
raise NotImplementedError("Selection register must not be empty.")
310339

340+
def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int) -> bool:
341+
"""Derived classes should override this method to specify an early termination condition.
342+
343+
For each internal node of the unary iteration segment tree, `break_early(l, r)` is called
344+
to evaluate whether the unary iteration should not recurse in the subtree of the node
345+
representing range `[l, r)`. If True, the internal node is considered equivalent to a leaf
346+
node and thus, `self.nth_operation` will be called for only integer `l` in the range [l, r).
347+
348+
When the `UnaryIteration` class is constructed using multiple selection registers, i.e. we
349+
wish to perform nested coherent for-loops, a unary iteration segment tree is constructed
350+
corresponding to each nested coherent for-loop. For every such unary iteration segment tree,
351+
the `_break_early` condition is checked by passing the `selection_index_prefix` tuple.
352+
353+
Args:
354+
selection_index_prefix: To evaluate the early breaking condition for the i'th nested
355+
for-loop, the `selection_index_prefix` contains `i-1` integers corresponding to
356+
the loop variable values for the first `i-1` nested loops.
357+
l: Beginning of range `[l, r)` for internal node of unary iteration segment tree.
358+
r: End (exclusive) of range `[l, r)` for internal node of unary iteration segment tree.
359+
360+
Returns:
361+
True of the `len(selection_index_prefix)`'th unary iteration should terminate early for
362+
the given parameters.
363+
"""
364+
return False
365+
311366
def decompose_from_registers(
312367
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
313368
) -> cirq.OP_TREE:
314-
if self.selection_registers.total_bits() == 0:
369+
if self.selection_registers.total_bits() == 0 or self._break_early(
370+
(), 0, self.selection_registers[0].iteration_length
371+
):
315372
return self.decompose_zero_selection(context=context, **quregs)
316373

317374
num_loops = len(self.selection_registers)
@@ -354,20 +411,23 @@ def unary_iteration_loops(
354411
return
355412
# Use recursion to write `num_loops` nested loops using unary_iteration().
356413
ops: List[cirq.Operation] = []
414+
selection_index_prefix = tuple(selection_reg_name_to_val.values())
357415
ith_for_loop = unary_iteration(
358416
l_iter=0,
359417
r_iter=self.selection_registers[nested_depth].iteration_length,
360418
flanking_ops=ops,
361419
controls=controls,
362420
selection=[*quregs[self.selection_registers[nested_depth].name]],
363421
qubit_manager=context.qubit_manager,
422+
break_early=lambda l, r: self._break_early(selection_index_prefix, l, r),
364423
)
365424
for op_tree, control_qid, n in ith_for_loop:
366425
yield op_tree
367426
selection_reg_name_to_val[self.selection_registers[nested_depth].name] = n
368427
yield from unary_iteration_loops(
369428
nested_depth + 1, selection_reg_name_to_val, (control_qid,)
370429
)
430+
selection_reg_name_to_val.pop(self.selection_registers[nested_depth].name)
371431
yield ops
372432

373433
return unary_iteration_loops(0, {}, self.control_registers.merge_qubits(**quregs))

0 commit comments

Comments
 (0)