@@ -188,3 +188,68 @@ def rewriter_replace_with_decomp(op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
188
188
║ ║
189
189
a: ═════════════════════════════════════════════════════════════════════════════════════════════@══════════════════════════════^═══''' ,
190
190
)
191
+
192
+
193
+ def test_merge_k_qubit_unitaries_deep ():
194
+ q = cirq .LineQubit .range (2 )
195
+ h_cz_y = [cirq .H (q [0 ]), cirq .CZ (* q ), cirq .Y (q [1 ])]
196
+ c_orig = cirq .Circuit (
197
+ h_cz_y ,
198
+ cirq .Moment (cirq .X (q [0 ]).with_tags ("ignore" ), cirq .Y (q [1 ])),
199
+ cirq .CircuitOperation (cirq .FrozenCircuit (h_cz_y )).repeat (6 ).with_tags ("ignore" ),
200
+ [cirq .CNOT (* q ), cirq .CNOT (* q )],
201
+ cirq .CircuitOperation (cirq .FrozenCircuit (h_cz_y )).repeat (4 ),
202
+ [cirq .CNOT (* q ), cirq .CZ (* q ), cirq .CNOT (* q )],
203
+ cirq .CircuitOperation (cirq .FrozenCircuit (h_cz_y )).repeat (5 ).with_tags ("preserve_tag" ),
204
+ )
205
+
206
+ def _wrap_in_cop (ops : cirq .OP_TREE , tag : str ):
207
+ return cirq .CircuitOperation (cirq .FrozenCircuit (ops )).with_tags (tag )
208
+
209
+ c_expected = cirq .Circuit (
210
+ _wrap_in_cop ([h_cz_y , cirq .Y (q [1 ])], '1' ),
211
+ cirq .Moment (cirq .X (q [0 ]).with_tags ("ignore" )),
212
+ cirq .CircuitOperation (cirq .FrozenCircuit (h_cz_y )).repeat (6 ).with_tags ("ignore" ),
213
+ _wrap_in_cop ([cirq .CNOT (* q ), cirq .CNOT (* q )], '2' ),
214
+ cirq .CircuitOperation (cirq .FrozenCircuit (_wrap_in_cop (h_cz_y , '3' ))).repeat (4 ),
215
+ _wrap_in_cop ([cirq .CNOT (* q ), cirq .CZ (* q ), cirq .CNOT (* q )], '4' ),
216
+ cirq .CircuitOperation (cirq .FrozenCircuit (_wrap_in_cop (h_cz_y , '5' )))
217
+ .repeat (5 )
218
+ .with_tags ("preserve_tag" ),
219
+ strategy = cirq .InsertStrategy .NEW ,
220
+ )
221
+
222
+ component_id = 0
223
+
224
+ def rewriter_merge_to_circuit_op (op : 'cirq.CircuitOperation' ) -> 'cirq.OP_TREE' :
225
+ nonlocal component_id
226
+ component_id = component_id + 1
227
+ return op .with_tags (f'{ component_id } ' )
228
+
229
+ context = cirq .TransformerContext (tags_to_ignore = ("ignore" ,), deep = True )
230
+ c_new = cirq .merge_k_qubit_unitaries (
231
+ c_orig ,
232
+ k = 2 ,
233
+ context = context ,
234
+ rewriter = rewriter_merge_to_circuit_op ,
235
+ )
236
+ cirq .testing .assert_same_circuits (c_new , c_expected )
237
+
238
+ def _wrap_in_matrix_gate (ops : cirq .OP_TREE ):
239
+ op = _wrap_in_cop (ops , 'temp' )
240
+ return cirq .MatrixGate (cirq .unitary (op )).on (* op .qubits )
241
+
242
+ c_expected_matrix = cirq .Circuit (
243
+ _wrap_in_matrix_gate ([h_cz_y , cirq .Y (q [1 ])]),
244
+ cirq .Moment (cirq .X (q [0 ]).with_tags ("ignore" )),
245
+ cirq .CircuitOperation (cirq .FrozenCircuit (h_cz_y )).repeat (6 ).with_tags ("ignore" ),
246
+ _wrap_in_matrix_gate ([cirq .CNOT (* q ), cirq .CNOT (* q )]),
247
+ cirq .CircuitOperation (cirq .FrozenCircuit (_wrap_in_matrix_gate (h_cz_y ))).repeat (4 ),
248
+ _wrap_in_matrix_gate ([cirq .CNOT (* q ), cirq .CZ (* q ), cirq .CNOT (* q )]),
249
+ cirq .CircuitOperation (cirq .FrozenCircuit (_wrap_in_matrix_gate (h_cz_y )))
250
+ .repeat (5 )
251
+ .with_tags ("preserve_tag" ),
252
+ strategy = cirq .InsertStrategy .NEW ,
253
+ )
254
+ c_new_matrix = cirq .merge_k_qubit_unitaries (c_orig , k = 2 , context = context )
255
+ cirq .testing .assert_same_circuits (c_new_matrix , c_expected_matrix )
0 commit comments