@@ -205,6 +205,33 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
205
205
# pylint: enable=line-too-long
206
206
207
207
208
+ @pytest .mark .parametrize ("deep" , [False , True ])
209
+ def test_map_operations_preserves_circuit_tags (deep : bool ) -> None :
210
+ tag = "should be preserved"
211
+
212
+ def func (op : cirq .Operation , idx : int ) -> cirq .Operation :
213
+ return cirq .Y (op .qubits [0 ]) if op .gate == cirq .X else op
214
+
215
+ x = cirq .X (cirq .q (0 ))
216
+ circuit = cirq .FrozenCircuit .from_moments (x , cirq .FrozenCircuit (x )).with_tags (tag )
217
+ mapped = cirq .map_operations (circuit , func , deep = deep )
218
+
219
+ assert mapped .tags == (tag ,)
220
+
221
+
222
+ def test_map_operations_deep_preserves_subcircuit_tags ():
223
+ tag = "should be preserved"
224
+
225
+ def func (op : cirq .Operation , idx : int ) -> cirq .Operation :
226
+ return cirq .Y (op .qubits [0 ]) if op .gate == cirq .X else op
227
+
228
+ x = cirq .X (cirq .q (0 ))
229
+ circuit = cirq .FrozenCircuit .from_moments (x , cirq .FrozenCircuit (x ).with_tags (tag ))
230
+ mapped = cirq .map_operations (circuit , func , deep = True )
231
+
232
+ assert mapped [1 ].operations [0 ].circuit .tags == (tag ,)
233
+
234
+
208
235
def test_map_operations_deep_respects_tags_to_ignore ():
209
236
q = cirq .LineQubit .range (2 )
210
237
c_nested = cirq .FrozenCircuit (cirq .CX (* q ), cirq .CX (* q ).with_tags ("ignore" ), cirq .CX (* q ))
0 commit comments