Skip to content

Commit 839f3b8

Browse files
tanujkhattarrht
authored andcommitted
Bugfix in handling of deep=True flag in cirq.merge_k_qubit_unitaries transformer (quantumlib#5125)
- Fixes a bug in `cirq.merge_k_qubit_unitaries` due to which the transformer was applied recursively only on circuit operations satisfying `cirq.num_qubits(op) <= k and cirq.has_unitary(op)`. Fixed the bug and added more tests. - Part of quantumlib#5039
1 parent 88917e1 commit 839f3b8

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

Diff for: cirq-core/cirq/transformers/merge_k_qubit_gates.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ def _rewrite_merged_k_qubit_unitaries(
3434
deep = context.deep if context else False
3535

3636
def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
37-
if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)):
38-
return op
3937
op_untagged = op.untagged
4038
if (
4139
deep
@@ -51,6 +49,8 @@ def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
5149
merged_circuit_op_tag=merged_circuit_op_tag,
5250
).freeze()
5351
).with_tags(*op.tags)
52+
if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)):
53+
return op
5454
if rewriter:
5555
return rewriter(
5656
cast(circuits.CircuitOperation, op_untagged)

Diff for: cirq-core/cirq/transformers/merge_k_qubit_gates_test.py

+24
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,27 @@ def _wrap_in_matrix_gate(ops: cirq.OP_TREE):
253253
)
254254
c_new_matrix = cirq.merge_k_qubit_unitaries(c_orig, k=2, context=context)
255255
cirq.testing.assert_same_circuits(c_new_matrix, c_expected_matrix)
256+
257+
258+
def test_merge_k_qubit_unitaries_deep_recurses_on_large_circuit_op():
259+
q = cirq.LineQubit.range(2)
260+
c_orig = cirq.Circuit(
261+
cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q[0]), cirq.H(q[0]), cirq.CNOT(*q)))
262+
)
263+
c_expected = cirq.Circuit(
264+
cirq.CircuitOperation(
265+
cirq.FrozenCircuit(
266+
cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q[0]), cirq.H(q[0]))).with_tags(
267+
"merged"
268+
),
269+
cirq.CNOT(*q),
270+
)
271+
)
272+
)
273+
c_new = cirq.merge_k_qubit_unitaries(
274+
c_orig,
275+
context=cirq.TransformerContext(deep=True),
276+
k=1,
277+
rewriter=lambda op: op.with_tags("merged"),
278+
)
279+
cirq.testing.assert_same_circuits(c_new, c_expected)

0 commit comments

Comments
 (0)