@@ -104,6 +104,134 @@ def map_moments(
104
104
)
105
105
106
106
107
+ def _map_operations_impl (
108
+ circuit : CIRCUIT_TYPE ,
109
+ map_func : Callable [[ops .Operation , int ], ops .OP_TREE ],
110
+ * ,
111
+ deep : bool = False ,
112
+ raise_if_add_qubits = True ,
113
+ tags_to_ignore : Sequence [Hashable ] = (),
114
+ wrap_in_circuit_op : bool = True ,
115
+ ) -> CIRCUIT_TYPE :
116
+ """Applies local transformations, by calling `map_func(op, moment_index)` for each operation.
117
+
118
+ This method provides a fast, iterative implementation for the two `map_operations_*` variants
119
+ exposed as public transformer primitives. The high level idea for the iterative implementation
120
+ is to
121
+ 1) For each operation `op`, find the corresponding mapped operation(s) `mapped_ops`. The
122
+ set of mapped operations can be either wrapped in a circuit operation or not, depending
123
+ on the value of flag `wrap_in_circuit_op` and whether the mapped operations will end up
124
+ occupying more than one moment or not.
125
+ 2) Use the `get_earliest_accommodating_moment_index` infrastructure built for `cirq.Circuit`
126
+ construction to determine the index at which the mapped operations should be inserted.
127
+ This step takes care of the nuances that arise due to (a) preserving moment structure
128
+ and (b) mapped operations spanning across multiple moments (these both are trivial when
129
+ `op` is mapped to a single `mapped_op` that acts on the same set of qubits).
130
+
131
+ By default, the function assumes `issubset(qubit_set(map_func(op, moment_index)), op.qubits)` is
132
+ True.
133
+
134
+ Args:
135
+ circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
136
+ map_func: Mapping function from (cirq.Operation, moment_index) to a cirq.OP_TREE. If the
137
+ resulting optree spans more than 1 moment, it's either wrapped in a tagged circuit
138
+ operation and inserted in-place in the same moment (if `wrap_in_circuit_op` is True)
139
+ OR the mapped operations are inserted directly in the circuit, preserving moment
140
+ strucutre. The effect is equivalent to (but much faster) a two-step approach of first
141
+ wrapping the operations in a circuit operation and then calling `cirq.unroll_circuit_op`
142
+ to unroll the corresponding circuit ops.
143
+ deep: If true, `map_func` will be recursively applied to circuits wrapped inside
144
+ any circuit operations contained within `circuit`.
145
+ raise_if_add_qubits: Set to True by default. If True, raises ValueError if
146
+ `map_func(op, idx)` adds operations on qubits outside of `op.qubits`.
147
+ tags_to_ignore: Sequence of tags which should be ignored while applying `map_func` on
148
+ tagged operations -- i.e. `map_func(op, idx)` will be called only for operations that
149
+ satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
150
+ wrap_in_circuit_op: If True, the mapped operations will be wrapped in a tagged circuit
151
+ operation and inserted in-place if they occupy more than one moment.
152
+
153
+ Raises:
154
+ ValueError if `issubset(qubit_set(map_func(op, idx)), op.qubits) is False` and
155
+ `raise_if_add_qubits is True`.
156
+
157
+ Returns:
158
+ Copy of input circuit with mapped operations.
159
+ """
160
+ tags_to_ignore_set = set (tags_to_ignore )
161
+
162
+ def apply_map_func (op : 'cirq.Operation' , idx : int ) -> List ['cirq.Operation' ]:
163
+ if tags_to_ignore_set .intersection (op .tags ):
164
+ return [op ]
165
+ if deep and isinstance (op .untagged , circuits .CircuitOperation ):
166
+ op = op .untagged .replace (
167
+ circuit = _map_operations_impl (
168
+ op .untagged .circuit ,
169
+ map_func ,
170
+ deep = deep ,
171
+ raise_if_add_qubits = raise_if_add_qubits ,
172
+ tags_to_ignore = tags_to_ignore ,
173
+ wrap_in_circuit_op = wrap_in_circuit_op ,
174
+ )
175
+ ).with_tags (* op .tags )
176
+ mapped_ops = [* ops .flatten_to_ops (map_func (op , idx ))]
177
+ op_qubits = set (op .qubits )
178
+ mapped_ops_qubits : Set ['cirq.Qid' ] = set ()
179
+ has_overlapping_ops = False
180
+ for mapped_op in mapped_ops :
181
+ if raise_if_add_qubits and not op_qubits .issuperset (mapped_op .qubits ):
182
+ raise ValueError (
183
+ f"Mapped operations { mapped_ops } should act on a subset "
184
+ f"of qubits of the original operation { op } "
185
+ )
186
+ if mapped_ops_qubits .intersection (mapped_op .qubits ):
187
+ has_overlapping_ops = True
188
+ mapped_ops_qubits = mapped_ops_qubits .union (mapped_op .qubits )
189
+ if wrap_in_circuit_op and has_overlapping_ops :
190
+ # Mapped operations should be wrapped in a `CircuitOperation` only iff they occupy more
191
+ # than one moment, i.e. there are at least two operations that share a qubit.
192
+ mapped_ops = [
193
+ circuits .CircuitOperation (circuits .FrozenCircuit (mapped_ops )).with_tags (
194
+ MAPPED_CIRCUIT_OP_TAG
195
+ )
196
+ ]
197
+ return mapped_ops
198
+
199
+ new_moments : List [List ['cirq.Operation' ]] = []
200
+
201
+ # Keep track of the latest time index for each qubit, measurement key, and control key.
202
+ qubit_time_index : Dict ['cirq.Qid' , int ] = {}
203
+ measurement_time_index : Dict ['cirq.MeasurementKey' , int ] = {}
204
+ control_time_index : Dict ['cirq.MeasurementKey' , int ] = {}
205
+
206
+ # New mapped operations in the current moment should be inserted after `last_moment_time_index`.
207
+ last_moment_time_index = - 1
208
+
209
+ for idx , moment in enumerate (circuit ):
210
+ if wrap_in_circuit_op :
211
+ new_moments .append ([])
212
+ for op in moment :
213
+ mapped_ops = apply_map_func (op , idx )
214
+
215
+ for mapped_op in mapped_ops :
216
+ # Identify the earliest moment that can accommodate this op.
217
+ placement_index = circuits .circuit .get_earliest_accommodating_moment_index (
218
+ mapped_op , qubit_time_index , measurement_time_index , control_time_index
219
+ )
220
+ placement_index = max (placement_index , last_moment_time_index + 1 )
221
+ new_moments .extend ([[] for _ in range (placement_index - len (new_moments ) + 1 )])
222
+ new_moments [placement_index ].append (mapped_op )
223
+ for qubit in mapped_op .qubits :
224
+ qubit_time_index [qubit ] = placement_index
225
+ for key in protocols .measurement_key_objs (mapped_op ):
226
+ measurement_time_index [key ] = placement_index
227
+ for key in protocols .control_keys (mapped_op ):
228
+ control_time_index [key ] = placement_index
229
+
230
+ last_moment_time_index = len (new_moments ) - 1
231
+
232
+ return _create_target_circuit_type ([circuits .Moment (moment ) for moment in new_moments ], circuit )
233
+
234
+
107
235
def map_operations (
108
236
circuit : CIRCUIT_TYPE ,
109
237
map_func : Callable [[ops .Operation , int ], ops .OP_TREE ],
@@ -139,29 +267,13 @@ def map_operations(
139
267
Returns:
140
268
Copy of input circuit with mapped operations (wrapped in a tagged CircuitOperation).
141
269
"""
142
-
143
- def apply_map (op : ops .Operation , idx : int ) -> ops .OP_TREE :
144
- if not set (op .tags ).isdisjoint (tags_to_ignore ):
145
- return op
146
- c = circuits .FrozenCircuit (map_func (op , idx ))
147
- if raise_if_add_qubits and not c .all_qubits ().issubset (op .qubits ):
148
- raise ValueError (
149
- f"Mapped operations { c .all_operations ()} should act on a subset "
150
- f"of qubits of the original operation { op } "
151
- )
152
- if len (c ) <= 1 :
153
- # Either empty circuit or all operations act in the same moment;
154
- # So, we don't need to wrap them in a circuit_op.
155
- return c [0 ].operations if c else []
156
- circuit_op = circuits .CircuitOperation (c ).with_tags (MAPPED_CIRCUIT_OP_TAG )
157
- return circuit_op
158
-
159
- return map_moments (
270
+ return _map_operations_impl (
160
271
circuit ,
161
- lambda m , i : circuits .Circuit (apply_map (op , i ) for op in m .operations ).moments
162
- or [circuits .Moment ()],
272
+ map_func ,
163
273
deep = deep ,
274
+ raise_if_add_qubits = raise_if_add_qubits ,
164
275
tags_to_ignore = tags_to_ignore ,
276
+ wrap_in_circuit_op = True ,
165
277
)
166
278
167
279
@@ -191,15 +303,13 @@ def map_operations_and_unroll(
191
303
Returns:
192
304
Copy of input circuit with mapped operations, unrolled in a moment preserving way.
193
305
"""
194
- return unroll_circuit_op (
195
- map_operations (
196
- circuit ,
197
- map_func ,
198
- deep = deep ,
199
- raise_if_add_qubits = raise_if_add_qubits ,
200
- tags_to_ignore = tags_to_ignore ,
201
- ),
306
+ return _map_operations_impl (
307
+ circuit ,
308
+ map_func ,
202
309
deep = deep ,
310
+ raise_if_add_qubits = raise_if_add_qubits ,
311
+ tags_to_ignore = tags_to_ignore ,
312
+ wrap_in_circuit_op = False ,
203
313
)
204
314
205
315
0 commit comments