Skip to content

Commit aed4eb8

Browse files
authored
Bugfixes in handling nested tags_to_ignore + deep=True in cirq.map_moments and cirq.map_operations transformer primitives (#5109)
- Fixes a few more bugs in the handling of deep=True flag and nested operations to ignore using `tags_to_ignore` in `cirq.map_operations` and `cirq.map_moments` transformer primitives. Also added more tests. - Step towards fixing #5039
1 parent e0a64dd commit aed4eb8

File tree

2 files changed

+62
-4
lines changed

2 files changed

+62
-4
lines changed

cirq-core/cirq/transformers/transformer_primitives.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def map_moments(
8787
continue
8888
op_untagged = cast(circuits.CircuitOperation, op.untagged)
8989
mapped_op = op_untagged.replace(
90-
circuit=map_moments(op_untagged.circuit, map_func, deep=deep)
90+
circuit=map_moments(
91+
op_untagged.circuit, map_func, tags_to_ignore=tags_to_ignore, deep=deep
92+
)
9193
).with_tags(*op.tags)
9294
batch_replace.append((i, op, mapped_op))
9395
mutable_circuit = circuit.unfreeze(copy=True)
@@ -149,7 +151,10 @@ def apply_map(op: ops.Operation, idx: int) -> ops.OP_TREE:
149151
return circuit_op
150152

151153
return map_moments(
152-
circuit, lambda m, i: [circuits.Moment(apply_map(op, i) for op in m.operations)], deep=deep
154+
circuit,
155+
lambda m, i: [circuits.Moment(apply_map(op, i) for op in m.operations)],
156+
deep=deep,
157+
tags_to_ignore=tags_to_ignore,
153158
)
154159

155160

cirq-core/cirq/transformers/transformer_primitives_test.py

+55-2
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,47 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
213213
# pylint: enable=line-too-long
214214

215215

216+
def test_map_operations_deep_respects_tags_to_ignore():
217+
q = cirq.LineQubit.range(2)
218+
c_nested = cirq.FrozenCircuit(cirq.CX(*q), cirq.CX(*q).with_tags("ignore"), cirq.CX(*q))
219+
c_nested_mapped = cirq.FrozenCircuit(cirq.CZ(*q), cirq.CX(*q).with_tags("ignore"), cirq.CZ(*q))
220+
c_orig = cirq.Circuit(
221+
c_nested,
222+
cirq.CircuitOperation(c_nested).repeat(4).with_tags("ignore"),
223+
c_nested,
224+
cirq.CircuitOperation(
225+
cirq.FrozenCircuit(
226+
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
227+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
228+
cirq.CircuitOperation(c_nested).repeat(7),
229+
)
230+
),
231+
c_nested,
232+
)
233+
c_expected = cirq.Circuit(
234+
c_nested_mapped,
235+
cirq.CircuitOperation(c_nested).repeat(4).with_tags("ignore"),
236+
c_nested_mapped,
237+
cirq.CircuitOperation(
238+
cirq.FrozenCircuit(
239+
cirq.CircuitOperation(c_nested_mapped).repeat(5).with_tags("preserve_tag"),
240+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
241+
cirq.CircuitOperation(c_nested_mapped).repeat(7),
242+
)
243+
),
244+
c_nested_mapped,
245+
)
246+
cirq.testing.assert_same_circuits(
247+
cirq.map_operations(
248+
c_orig,
249+
lambda op, _: cirq.CZ(*op.qubits) if op.gate == cirq.CX else op,
250+
tags_to_ignore=["ignore"],
251+
deep=True,
252+
),
253+
c_expected,
254+
)
255+
256+
216257
def test_map_operations_respects_tags_to_ignore():
217258
q = cirq.LineQubit.range(2)
218259
c = cirq.Circuit(cirq.CNOT(*q), cirq.CNOT(*q).with_tags("ignore"), cirq.CNOT(*q))
@@ -402,17 +443,29 @@ def test_map_moments_drop_empty_moments():
402443
def test_map_moments_drop_empty_moments_deep():
403444
op = cirq.X(cirq.NamedQubit("q"))
404445
c_nested = cirq.FrozenCircuit(cirq.Moment(op), cirq.Moment(), cirq.Moment(op))
446+
circuit_op = cirq.CircuitOperation(c_nested).repeat(2)
447+
circuit_op_dropped = cirq.CircuitOperation(cirq.FrozenCircuit([op, op])).repeat(2)
405448
c_orig = cirq.Circuit(
406449
c_nested,
407450
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
408451
c_nested,
409-
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
452+
cirq.CircuitOperation(
453+
cirq.FrozenCircuit(circuit_op, circuit_op.with_tags("ignore"), circuit_op)
454+
)
455+
.repeat(5)
456+
.with_tags("preserve_tag"),
410457
)
411458
c_expected = cirq.Circuit(
412459
[op, op],
413460
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
414461
[op, op],
415-
cirq.CircuitOperation(cirq.FrozenCircuit([op, op])).repeat(5).with_tags("preserve_tag"),
462+
cirq.CircuitOperation(
463+
cirq.FrozenCircuit(
464+
circuit_op_dropped, circuit_op.with_tags("ignore"), circuit_op_dropped
465+
)
466+
)
467+
.repeat(5)
468+
.with_tags("preserve_tag"),
416469
)
417470
c_mapped = cirq.map_moments(
418471
c_orig, lambda m, i: [] if len(m) == 0 else [m], deep=True, tags_to_ignore=("ignore",)

0 commit comments

Comments
 (0)