15
15
"""Transformer pass to repack circuits avoiding simultaneous operations with different classes."""
16
16
17
17
import itertools
18
- from typing import TYPE_CHECKING , Type , Callable , Optional , Union , Iterable , Sequence , List , Tuple
18
+ from typing import TYPE_CHECKING , Type , Callable , Dict , Optional , Union , Iterable , Sequence , List
19
19
20
- from cirq import ops , circuits , _import
21
- from cirq .transformers import transformer_api , transformer_primitives
20
+ from cirq import ops , circuits , protocols , _import
21
+ from cirq .transformers import transformer_api
22
22
23
23
drop_empty_moments = _import .LazyLoader ('drop_empty_moments' , globals (), 'cirq.transformers' )
24
24
@@ -61,38 +61,36 @@ def stratified_circuit(
61
61
Returns:
62
62
A copy of the original circuit, but with re-arranged operations.
63
63
"""
64
-
65
64
# Normalize categories into classifier functions.
66
- classifiers = [_category_to_classifier (category ) for category in categories ]
67
- # Make the classifiers exhaustive by adding an "everything else" bucket.
68
- and_the_rest = lambda op : all (not classifier (op ) for classifier in classifiers )
69
- classifiers_and_the_rest = [* classifiers , and_the_rest ]
65
+ classifiers = _get_classifiers (circuit , categories )
70
66
71
67
# Try the algorithm with each permutation of the classifiers.
72
- classifiers_permutations = list (itertools .permutations (classifiers_and_the_rest ))
68
+ smallest_depth = protocols .num_qubits (circuit ) * len (circuit ) + 1
69
+ shortest_stratified_circuit = circuits .Circuit ()
73
70
reversed_circuit = circuit [::- 1 ]
74
- solutions = []
75
- for c in classifiers_permutations :
76
- solutions .append (
77
- _stratify_circuit (
78
- circuit ,
79
- classifiers = list (c ),
80
- context = context or transformer_api .TransformerContext (),
81
- )
71
+ for ordered_classifiers in itertools .permutations (classifiers ):
72
+ solution = _stratify_circuit (
73
+ circuit ,
74
+ classifiers = ordered_classifiers ,
75
+ context = context or transformer_api .TransformerContext (),
82
76
)
77
+ if len (solution ) < smallest_depth :
78
+ shortest_stratified_circuit = solution
79
+ smallest_depth = len (solution )
80
+
83
81
# Do the same thing, except this time in reverse. This helps for some
84
82
# circuits because it inserts operations at the end instead of at the
85
83
# beginning.
86
- solutions .append (
87
- _stratify_circuit (
88
- reversed_circuit ,
89
- classifiers = list (c ),
90
- context = context or transformer_api .TransformerContext (),
91
- )[::- 1 ]
92
- )
84
+ solution = _stratify_circuit (
85
+ reversed_circuit ,
86
+ classifiers = ordered_classifiers ,
87
+ context = context or transformer_api .TransformerContext (),
88
+ )[::- 1 ]
89
+ if len (solution ) < smallest_depth :
90
+ shortest_stratified_circuit = solution
91
+ smallest_depth = len (solution )
93
92
94
- # Return the shortest circuit.
95
- return min (solutions , key = lambda c : len (c ))
93
+ return shortest_stratified_circuit
96
94
97
95
98
96
def _stratify_circuit (
@@ -116,43 +114,88 @@ def _stratify_circuit(
116
114
Returns:
117
115
The stratified circuit.
118
116
"""
119
- num_categories = len (classifiers ) + 1
120
-
121
- def map_func (m : 'cirq.Moment' , _ ) -> Sequence ['cirq.Moment' ]:
122
- stratified_ops : List [List ['cirq.Operation' ]] = [[] for _ in range (num_categories )]
123
- for op in m :
124
- if set (op .tags ) & set (context .tags_to_ignore ):
125
- stratified_ops [0 ].append (op )
126
- continue
127
- for i , classifier in enumerate (classifiers ):
128
- if classifier (op ):
129
- stratified_ops [i + 1 ].append (op )
130
- break
131
- return [circuits .Moment (op_list ) for op_list in stratified_ops ]
132
-
133
- stratified_circuit = transformer_primitives .map_moments (circuit , map_func ).unfreeze (copy = False )
134
- assert len (stratified_circuit ) == len (circuit ) * num_categories
135
-
136
- # Try to move operations to the left to reduce circuit depth, preserving stratification.
137
- for curr_idx , moment in enumerate (stratified_circuit ):
138
- curr_category = curr_idx % num_categories
139
- if curr_category == 0 :
140
- # Moment containing tagged operations to be ignored.
141
- continue
142
- batch_removals : List [Tuple [int , 'cirq.Operation' ]] = []
143
- batch_inserts : List [Tuple [int , 'cirq.Operation' ]] = []
117
+ num_classes = len (classifiers ) + 1 # include one "extra" category for ignored operations
118
+ new_moments : List [List ['cirq.Operation' ]] = []
119
+
120
+ # Keep track of the the latest time index for each qubit, measurement key, and control key.
121
+ qubit_time_index : Dict ['cirq.Qid' , int ] = {}
122
+ measurement_time_index : Dict ['cirq.MeasurementKey' , int ] = {}
123
+ control_time_index : Dict ['cirq.MeasurementKey' , int ] = {}
124
+
125
+ # The minimum time index for operations with a tag in context.tags_to_ignore.
126
+ last_ignored_ops_time_index = 0
127
+
128
+ for moment in circuit :
129
+ # Identify the new time indices that operations should be moved into.
130
+ ignored_ops = []
131
+ op_time_indices = {}
144
132
for op in moment :
145
- prv_idx = stratified_circuit .earliest_available_moment (op , end_moment_index = curr_idx )
146
- prv_category = prv_idx % num_categories
147
- should_move_to_next_batch = curr_category < prv_category
148
- prv_idx += curr_category - prv_category + num_categories * should_move_to_next_batch
149
- assert prv_idx <= curr_idx and prv_idx % num_categories == curr_idx % num_categories
150
- if prv_idx < curr_idx :
151
- batch_inserts .append ((prv_idx , op ))
152
- batch_removals .append ((curr_idx , op ))
153
- stratified_circuit .batch_remove (batch_removals )
154
- stratified_circuit .batch_insert_into (batch_inserts )
155
- return drop_empty_moments .drop_empty_moments (stratified_circuit )
133
+
134
+ # Identify the earliest moment that can accommodate this op.
135
+ min_time_index_for_op = circuits .circuit .get_earliest_accommodating_moment_index (
136
+ op , qubit_time_index , measurement_time_index , control_time_index
137
+ )
138
+
139
+ # Identify the "class" of this operation (by index).
140
+ ignored_op = any (tag in op .tags for tag in context .tags_to_ignore )
141
+ if not ignored_op :
142
+ op_class = _get_op_class (op , classifiers )
143
+ else :
144
+ op_class = len (classifiers )
145
+ ignored_ops .append (op )
146
+ min_time_index_for_op = max (min_time_index_for_op , last_ignored_ops_time_index + 1 )
147
+
148
+ # Identify the time index to place this operation into.
149
+ time_index = (min_time_index_for_op // num_classes ) * num_classes + op_class
150
+ if time_index < min_time_index_for_op :
151
+ time_index += num_classes
152
+ op_time_indices [op ] = time_index
153
+
154
+ # Assign ignored operations to the same moment.
155
+ if ignored_ops :
156
+ last_ignored_ops_time_index = max (op_time_indices [op ] for op in ignored_ops )
157
+ for op in ignored_ops :
158
+ op_time_indices [op ] = last_ignored_ops_time_index
159
+
160
+ # Move the operations into their assigned moments.
161
+ for op , time_index in op_time_indices .items ():
162
+ if time_index >= len (new_moments ):
163
+ new_moments += [[] for _ in range (num_classes )]
164
+ new_moments [time_index ].append (op )
165
+
166
+ # Update qubit, measurment key, and control key moments.
167
+ for qubit in op .qubits :
168
+ qubit_time_index [qubit ] = time_index
169
+ for key in protocols .measurement_key_objs (op ):
170
+ measurement_time_index [key ] = time_index
171
+ for key in protocols .control_keys (op ):
172
+ control_time_index [key ] = time_index
173
+
174
+ return circuits .Circuit (circuits .Moment (moment ) for moment in new_moments if moment )
175
+
176
+
177
+ def _get_classifiers (
178
+ circuit : circuits .AbstractCircuit , categories : Iterable [Category ]
179
+ ) -> List [Classifier ]:
180
+ """Convert a collection of categories into a list of classifiers.
181
+
182
+ The returned list of classifiers is:
183
+ - Exhaustive, meaning every operation in the circuit is classified by at least one classifier.
184
+ - Minimal, meaning unused classifiers are forgotten.
185
+ """
186
+ # Convert all categories into classifiers, and make the list exhaustive by adding a dummy
187
+ # classifier for otherwise unclassified ops.
188
+ classifiers = [_category_to_classifier (cat ) for cat in categories ] + [_dummy_classifier ]
189
+
190
+ # Figure out which classes are actually used in the circuit.
191
+ class_is_used = [False for _ in classifiers ]
192
+ for op in circuit .all_operations ():
193
+ class_is_used [_get_op_class (op , classifiers )] = True
194
+ if all (class_is_used ):
195
+ break
196
+
197
+ # Return only the classifiers that are used.
198
+ return [classifier for classifier , is_used in zip (classifiers , class_is_used ) if is_used ]
156
199
157
200
158
201
# No type for `category` because mypy does not keep the return type when
@@ -177,3 +220,22 @@ def _category_to_classifier(category) -> Classifier:
177
220
f'Type[cirq.Gate], Type[cirq.Operation], '
178
221
f'or Callable[[cirq.Operation], bool].'
179
222
)
223
+
224
+
225
+ def _dummy_classifier (op : 'cirq.Operation' ) -> bool :
226
+ """Dummy classifier, used to "complete" a collection of classifiers and make it exhaustive."""
227
+
228
+
229
+ def _get_op_class (op : 'cirq.Operation' , classifiers : Sequence [Classifier ]) -> int :
230
+ """Get the "class" of an operator, by index."""
231
+ for class_index , classifier in enumerate (classifiers ):
232
+ if classifier is _dummy_classifier :
233
+ dummy_classifier_index = class_index
234
+ elif classifier (op ):
235
+ return class_index
236
+ # If we got this far, the operation did not match any "actual" classifier,
237
+ # so return the index of the dummy classifer.
238
+ try :
239
+ return dummy_classifier_index
240
+ except NameError :
241
+ raise ValueError (f"Operation { op } not identified by any classifier" )
0 commit comments