Skip to content

Commit 92d19f6

Browse files
authored
Add support for deep=True to cirq.merge_k_qubit_unitaries transformer (#5122)
1 parent 64a6723 commit 92d19f6

File tree

2 files changed

+114
-15
lines changed

2 files changed

+114
-15
lines changed

cirq-core/cirq/transformers/merge_k_qubit_gates.py

+49-15
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,47 @@
2323
import cirq
2424

2525

26+
def _rewrite_merged_k_qubit_unitaries(
27+
circuit: 'cirq.AbstractCircuit',
28+
*,
29+
context: Optional['cirq.TransformerContext'] = None,
30+
k: int = 0,
31+
rewriter: Optional[Callable[['cirq.CircuitOperation'], 'cirq.OP_TREE']] = None,
32+
merged_circuit_op_tag: str = "_merged_k_qubit_unitaries_component",
33+
) -> 'cirq.Circuit':
34+
deep = context.deep if context else False
35+
36+
def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
37+
if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)):
38+
return op
39+
op_untagged = op.untagged
40+
if (
41+
deep
42+
and isinstance(op_untagged, circuits.CircuitOperation)
43+
and merged_circuit_op_tag not in op.tags
44+
):
45+
return op_untagged.replace(
46+
circuit=_rewrite_merged_k_qubit_unitaries(
47+
op_untagged.circuit,
48+
context=context,
49+
k=k,
50+
rewriter=rewriter,
51+
merged_circuit_op_tag=merged_circuit_op_tag,
52+
).freeze()
53+
).with_tags(*op.tags)
54+
if rewriter:
55+
return rewriter(
56+
cast(circuits.CircuitOperation, op_untagged)
57+
if merged_circuit_op_tag in op.tags
58+
else circuits.CircuitOperation(circuits.FrozenCircuit(op))
59+
)
60+
return ops.MatrixGate(protocols.unitary(op)).on(*op.qubits)
61+
62+
return transformer_primitives.map_operations_and_unroll(
63+
circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else ()
64+
).unfreeze(copy=False)
65+
66+
2667
@transformer_api.transformer
2768
def merge_k_qubit_unitaries(
2869
circuit: 'cirq.AbstractCircuit',
@@ -54,24 +95,17 @@ def merge_k_qubit_unitaries(
5495
if k <= 0:
5596
raise ValueError(f"k should be greater than or equal to 1. Found {k}.")
5697
merged_circuit_op_tag = "_merged_k_qubit_unitaries_component"
57-
58-
def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
59-
if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)):
60-
return op
61-
if rewriter:
62-
return rewriter(
63-
cast(circuits.CircuitOperation, op.untagged)
64-
if merged_circuit_op_tag in op.tags
65-
else circuits.CircuitOperation(circuits.FrozenCircuit(op))
66-
)
67-
return ops.MatrixGate(protocols.unitary(op)).on(*op.qubits)
68-
6998
circuit = transformer_primitives.merge_k_qubit_unitaries_to_circuit_op(
7099
circuit,
71100
k=k,
72101
tags_to_ignore=context.tags_to_ignore if context else (),
73102
merged_circuit_op_tag=merged_circuit_op_tag,
103+
deep=context.deep if context else False,
104+
)
105+
return _rewrite_merged_k_qubit_unitaries(
106+
circuit,
107+
context=context,
108+
k=k,
109+
rewriter=rewriter,
110+
merged_circuit_op_tag=merged_circuit_op_tag,
74111
)
75-
return transformer_primitives.map_operations_and_unroll(
76-
circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else ()
77-
).unfreeze(copy=False)

cirq-core/cirq/transformers/merge_k_qubit_gates_test.py

+65
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,68 @@ def rewriter_replace_with_decomp(op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
188188
║ ║
189189
a: ═════════════════════════════════════════════════════════════════════════════════════════════@══════════════════════════════^═══''',
190190
)
191+
192+
193+
def test_merge_k_qubit_unitaries_deep():
194+
q = cirq.LineQubit.range(2)
195+
h_cz_y = [cirq.H(q[0]), cirq.CZ(*q), cirq.Y(q[1])]
196+
c_orig = cirq.Circuit(
197+
h_cz_y,
198+
cirq.Moment(cirq.X(q[0]).with_tags("ignore"), cirq.Y(q[1])),
199+
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
200+
[cirq.CNOT(*q), cirq.CNOT(*q)],
201+
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(4),
202+
[cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)],
203+
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(5).with_tags("preserve_tag"),
204+
)
205+
206+
def _wrap_in_cop(ops: cirq.OP_TREE, tag: str):
207+
return cirq.CircuitOperation(cirq.FrozenCircuit(ops)).with_tags(tag)
208+
209+
c_expected = cirq.Circuit(
210+
_wrap_in_cop([h_cz_y, cirq.Y(q[1])], '1'),
211+
cirq.Moment(cirq.X(q[0]).with_tags("ignore")),
212+
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
213+
_wrap_in_cop([cirq.CNOT(*q), cirq.CNOT(*q)], '2'),
214+
cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_cop(h_cz_y, '3'))).repeat(4),
215+
_wrap_in_cop([cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)], '4'),
216+
cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_cop(h_cz_y, '5')))
217+
.repeat(5)
218+
.with_tags("preserve_tag"),
219+
strategy=cirq.InsertStrategy.NEW,
220+
)
221+
222+
component_id = 0
223+
224+
def rewriter_merge_to_circuit_op(op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
225+
nonlocal component_id
226+
component_id = component_id + 1
227+
return op.with_tags(f'{component_id}')
228+
229+
context = cirq.TransformerContext(tags_to_ignore=("ignore",), deep=True)
230+
c_new = cirq.merge_k_qubit_unitaries(
231+
c_orig,
232+
k=2,
233+
context=context,
234+
rewriter=rewriter_merge_to_circuit_op,
235+
)
236+
cirq.testing.assert_same_circuits(c_new, c_expected)
237+
238+
def _wrap_in_matrix_gate(ops: cirq.OP_TREE):
239+
op = _wrap_in_cop(ops, 'temp')
240+
return cirq.MatrixGate(cirq.unitary(op)).on(*op.qubits)
241+
242+
c_expected_matrix = cirq.Circuit(
243+
_wrap_in_matrix_gate([h_cz_y, cirq.Y(q[1])]),
244+
cirq.Moment(cirq.X(q[0]).with_tags("ignore")),
245+
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
246+
_wrap_in_matrix_gate([cirq.CNOT(*q), cirq.CNOT(*q)]),
247+
cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_matrix_gate(h_cz_y))).repeat(4),
248+
_wrap_in_matrix_gate([cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)]),
249+
cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_matrix_gate(h_cz_y)))
250+
.repeat(5)
251+
.with_tags("preserve_tag"),
252+
strategy=cirq.InsertStrategy.NEW,
253+
)
254+
c_new_matrix = cirq.merge_k_qubit_unitaries(c_orig, k=2, context=context)
255+
cirq.testing.assert_same_circuits(c_new_matrix, c_expected_matrix)

0 commit comments

Comments
 (0)