@@ -213,6 +213,47 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
213
213
# pylint: enable=line-too-long
214
214
215
215
216
+ def test_map_operations_deep_respects_tags_to_ignore ():
217
+ q = cirq .LineQubit .range (2 )
218
+ c_nested = cirq .FrozenCircuit (cirq .CX (* q ), cirq .CX (* q ).with_tags ("ignore" ), cirq .CX (* q ))
219
+ c_nested_mapped = cirq .FrozenCircuit (cirq .CZ (* q ), cirq .CX (* q ).with_tags ("ignore" ), cirq .CZ (* q ))
220
+ c_orig = cirq .Circuit (
221
+ c_nested ,
222
+ cirq .CircuitOperation (c_nested ).repeat (4 ).with_tags ("ignore" ),
223
+ c_nested ,
224
+ cirq .CircuitOperation (
225
+ cirq .FrozenCircuit (
226
+ cirq .CircuitOperation (c_nested ).repeat (5 ).with_tags ("preserve_tag" ),
227
+ cirq .CircuitOperation (c_nested ).repeat (6 ).with_tags ("ignore" ),
228
+ cirq .CircuitOperation (c_nested ).repeat (7 ),
229
+ )
230
+ ),
231
+ c_nested ,
232
+ )
233
+ c_expected = cirq .Circuit (
234
+ c_nested_mapped ,
235
+ cirq .CircuitOperation (c_nested ).repeat (4 ).with_tags ("ignore" ),
236
+ c_nested_mapped ,
237
+ cirq .CircuitOperation (
238
+ cirq .FrozenCircuit (
239
+ cirq .CircuitOperation (c_nested_mapped ).repeat (5 ).with_tags ("preserve_tag" ),
240
+ cirq .CircuitOperation (c_nested ).repeat (6 ).with_tags ("ignore" ),
241
+ cirq .CircuitOperation (c_nested_mapped ).repeat (7 ),
242
+ )
243
+ ),
244
+ c_nested_mapped ,
245
+ )
246
+ cirq .testing .assert_same_circuits (
247
+ cirq .map_operations (
248
+ c_orig ,
249
+ lambda op , _ : cirq .CZ (* op .qubits ) if op .gate == cirq .CX else op ,
250
+ tags_to_ignore = ["ignore" ],
251
+ deep = True ,
252
+ ),
253
+ c_expected ,
254
+ )
255
+
256
+
216
257
def test_map_operations_respects_tags_to_ignore ():
217
258
q = cirq .LineQubit .range (2 )
218
259
c = cirq .Circuit (cirq .CNOT (* q ), cirq .CNOT (* q ).with_tags ("ignore" ), cirq .CNOT (* q ))
@@ -402,17 +443,29 @@ def test_map_moments_drop_empty_moments():
402
443
def test_map_moments_drop_empty_moments_deep ():
403
444
op = cirq .X (cirq .NamedQubit ("q" ))
404
445
c_nested = cirq .FrozenCircuit (cirq .Moment (op ), cirq .Moment (), cirq .Moment (op ))
446
+ circuit_op = cirq .CircuitOperation (c_nested ).repeat (2 )
447
+ circuit_op_dropped = cirq .CircuitOperation (cirq .FrozenCircuit ([op , op ])).repeat (2 )
405
448
c_orig = cirq .Circuit (
406
449
c_nested ,
407
450
cirq .CircuitOperation (c_nested ).repeat (6 ).with_tags ("ignore" ),
408
451
c_nested ,
409
- cirq .CircuitOperation (c_nested ).repeat (5 ).with_tags ("preserve_tag" ),
452
+ cirq .CircuitOperation (
453
+ cirq .FrozenCircuit (circuit_op , circuit_op .with_tags ("ignore" ), circuit_op )
454
+ )
455
+ .repeat (5 )
456
+ .with_tags ("preserve_tag" ),
410
457
)
411
458
c_expected = cirq .Circuit (
412
459
[op , op ],
413
460
cirq .CircuitOperation (c_nested ).repeat (6 ).with_tags ("ignore" ),
414
461
[op , op ],
415
- cirq .CircuitOperation (cirq .FrozenCircuit ([op , op ])).repeat (5 ).with_tags ("preserve_tag" ),
462
+ cirq .CircuitOperation (
463
+ cirq .FrozenCircuit (
464
+ circuit_op_dropped , circuit_op .with_tags ("ignore" ), circuit_op_dropped
465
+ )
466
+ )
467
+ .repeat (5 )
468
+ .with_tags ("preserve_tag" ),
416
469
)
417
470
c_mapped = cirq .map_moments (
418
471
c_orig , lambda m , i : [] if len (m ) == 0 else [m ], deep = True , tags_to_ignore = ("ignore" ,)
0 commit comments