Skip to content

Commit 6a97cca

Browse files
Refactor and speed up cirq.transformers.stratify (#6013)
* refactor cirq.transformers.stratify * fix one failing test * nit renaming * fix coverage * formatting fix * pylint fix * fix bug with measurements in stratification * add missing import * fix bug with finding time index for op * fix test, and nit change to keeping track of time indices * hopefully fix measurement bug * minor fix with ignored ops * fix test * only store shortest circuit found * minor bugfix * nit typing fix * one more silly bugfig * store shortest stratified circuit properly * fix bug with overlapping measurements * clean up handling of ignored ops * further clean up logic deciding where to put an op * factor out logic for finding earliest accomodating moment * fix typo * fix typo * remove unnecesaary use of defaultdict * further simplify logic in get_earliest_accommodating_moment_index * fix lint check * fix minor bug * nit docstring update * separately update qubit/mkey/ckey moments in cirq.stratify * fix bug with max only getting one argument * fix coverage check * fix typo * fix lint check --------- Co-authored-by: Tanuj Khattar <[email protected]>
1 parent 663d404 commit 6a97cca

File tree

4 files changed

+241
-121
lines changed

4 files changed

+241
-121
lines changed

cirq-core/cirq/circuits/circuit.py

+86-43
Original file line numberDiff line numberDiff line change
@@ -1776,12 +1776,10 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
17761776
Non-moment entries will be inserted according to the EARLIEST
17771777
insertion strategy.
17781778
"""
1779-
# These are dicts from the qubit/key to the greatest moment index that has it. It is safe
1780-
# to default to `-1`, as that is interpreted as meaning the zeroth index onward does not
1781-
# have this value.
1782-
qubit_indexes: Dict['cirq.Qid', int] = defaultdict(lambda: -1)
1783-
mkey_indexes: Dict['cirq.MeasurementKey', int] = defaultdict(lambda: -1)
1784-
ckey_indexes: Dict['cirq.MeasurementKey', int] = defaultdict(lambda: -1)
1779+
# These are dicts from the qubit/key to the greatest moment index that has it.
1780+
qubit_indices: Dict['cirq.Qid', int] = {}
1781+
mkey_indices: Dict['cirq.MeasurementKey', int] = {}
1782+
ckey_indices: Dict['cirq.MeasurementKey', int] = {}
17851783

17861784
# We also maintain the dict from moment index to moments/ops that go into it, for use when
17871785
# building the actual moments at the end.
@@ -1793,46 +1791,17 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
17931791

17941792
# "mop" means current moment-or-operation
17951793
for mop in ops.flatten_to_ops_or_moments(contents):
1796-
mop_qubits = mop.qubits
1797-
mop_mkeys = protocols.measurement_key_objs(mop)
1798-
mop_ckeys = protocols.control_keys(mop)
17991794

1800-
# Both branches define `i`, the moment index at which to place the mop.
1795+
# Identify the index of the moment to place this `mop` into.
1796+
placement_index = get_earliest_accommodating_moment_index(
1797+
mop, qubit_indices, mkey_indices, ckey_indices, length
1798+
)
1799+
length = max(length, placement_index + 1) # update the length of the circuit thus far
1800+
18011801
if isinstance(mop, Moment):
1802-
# We always append moment to the end, to be consistent with `self.append`
1803-
i = length
1804-
moments_by_index[i] = mop
1802+
moments_by_index[placement_index] = mop
18051803
else:
1806-
# Initially we define `i` as the greatest moment index that has a conflict. `-1` is
1807-
# the initial conflict, and we search for larger ones. Once we get the largest one,
1808-
# we increment i by 1 to set the placement index.
1809-
i = -1
1810-
1811-
# Look for the maximum conflict; i.e. a moment that has a qubit the same as one of
1812-
# this op's qubits, that has a measurement or control key the same as one of this
1813-
# op's measurement keys, or that has a measurement key the same as one of this op's
1814-
# control keys. (Control keys alone can commute past each other). The `ifs` are
1815-
# logically unnecessary but seem to make this slightly faster.
1816-
if mop_qubits:
1817-
i = max(i, *[qubit_indexes[q] for q in mop_qubits])
1818-
if mop_mkeys:
1819-
i = max(i, *[mkey_indexes[k] for k in mop_mkeys])
1820-
i = max(i, *[ckey_indexes[k] for k in mop_mkeys])
1821-
if mop_ckeys:
1822-
i = max(i, *[mkey_indexes[k] for k in mop_ckeys])
1823-
i += 1
1824-
op_lists_by_index[i].append(mop)
1825-
1826-
# Update our dicts with data from the latest mop placement. Note `i` will always be
1827-
# greater than the existing value for all of these, by construction, so there is no
1828-
# need to do a `max(i, existing)`.
1829-
for q in mop_qubits:
1830-
qubit_indexes[q] = i
1831-
for k in mop_mkeys:
1832-
mkey_indexes[k] = i
1833-
for k in mop_ckeys:
1834-
ckey_indexes[k] = i
1835-
length = max(length, i + 1)
1804+
op_lists_by_index[placement_index].append(mop)
18361805

18371806
# Finally, once everything is placed, we can construct and append the actual moments for
18381807
# each index.
@@ -2753,3 +2722,77 @@ def _group_until_different(items: Iterable[_TIn], key: Callable[[_TIn], _TKey],
27532722
Tuples containing the group key and item values.
27542723
"""
27552724
return ((k, [val(i) for i in v]) for (k, v) in itertools.groupby(items, key))
2725+
2726+
2727+
def get_earliest_accommodating_moment_index(
2728+
moment_or_operation: Union['cirq.Moment', 'cirq.Operation'],
2729+
qubit_indices: Dict['cirq.Qid', int],
2730+
mkey_indices: Dict['cirq.MeasurementKey', int],
2731+
ckey_indices: Dict['cirq.MeasurementKey', int],
2732+
length: Optional[int] = None,
2733+
) -> int:
2734+
"""Get the index of the earliest moment that can accomodate the given moment or operation.
2735+
2736+
Updates the dictionaries keeping track of the last moment index addressing a given qubit,
2737+
measurement key, and control key.
2738+
2739+
Args:
2740+
moment_or_operation: The moment operation in question.
2741+
qubit_indices: A dictionary mapping qubits to the latest moments that address them.
2742+
mkey_indices: A dictionary mapping measureent keys to the latest moments that address them.
2743+
ckey_indices: A dictionary mapping control keys to the latest moments that address them.
2744+
length: The length of the circuit that we are trying to insert a moment or operation into.
2745+
Should probably be equal to the maximum of the values in `qubit_indices`,
2746+
`mkey_indices`, and `ckey_indices`.
2747+
2748+
Returns:
2749+
The integer index of the earliest moment that can accomodate the given moment or operation.
2750+
"""
2751+
mop_qubits = moment_or_operation.qubits
2752+
mop_mkeys = protocols.measurement_key_objs(moment_or_operation)
2753+
mop_ckeys = protocols.control_keys(moment_or_operation)
2754+
2755+
if isinstance(moment_or_operation, Moment):
2756+
# For consistency with `Circuit.append`, moments always get placed at the end of a circuit.
2757+
if length is not None:
2758+
last_conflict = length - 1
2759+
else:
2760+
last_conflict = max(
2761+
[*qubit_indices.values(), *mkey_indices.values(), *ckey_indices.values(), -1]
2762+
)
2763+
2764+
else:
2765+
# We start by searching for the `latest_conflict` moment index, which we will increment by
2766+
# `1` to identify the earliest moment that *does not* conflict with the given operation.
2767+
# The `latest_conflict` is initialized to `-1` before searching for later conflicting
2768+
# moments.
2769+
last_conflict = -1
2770+
2771+
# Look for the maximum conflict; i.e. a moment that has a qubit the same as one of this op's
2772+
# qubits, that has a measurement or control key the same as one of this op's measurement
2773+
# keys, or that has a measurement key the same as one of this op's control keys. (Control
2774+
# keys alone can commute past each other). The `ifs` are logically unnecessary but seem to
2775+
# make this slightly faster.
2776+
if mop_qubits:
2777+
last_conflict = max(
2778+
last_conflict, *[qubit_indices.get(qubit, -1) for qubit in mop_qubits]
2779+
)
2780+
if mop_mkeys:
2781+
last_conflict = max(last_conflict, *[mkey_indices.get(key, -1) for key in mop_mkeys])
2782+
last_conflict = max(last_conflict, *[ckey_indices.get(key, -1) for key in mop_mkeys])
2783+
if mop_ckeys:
2784+
last_conflict = max(last_conflict, *[mkey_indices.get(key, -1) for key in mop_ckeys])
2785+
2786+
# The index of the moment to place this moment or operaton ("mop") into.
2787+
mop_index = last_conflict + 1
2788+
2789+
# Update our dicts with data from this `mop` placement. Note `mop_index` will always be greater
2790+
# than the existing value for all of these, by construction.
2791+
for qubit in mop_qubits:
2792+
qubit_indices[qubit] = mop_index
2793+
for key in mop_mkeys:
2794+
mkey_indices[key] = mop_index
2795+
for key in mop_ckeys:
2796+
ckey_indices[key] = mop_index
2797+
2798+
return mop_index

cirq-core/cirq/circuits/circuit_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,17 @@ def test_insert_moment():
834834
assert c.operation_at(qubit, actual_index) == operation[0]
835835

836836

837+
def test_circuit_length_inference():
838+
# tests that `get_earliest_accommodating_moment_index` properly computes circuit length
839+
circuit = cirq.Circuit(cirq.X(cirq.q(0)))
840+
qubit_indices = {cirq.q(0): 0}
841+
mkey_indices = {}
842+
ckey_indices = {}
843+
assert circuits.circuit.get_earliest_accommodating_moment_index(
844+
cirq.Moment(), qubit_indices, mkey_indices, ckey_indices
845+
) == len(circuit)
846+
847+
837848
def test_insert_inline_near_start():
838849
a = cirq.NamedQubit('a')
839850
b = cirq.NamedQubit('b')

cirq-core/cirq/transformers/stratify.py

+124-62
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
"""Transformer pass to repack circuits avoiding simultaneous operations with different classes."""
1616

1717
import itertools
18-
from typing import TYPE_CHECKING, Type, Callable, Optional, Union, Iterable, Sequence, List, Tuple
18+
from typing import TYPE_CHECKING, Type, Callable, Dict, Optional, Union, Iterable, Sequence, List
1919

20-
from cirq import ops, circuits, _import
21-
from cirq.transformers import transformer_api, transformer_primitives
20+
from cirq import ops, circuits, protocols, _import
21+
from cirq.transformers import transformer_api
2222

2323
drop_empty_moments = _import.LazyLoader('drop_empty_moments', globals(), 'cirq.transformers')
2424

@@ -61,38 +61,36 @@ def stratified_circuit(
6161
Returns:
6262
A copy of the original circuit, but with re-arranged operations.
6363
"""
64-
6564
# Normalize categories into classifier functions.
66-
classifiers = [_category_to_classifier(category) for category in categories]
67-
# Make the classifiers exhaustive by adding an "everything else" bucket.
68-
and_the_rest = lambda op: all(not classifier(op) for classifier in classifiers)
69-
classifiers_and_the_rest = [*classifiers, and_the_rest]
65+
classifiers = _get_classifiers(circuit, categories)
7066

7167
# Try the algorithm with each permutation of the classifiers.
72-
classifiers_permutations = list(itertools.permutations(classifiers_and_the_rest))
68+
smallest_depth = protocols.num_qubits(circuit) * len(circuit) + 1
69+
shortest_stratified_circuit = circuits.Circuit()
7370
reversed_circuit = circuit[::-1]
74-
solutions = []
75-
for c in classifiers_permutations:
76-
solutions.append(
77-
_stratify_circuit(
78-
circuit,
79-
classifiers=list(c),
80-
context=context or transformer_api.TransformerContext(),
81-
)
71+
for ordered_classifiers in itertools.permutations(classifiers):
72+
solution = _stratify_circuit(
73+
circuit,
74+
classifiers=ordered_classifiers,
75+
context=context or transformer_api.TransformerContext(),
8276
)
77+
if len(solution) < smallest_depth:
78+
shortest_stratified_circuit = solution
79+
smallest_depth = len(solution)
80+
8381
# Do the same thing, except this time in reverse. This helps for some
8482
# circuits because it inserts operations at the end instead of at the
8583
# beginning.
86-
solutions.append(
87-
_stratify_circuit(
88-
reversed_circuit,
89-
classifiers=list(c),
90-
context=context or transformer_api.TransformerContext(),
91-
)[::-1]
92-
)
84+
solution = _stratify_circuit(
85+
reversed_circuit,
86+
classifiers=ordered_classifiers,
87+
context=context or transformer_api.TransformerContext(),
88+
)[::-1]
89+
if len(solution) < smallest_depth:
90+
shortest_stratified_circuit = solution
91+
smallest_depth = len(solution)
9392

94-
# Return the shortest circuit.
95-
return min(solutions, key=lambda c: len(c))
93+
return shortest_stratified_circuit
9694

9795

9896
def _stratify_circuit(
@@ -116,43 +114,88 @@ def _stratify_circuit(
116114
Returns:
117115
The stratified circuit.
118116
"""
119-
num_categories = len(classifiers) + 1
120-
121-
def map_func(m: 'cirq.Moment', _) -> Sequence['cirq.Moment']:
122-
stratified_ops: List[List['cirq.Operation']] = [[] for _ in range(num_categories)]
123-
for op in m:
124-
if set(op.tags) & set(context.tags_to_ignore):
125-
stratified_ops[0].append(op)
126-
continue
127-
for i, classifier in enumerate(classifiers):
128-
if classifier(op):
129-
stratified_ops[i + 1].append(op)
130-
break
131-
return [circuits.Moment(op_list) for op_list in stratified_ops]
132-
133-
stratified_circuit = transformer_primitives.map_moments(circuit, map_func).unfreeze(copy=False)
134-
assert len(stratified_circuit) == len(circuit) * num_categories
135-
136-
# Try to move operations to the left to reduce circuit depth, preserving stratification.
137-
for curr_idx, moment in enumerate(stratified_circuit):
138-
curr_category = curr_idx % num_categories
139-
if curr_category == 0:
140-
# Moment containing tagged operations to be ignored.
141-
continue
142-
batch_removals: List[Tuple[int, 'cirq.Operation']] = []
143-
batch_inserts: List[Tuple[int, 'cirq.Operation']] = []
117+
num_classes = len(classifiers) + 1 # include one "extra" category for ignored operations
118+
new_moments: List[List['cirq.Operation']] = []
119+
120+
# Keep track of the the latest time index for each qubit, measurement key, and control key.
121+
qubit_time_index: Dict['cirq.Qid', int] = {}
122+
measurement_time_index: Dict['cirq.MeasurementKey', int] = {}
123+
control_time_index: Dict['cirq.MeasurementKey', int] = {}
124+
125+
# The minimum time index for operations with a tag in context.tags_to_ignore.
126+
last_ignored_ops_time_index = 0
127+
128+
for moment in circuit:
129+
# Identify the new time indices that operations should be moved into.
130+
ignored_ops = []
131+
op_time_indices = {}
144132
for op in moment:
145-
prv_idx = stratified_circuit.earliest_available_moment(op, end_moment_index=curr_idx)
146-
prv_category = prv_idx % num_categories
147-
should_move_to_next_batch = curr_category < prv_category
148-
prv_idx += curr_category - prv_category + num_categories * should_move_to_next_batch
149-
assert prv_idx <= curr_idx and prv_idx % num_categories == curr_idx % num_categories
150-
if prv_idx < curr_idx:
151-
batch_inserts.append((prv_idx, op))
152-
batch_removals.append((curr_idx, op))
153-
stratified_circuit.batch_remove(batch_removals)
154-
stratified_circuit.batch_insert_into(batch_inserts)
155-
return drop_empty_moments.drop_empty_moments(stratified_circuit)
133+
134+
# Identify the earliest moment that can accommodate this op.
135+
min_time_index_for_op = circuits.circuit.get_earliest_accommodating_moment_index(
136+
op, qubit_time_index, measurement_time_index, control_time_index
137+
)
138+
139+
# Identify the "class" of this operation (by index).
140+
ignored_op = any(tag in op.tags for tag in context.tags_to_ignore)
141+
if not ignored_op:
142+
op_class = _get_op_class(op, classifiers)
143+
else:
144+
op_class = len(classifiers)
145+
ignored_ops.append(op)
146+
min_time_index_for_op = max(min_time_index_for_op, last_ignored_ops_time_index + 1)
147+
148+
# Identify the time index to place this operation into.
149+
time_index = (min_time_index_for_op // num_classes) * num_classes + op_class
150+
if time_index < min_time_index_for_op:
151+
time_index += num_classes
152+
op_time_indices[op] = time_index
153+
154+
# Assign ignored operations to the same moment.
155+
if ignored_ops:
156+
last_ignored_ops_time_index = max(op_time_indices[op] for op in ignored_ops)
157+
for op in ignored_ops:
158+
op_time_indices[op] = last_ignored_ops_time_index
159+
160+
# Move the operations into their assigned moments.
161+
for op, time_index in op_time_indices.items():
162+
if time_index >= len(new_moments):
163+
new_moments += [[] for _ in range(num_classes)]
164+
new_moments[time_index].append(op)
165+
166+
# Update qubit, measurment key, and control key moments.
167+
for qubit in op.qubits:
168+
qubit_time_index[qubit] = time_index
169+
for key in protocols.measurement_key_objs(op):
170+
measurement_time_index[key] = time_index
171+
for key in protocols.control_keys(op):
172+
control_time_index[key] = time_index
173+
174+
return circuits.Circuit(circuits.Moment(moment) for moment in new_moments if moment)
175+
176+
177+
def _get_classifiers(
178+
circuit: circuits.AbstractCircuit, categories: Iterable[Category]
179+
) -> List[Classifier]:
180+
"""Convert a collection of categories into a list of classifiers.
181+
182+
The returned list of classifiers is:
183+
- Exhaustive, meaning every operation in the circuit is classified by at least one classifier.
184+
- Minimal, meaning unused classifiers are forgotten.
185+
"""
186+
# Convert all categories into classifiers, and make the list exhaustive by adding a dummy
187+
# classifier for otherwise unclassified ops.
188+
classifiers = [_category_to_classifier(cat) for cat in categories] + [_dummy_classifier]
189+
190+
# Figure out which classes are actually used in the circuit.
191+
class_is_used = [False for _ in classifiers]
192+
for op in circuit.all_operations():
193+
class_is_used[_get_op_class(op, classifiers)] = True
194+
if all(class_is_used):
195+
break
196+
197+
# Return only the classifiers that are used.
198+
return [classifier for classifier, is_used in zip(classifiers, class_is_used) if is_used]
156199

157200

158201
# No type for `category` because mypy does not keep the return type when
@@ -177,3 +220,22 @@ def _category_to_classifier(category) -> Classifier:
177220
f'Type[cirq.Gate], Type[cirq.Operation], '
178221
f'or Callable[[cirq.Operation], bool].'
179222
)
223+
224+
225+
def _dummy_classifier(op: 'cirq.Operation') -> bool:
226+
"""Dummy classifier, used to "complete" a collection of classifiers and make it exhaustive."""
227+
228+
229+
def _get_op_class(op: 'cirq.Operation', classifiers: Sequence[Classifier]) -> int:
230+
"""Get the "class" of an operator, by index."""
231+
for class_index, classifier in enumerate(classifiers):
232+
if classifier is _dummy_classifier:
233+
dummy_classifier_index = class_index
234+
elif classifier(op):
235+
return class_index
236+
# If we got this far, the operation did not match any "actual" classifier,
237+
# so return the index of the dummy classifer.
238+
try:
239+
return dummy_classifier_index
240+
except NameError:
241+
raise ValueError(f"Operation {op} not identified by any classifier")

0 commit comments

Comments
 (0)