diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index a0df6ef1b28..fc2125b39ab 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -104,6 +104,134 @@ def map_moments( ) +def _map_operations_impl( + circuit: CIRCUIT_TYPE, + map_func: Callable[[ops.Operation, int], ops.OP_TREE], + *, + deep: bool = False, + raise_if_add_qubits=True, + tags_to_ignore: Sequence[Hashable] = (), + wrap_in_circuit_op: bool = True, +) -> CIRCUIT_TYPE: + """Applies local transformations, by calling `map_func(op, moment_index)` for each operation. + + This method provides a fast, iterative implementation for the two `map_operations_*` variants + exposed as public transformer primitives. The high level idea for the iterative implementation + is to + 1) For each operation `op`, find the corresponding mapped operation(s) `mapped_ops`. The + set of mapped operations can be either wrapped in a circuit operation or not, depending + on the value of flag `wrap_in_circuit_op` and whether the mapped operations will end up + occupying more than one moment or not. + 2) Use the `get_earliest_accommodating_moment_index` infrastructure built for `cirq.Circuit` + construction to determine the index at which the mapped operations should be inserted. + This step takes care of the nuances that arise due to (a) preserving moment structure + and (b) mapped operations spanning across multiple moments (these both are trivial when + `op` is mapped to a single `mapped_op` that acts on the same set of qubits). + + By default, the function assumes `issubset(qubit_set(map_func(op, moment_index)), op.qubits)` is + True. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + map_func: Mapping function from (cirq.Operation, moment_index) to a cirq.OP_TREE. If the + resulting optree spans more than 1 moment, it's either wrapped in a tagged circuit + operation and inserted in-place in the same moment (if `wrap_in_circuit_op` is True) + OR the mapped operations are inserted directly in the circuit, preserving moment + strucutre. The effect is equivalent to (but much faster) a two-step approach of first + wrapping the operations in a circuit operation and then calling `cirq.unroll_circuit_op` + to unroll the corresponding circuit ops. + deep: If true, `map_func` will be recursively applied to circuits wrapped inside + any circuit operations contained within `circuit`. + raise_if_add_qubits: Set to True by default. If True, raises ValueError if + `map_func(op, idx)` adds operations on qubits outside of `op.qubits`. + tags_to_ignore: Sequence of tags which should be ignored while applying `map_func` on + tagged operations -- i.e. `map_func(op, idx)` will be called only for operations that + satisfy `set(op.tags).isdisjoint(tags_to_ignore)`. + wrap_in_circuit_op: If True, the mapped operations will be wrapped in a tagged circuit + operation and inserted in-place if they occupy more than one moment. + + Raises: + ValueError if `issubset(qubit_set(map_func(op, idx)), op.qubits) is False` and + `raise_if_add_qubits is True`. + + Returns: + Copy of input circuit with mapped operations. + """ + tags_to_ignore_set = set(tags_to_ignore) + + def apply_map_func(op: 'cirq.Operation', idx: int) -> List['cirq.Operation']: + if tags_to_ignore_set.intersection(op.tags): + return [op] + if deep and isinstance(op.untagged, circuits.CircuitOperation): + op = op.untagged.replace( + circuit=_map_operations_impl( + op.untagged.circuit, + map_func, + deep=deep, + raise_if_add_qubits=raise_if_add_qubits, + tags_to_ignore=tags_to_ignore, + wrap_in_circuit_op=wrap_in_circuit_op, + ) + ).with_tags(*op.tags) + mapped_ops = [*ops.flatten_to_ops(map_func(op, idx))] + op_qubits = set(op.qubits) + mapped_ops_qubits: Set['cirq.Qid'] = set() + has_overlapping_ops = False + for mapped_op in mapped_ops: + if raise_if_add_qubits and not op_qubits.issuperset(mapped_op.qubits): + raise ValueError( + f"Mapped operations {mapped_ops} should act on a subset " + f"of qubits of the original operation {op}" + ) + if mapped_ops_qubits.intersection(mapped_op.qubits): + has_overlapping_ops = True + mapped_ops_qubits = mapped_ops_qubits.union(mapped_op.qubits) + if wrap_in_circuit_op and has_overlapping_ops: + # Mapped operations should be wrapped in a `CircuitOperation` only iff they occupy more + # than one moment, i.e. there are at least two operations that share a qubit. + mapped_ops = [ + circuits.CircuitOperation(circuits.FrozenCircuit(mapped_ops)).with_tags( + MAPPED_CIRCUIT_OP_TAG + ) + ] + return mapped_ops + + new_moments: List[List['cirq.Operation']] = [] + + # Keep track of the latest time index for each qubit, measurement key, and control key. + qubit_time_index: Dict['cirq.Qid', int] = {} + measurement_time_index: Dict['cirq.MeasurementKey', int] = {} + control_time_index: Dict['cirq.MeasurementKey', int] = {} + + # New mapped operations in the current moment should be inserted after `last_moment_time_index`. + last_moment_time_index = -1 + + for idx, moment in enumerate(circuit): + if wrap_in_circuit_op: + new_moments.append([]) + for op in moment: + mapped_ops = apply_map_func(op, idx) + + for mapped_op in mapped_ops: + # Identify the earliest moment that can accommodate this op. + placement_index = circuits.circuit.get_earliest_accommodating_moment_index( + mapped_op, qubit_time_index, measurement_time_index, control_time_index + ) + placement_index = max(placement_index, last_moment_time_index + 1) + new_moments.extend([[] for _ in range(placement_index - len(new_moments) + 1)]) + new_moments[placement_index].append(mapped_op) + for qubit in mapped_op.qubits: + qubit_time_index[qubit] = placement_index + for key in protocols.measurement_key_objs(mapped_op): + measurement_time_index[key] = placement_index + for key in protocols.control_keys(mapped_op): + control_time_index[key] = placement_index + + last_moment_time_index = len(new_moments) - 1 + + return _create_target_circuit_type([circuits.Moment(moment) for moment in new_moments], circuit) + + def map_operations( circuit: CIRCUIT_TYPE, map_func: Callable[[ops.Operation, int], ops.OP_TREE], @@ -139,29 +267,13 @@ def map_operations( Returns: Copy of input circuit with mapped operations (wrapped in a tagged CircuitOperation). """ - - def apply_map(op: ops.Operation, idx: int) -> ops.OP_TREE: - if not set(op.tags).isdisjoint(tags_to_ignore): - return op - c = circuits.FrozenCircuit(map_func(op, idx)) - if raise_if_add_qubits and not c.all_qubits().issubset(op.qubits): - raise ValueError( - f"Mapped operations {c.all_operations()} should act on a subset " - f"of qubits of the original operation {op}" - ) - if len(c) <= 1: - # Either empty circuit or all operations act in the same moment; - # So, we don't need to wrap them in a circuit_op. - return c[0].operations if c else [] - circuit_op = circuits.CircuitOperation(c).with_tags(MAPPED_CIRCUIT_OP_TAG) - return circuit_op - - return map_moments( + return _map_operations_impl( circuit, - lambda m, i: circuits.Circuit(apply_map(op, i) for op in m.operations).moments - or [circuits.Moment()], + map_func, deep=deep, + raise_if_add_qubits=raise_if_add_qubits, tags_to_ignore=tags_to_ignore, + wrap_in_circuit_op=True, ) @@ -191,15 +303,13 @@ def map_operations_and_unroll( Returns: Copy of input circuit with mapped operations, unrolled in a moment preserving way. """ - return unroll_circuit_op( - map_operations( - circuit, - map_func, - deep=deep, - raise_if_add_qubits=raise_if_add_qubits, - tags_to_ignore=tags_to_ignore, - ), + return _map_operations_impl( + circuit, + map_func, deep=deep, + raise_if_add_qubits=raise_if_add_qubits, + tags_to_ignore=tags_to_ignore, + wrap_in_circuit_op=False, )