Skip to content

Commit f391c20

Browse files
authored
Preserve circuit tags in transformer_primitives.map_operations (quantumlib#6505)
Review: @NoureldinYosri
1 parent a502f68 commit f391c20

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

cirq/transformers/transformer_primitives.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,9 @@ def _to_target_circuit_type(
5353

5454

5555
def _create_target_circuit_type(ops: ops.OP_TREE, target_circuit: CIRCUIT_TYPE) -> CIRCUIT_TYPE:
56-
return cast(
57-
CIRCUIT_TYPE,
58-
circuits.Circuit(ops)
59-
if isinstance(target_circuit, circuits.Circuit)
60-
else circuits.FrozenCircuit(ops),
61-
)
56+
if isinstance(target_circuit, circuits.FrozenCircuit):
57+
return cast(CIRCUIT_TYPE, circuits.FrozenCircuit(ops).with_tags(*target_circuit.tags))
58+
return cast(CIRCUIT_TYPE, circuits.Circuit(ops))
6259

6360

6461
def map_moments(

cirq/transformers/transformer_primitives_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,33 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
205205
# pylint: enable=line-too-long
206206

207207

208+
@pytest.mark.parametrize("deep", [False, True])
209+
def test_map_operations_preserves_circuit_tags(deep: bool) -> None:
210+
tag = "should be preserved"
211+
212+
def func(op: cirq.Operation, idx: int) -> cirq.Operation:
213+
return cirq.Y(op.qubits[0]) if op.gate == cirq.X else op
214+
215+
x = cirq.X(cirq.q(0))
216+
circuit = cirq.FrozenCircuit.from_moments(x, cirq.FrozenCircuit(x)).with_tags(tag)
217+
mapped = cirq.map_operations(circuit, func, deep=deep)
218+
219+
assert mapped.tags == (tag,)
220+
221+
222+
def test_map_operations_deep_preserves_subcircuit_tags():
223+
tag = "should be preserved"
224+
225+
def func(op: cirq.Operation, idx: int) -> cirq.Operation:
226+
return cirq.Y(op.qubits[0]) if op.gate == cirq.X else op
227+
228+
x = cirq.X(cirq.q(0))
229+
circuit = cirq.FrozenCircuit.from_moments(x, cirq.FrozenCircuit(x).with_tags(tag))
230+
mapped = cirq.map_operations(circuit, func, deep=True)
231+
232+
assert mapped[1].operations[0].circuit.tags == (tag,)
233+
234+
208235
def test_map_operations_deep_respects_tags_to_ignore():
209236
q = cirq.LineQubit.range(2)
210237
c_nested = cirq.FrozenCircuit(cirq.CX(*q), cirq.CX(*q).with_tags("ignore"), cirq.CX(*q))

0 commit comments

Comments
 (0)