Skip to content

Commit 7ed95aa

Browse files
authored
Speed up cirq.map_operations and cirq.map_operations_and_unroll (#6250)
* Speed up cirq.map_operations and cirq.map_operations_and_unroll * Mypy typing and minor bug fixes * Fix pylint * Revert unrelated change to mypy script * Address nits
1 parent be6218e commit 7ed95aa

File tree

1 file changed

+138
-28
lines changed

1 file changed

+138
-28
lines changed

cirq-core/cirq/transformers/transformer_primitives.py

+138-28
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,134 @@ def map_moments(
104104
)
105105

106106

107+
def _map_operations_impl(
108+
circuit: CIRCUIT_TYPE,
109+
map_func: Callable[[ops.Operation, int], ops.OP_TREE],
110+
*,
111+
deep: bool = False,
112+
raise_if_add_qubits=True,
113+
tags_to_ignore: Sequence[Hashable] = (),
114+
wrap_in_circuit_op: bool = True,
115+
) -> CIRCUIT_TYPE:
116+
"""Applies local transformations, by calling `map_func(op, moment_index)` for each operation.
117+
118+
This method provides a fast, iterative implementation for the two `map_operations_*` variants
119+
exposed as public transformer primitives. The high level idea for the iterative implementation
120+
is to
121+
1) For each operation `op`, find the corresponding mapped operation(s) `mapped_ops`. The
122+
set of mapped operations can be either wrapped in a circuit operation or not, depending
123+
on the value of flag `wrap_in_circuit_op` and whether the mapped operations will end up
124+
occupying more than one moment or not.
125+
2) Use the `get_earliest_accommodating_moment_index` infrastructure built for `cirq.Circuit`
126+
construction to determine the index at which the mapped operations should be inserted.
127+
This step takes care of the nuances that arise due to (a) preserving moment structure
128+
and (b) mapped operations spanning across multiple moments (these both are trivial when
129+
`op` is mapped to a single `mapped_op` that acts on the same set of qubits).
130+
131+
By default, the function assumes `issubset(qubit_set(map_func(op, moment_index)), op.qubits)` is
132+
True.
133+
134+
Args:
135+
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
136+
map_func: Mapping function from (cirq.Operation, moment_index) to a cirq.OP_TREE. If the
137+
resulting optree spans more than 1 moment, it's either wrapped in a tagged circuit
138+
operation and inserted in-place in the same moment (if `wrap_in_circuit_op` is True)
139+
OR the mapped operations are inserted directly in the circuit, preserving moment
140+
strucutre. The effect is equivalent to (but much faster) a two-step approach of first
141+
wrapping the operations in a circuit operation and then calling `cirq.unroll_circuit_op`
142+
to unroll the corresponding circuit ops.
143+
deep: If true, `map_func` will be recursively applied to circuits wrapped inside
144+
any circuit operations contained within `circuit`.
145+
raise_if_add_qubits: Set to True by default. If True, raises ValueError if
146+
`map_func(op, idx)` adds operations on qubits outside of `op.qubits`.
147+
tags_to_ignore: Sequence of tags which should be ignored while applying `map_func` on
148+
tagged operations -- i.e. `map_func(op, idx)` will be called only for operations that
149+
satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
150+
wrap_in_circuit_op: If True, the mapped operations will be wrapped in a tagged circuit
151+
operation and inserted in-place if they occupy more than one moment.
152+
153+
Raises:
154+
ValueError if `issubset(qubit_set(map_func(op, idx)), op.qubits) is False` and
155+
`raise_if_add_qubits is True`.
156+
157+
Returns:
158+
Copy of input circuit with mapped operations.
159+
"""
160+
tags_to_ignore_set = set(tags_to_ignore)
161+
162+
def apply_map_func(op: 'cirq.Operation', idx: int) -> List['cirq.Operation']:
163+
if tags_to_ignore_set.intersection(op.tags):
164+
return [op]
165+
if deep and isinstance(op.untagged, circuits.CircuitOperation):
166+
op = op.untagged.replace(
167+
circuit=_map_operations_impl(
168+
op.untagged.circuit,
169+
map_func,
170+
deep=deep,
171+
raise_if_add_qubits=raise_if_add_qubits,
172+
tags_to_ignore=tags_to_ignore,
173+
wrap_in_circuit_op=wrap_in_circuit_op,
174+
)
175+
).with_tags(*op.tags)
176+
mapped_ops = [*ops.flatten_to_ops(map_func(op, idx))]
177+
op_qubits = set(op.qubits)
178+
mapped_ops_qubits: Set['cirq.Qid'] = set()
179+
has_overlapping_ops = False
180+
for mapped_op in mapped_ops:
181+
if raise_if_add_qubits and not op_qubits.issuperset(mapped_op.qubits):
182+
raise ValueError(
183+
f"Mapped operations {mapped_ops} should act on a subset "
184+
f"of qubits of the original operation {op}"
185+
)
186+
if mapped_ops_qubits.intersection(mapped_op.qubits):
187+
has_overlapping_ops = True
188+
mapped_ops_qubits = mapped_ops_qubits.union(mapped_op.qubits)
189+
if wrap_in_circuit_op and has_overlapping_ops:
190+
# Mapped operations should be wrapped in a `CircuitOperation` only iff they occupy more
191+
# than one moment, i.e. there are at least two operations that share a qubit.
192+
mapped_ops = [
193+
circuits.CircuitOperation(circuits.FrozenCircuit(mapped_ops)).with_tags(
194+
MAPPED_CIRCUIT_OP_TAG
195+
)
196+
]
197+
return mapped_ops
198+
199+
new_moments: List[List['cirq.Operation']] = []
200+
201+
# Keep track of the latest time index for each qubit, measurement key, and control key.
202+
qubit_time_index: Dict['cirq.Qid', int] = {}
203+
measurement_time_index: Dict['cirq.MeasurementKey', int] = {}
204+
control_time_index: Dict['cirq.MeasurementKey', int] = {}
205+
206+
# New mapped operations in the current moment should be inserted after `last_moment_time_index`.
207+
last_moment_time_index = -1
208+
209+
for idx, moment in enumerate(circuit):
210+
if wrap_in_circuit_op:
211+
new_moments.append([])
212+
for op in moment:
213+
mapped_ops = apply_map_func(op, idx)
214+
215+
for mapped_op in mapped_ops:
216+
# Identify the earliest moment that can accommodate this op.
217+
placement_index = circuits.circuit.get_earliest_accommodating_moment_index(
218+
mapped_op, qubit_time_index, measurement_time_index, control_time_index
219+
)
220+
placement_index = max(placement_index, last_moment_time_index + 1)
221+
new_moments.extend([[] for _ in range(placement_index - len(new_moments) + 1)])
222+
new_moments[placement_index].append(mapped_op)
223+
for qubit in mapped_op.qubits:
224+
qubit_time_index[qubit] = placement_index
225+
for key in protocols.measurement_key_objs(mapped_op):
226+
measurement_time_index[key] = placement_index
227+
for key in protocols.control_keys(mapped_op):
228+
control_time_index[key] = placement_index
229+
230+
last_moment_time_index = len(new_moments) - 1
231+
232+
return _create_target_circuit_type([circuits.Moment(moment) for moment in new_moments], circuit)
233+
234+
107235
def map_operations(
108236
circuit: CIRCUIT_TYPE,
109237
map_func: Callable[[ops.Operation, int], ops.OP_TREE],
@@ -139,29 +267,13 @@ def map_operations(
139267
Returns:
140268
Copy of input circuit with mapped operations (wrapped in a tagged CircuitOperation).
141269
"""
142-
143-
def apply_map(op: ops.Operation, idx: int) -> ops.OP_TREE:
144-
if not set(op.tags).isdisjoint(tags_to_ignore):
145-
return op
146-
c = circuits.FrozenCircuit(map_func(op, idx))
147-
if raise_if_add_qubits and not c.all_qubits().issubset(op.qubits):
148-
raise ValueError(
149-
f"Mapped operations {c.all_operations()} should act on a subset "
150-
f"of qubits of the original operation {op}"
151-
)
152-
if len(c) <= 1:
153-
# Either empty circuit or all operations act in the same moment;
154-
# So, we don't need to wrap them in a circuit_op.
155-
return c[0].operations if c else []
156-
circuit_op = circuits.CircuitOperation(c).with_tags(MAPPED_CIRCUIT_OP_TAG)
157-
return circuit_op
158-
159-
return map_moments(
270+
return _map_operations_impl(
160271
circuit,
161-
lambda m, i: circuits.Circuit(apply_map(op, i) for op in m.operations).moments
162-
or [circuits.Moment()],
272+
map_func,
163273
deep=deep,
274+
raise_if_add_qubits=raise_if_add_qubits,
164275
tags_to_ignore=tags_to_ignore,
276+
wrap_in_circuit_op=True,
165277
)
166278

167279

@@ -191,15 +303,13 @@ def map_operations_and_unroll(
191303
Returns:
192304
Copy of input circuit with mapped operations, unrolled in a moment preserving way.
193305
"""
194-
return unroll_circuit_op(
195-
map_operations(
196-
circuit,
197-
map_func,
198-
deep=deep,
199-
raise_if_add_qubits=raise_if_add_qubits,
200-
tags_to_ignore=tags_to_ignore,
201-
),
306+
return _map_operations_impl(
307+
circuit,
308+
map_func,
202309
deep=deep,
310+
raise_if_add_qubits=raise_if_add_qubits,
311+
tags_to_ignore=tags_to_ignore,
312+
wrap_in_circuit_op=False,
203313
)
204314

205315

0 commit comments

Comments
 (0)