Skip to content

Commit a4ba95b

Browse files
author
Shef
authored
Break intermediate measurements on 3+ qubits into single qubit measurements in RouteCQC#6293 (#6349)
1 parent 7f47e13 commit a4ba95b

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

cirq-core/cirq/transformers/routing/route_circuit_cqc.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -249,26 +249,32 @@ def _get_one_and_two_qubit_ops_as_timesteps(
249249
output routed circuit, single-qubit operations are inserted before two-qubit operations.
250250
251251
Raises:
252-
ValueError: if circuit has intermediate measurement op's that act on 3 or more qubits.
252+
ValueError: if circuit has intermediate measurements that act on three or more
253+
qubits with a custom key.
253254
"""
254255
two_qubit_circuit = circuits.Circuit()
255256
single_qubit_ops: List[List[cirq.Operation]] = []
256257

257-
if any(
258-
protocols.num_qubits(op) > 2 and protocols.is_measurement(op)
259-
for op in itertools.chain(*circuit.moments[:-1])
260-
):
261-
# There is at least one non-terminal measurement on 3+ qubits
262-
raise ValueError('Non-terminal measurements on three or more qubits are not supported')
263-
264-
for moment in circuit:
258+
for i, moment in enumerate(circuit):
265259
for op in moment:
266260
timestep = two_qubit_circuit.earliest_available_moment(op)
267261
single_qubit_ops.extend([] for _ in range(timestep + 1 - len(single_qubit_ops)))
268262
two_qubit_circuit.append(
269263
circuits.Moment() for _ in range(timestep + 1 - len(two_qubit_circuit))
270264
)
271-
if protocols.num_qubits(op) == 2:
265+
if protocols.num_qubits(op) > 2 and protocols.is_measurement(op):
266+
key = op.gate.key # type: ignore
267+
default_key = ops.measure(op.qubits).gate.key # type: ignore
268+
if len(circuit.moments) == i + 1:
269+
single_qubit_ops[timestep].append(op)
270+
elif key in ('', default_key):
271+
single_qubit_ops[timestep].extend(ops.measure(qubit) for qubit in op.qubits)
272+
else:
273+
raise ValueError(
274+
'Intermediate measurements on three or more qubits '
275+
'with a custom key are not supported'
276+
)
277+
elif protocols.num_qubits(op) == 2:
272278
two_qubit_circuit[timestep] = two_qubit_circuit[timestep].with_operation(op)
273279
else:
274280
single_qubit_ops[timestep].append(op)

cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -107,35 +107,45 @@ def test_circuit_with_measurement_gates():
107107
cirq.testing.assert_same_circuits(routed_circuit, circuit)
108108

109109

110-
def test_circuit_with_valid_intermediate_multi_qubit_measurement_gates():
111-
device = cirq.testing.construct_ring_device(3)
110+
def test_circuit_with_two_qubit_intermediate_measurement_gate():
111+
device = cirq.testing.construct_ring_device(2)
112112
device_graph = device.metadata.nx_graph
113113
router = cirq.RouteCQC(device_graph)
114-
q = cirq.LineQubit.range(2)
115-
hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(2)})
116-
117-
valid_circuit = cirq.Circuit(cirq.measure_each(*q), cirq.H.on_each(q))
118-
119-
c_routed = router(
120-
valid_circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True)
114+
qs = cirq.LineQubit.range(2)
115+
hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(2)})
116+
circuit = cirq.Circuit([cirq.Moment(cirq.measure(qs)), cirq.Moment(cirq.H.on_each(qs))])
117+
routed_circuit = router(
118+
circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True)
121119
)
122-
device.validate_circuit(c_routed)
120+
device.validate_circuit(routed_circuit)
123121

124122

125-
def test_circuit_with_invalid_intermediate_multi_qubit_measurement_gates():
123+
def test_circuit_with_multi_qubit_intermediate_measurement_gate_and_with_default_key():
126124
device = cirq.testing.construct_ring_device(3)
127125
device_graph = device.metadata.nx_graph
128126
router = cirq.RouteCQC(device_graph)
129-
q = cirq.LineQubit.range(3)
130-
hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)})
127+
qs = cirq.LineQubit.range(3)
128+
hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(3)})
129+
circuit = cirq.Circuit([cirq.Moment(cirq.measure(qs)), cirq.Moment(cirq.H.on_each(qs))])
130+
routed_circuit = router(
131+
circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True)
132+
)
133+
expected = cirq.Circuit([cirq.Moment(cirq.measure_each(qs)), cirq.Moment(cirq.H.on_each(qs))])
134+
cirq.testing.assert_same_circuits(routed_circuit, expected)
131135

132-
invalid_circuit = cirq.Circuit(cirq.MeasurementGate(3).on(*q), cirq.H.on_each(*q))
133136

137+
def test_circuit_with_multi_qubit_intermediate_measurement_gate_with_custom_key():
138+
device = cirq.testing.construct_ring_device(3)
139+
device_graph = device.metadata.nx_graph
140+
router = cirq.RouteCQC(device_graph)
141+
qs = cirq.LineQubit.range(3)
142+
hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(3)})
143+
circuit = cirq.Circuit(
144+
[cirq.Moment(cirq.measure(qs, key="test")), cirq.Moment(cirq.H.on_each(qs))]
145+
)
134146
with pytest.raises(ValueError):
135147
_ = router(
136-
invalid_circuit,
137-
initial_mapper=hard_coded_mapper,
138-
context=cirq.TransformerContext(deep=True),
148+
circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True)
139149
)
140150

141151

0 commit comments

Comments
 (0)