Skip to content

Commit c2a153b

Browse files
authored
Support multi-qubit measurements in deferred measurement transformer (quantumlib#5787)
* Support multi-qubit measurements in deferred measurement transformer * mypy * invert if branch * docstring
1 parent 47e7d22 commit c2a153b

File tree

2 files changed

+37
-21
lines changed

2 files changed

+37
-21
lines changed

cirq/transformers/measurement_transformers.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import itertools
1516
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
1617

1718
from cirq import ops, protocols, value
@@ -81,9 +82,7 @@ def defer_measurements(
8182
A circuit with equivalent logic, but all measurements at the end of the
8283
circuit.
8384
Raises:
84-
ValueError: If sympy-based classical conditions are used, or if
85-
conditions based on multi-qubit measurements exist. (The latter of
86-
these is planned to be implemented soon).
85+
ValueError: If sympy-based classical conditions are used.
8786
NotImplementedError: When attempting to defer a measurement with a
8887
confusion map. (https://github.com/quantumlib/Cirq/issues/5482)
8988
"""
@@ -111,23 +110,22 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
111110
elif protocols.is_measurement(op):
112111
return [defer(op, None) for op in protocols.decompose_once(op)]
113112
elif op.classical_controls:
114-
controls = []
113+
new_op = op.without_classical_controls()
115114
for c in op.classical_controls:
116115
if isinstance(c, value.KeyCondition):
117116
if c.key not in measurement_qubits:
118117
raise ValueError(f'Deferred measurement for key={c.key} not found.')
119-
qubits = measurement_qubits[c.key]
120-
if len(qubits) != 1:
121-
# TODO: Multi-qubit conditions require
122-
# https://github.com/quantumlib/Cirq/issues/4512
123-
# Remember to update docstring above once this works.
124-
raise ValueError('Only single qubit conditions are allowed.')
125-
controls.extend(qubits)
118+
qs = measurement_qubits[c.key]
119+
if len(qs) == 1:
120+
control_values: Any = range(1, qs[0].dimension)
121+
else:
122+
all_values = itertools.product(*[range(q.dimension) for q in qs])
123+
anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None))
124+
control_values = ops.SumOfProducts(anything_but_all_zeros)
125+
new_op = new_op.controlled_by(*qs, control_values=control_values)
126126
else:
127127
raise ValueError('Only KeyConditions are allowed.')
128-
return op.without_classical_controls().controlled_by(
129-
*controls, control_values=[tuple(range(1, q.dimension)) for q in controls]
130-
)
128+
return new_op
131129
return op
132130

133131
circuit = transformer_primitives.map_operations_and_unroll(

cirq/transformers/measurement_transformers_test.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,31 @@ def test_multi_qubit_measurements():
225225
)
226226

227227

228+
def test_multi_qubit_control():
229+
q0, q1, q2 = cirq.LineQubit.range(3)
230+
circuit = cirq.Circuit(
231+
cirq.measure(q0, q1, key='a'),
232+
cirq.X(q2).with_classical_controls('a'),
233+
cirq.measure(q2, key='b'),
234+
)
235+
assert_equivalent_to_deferred(circuit)
236+
deferred = cirq.defer_measurements(circuit)
237+
q_ma0 = _MeasurementQid('a', q0)
238+
q_ma1 = _MeasurementQid('a', q1)
239+
cirq.testing.assert_same_circuits(
240+
deferred,
241+
cirq.Circuit(
242+
cirq.CX(q0, q_ma0),
243+
cirq.CX(q1, q_ma1),
244+
cirq.X(q2).controlled_by(
245+
q_ma0, q_ma1, control_values=cirq.SumOfProducts(((0, 1), (1, 0), (1, 1)))
246+
),
247+
cirq.measure(q_ma0, q_ma1, key='a'),
248+
cirq.measure(q2, key='b'),
249+
),
250+
)
251+
252+
228253
def test_diagram():
229254
q0, q1, q2, q3 = cirq.LineQubit.range(4)
230255
circuit = cirq.Circuit(
@@ -270,13 +295,6 @@ def test_repr(qid: _MeasurementQid):
270295
test_repr(_MeasurementQid('0:1:a', cirq.LineQid(9, 4)))
271296

272297

273-
def test_multi_qubit_control():
274-
q0, q1 = cirq.LineQubit.range(2)
275-
circuit = cirq.Circuit(cirq.measure(q0, q1, key='a'), cirq.X(q1).with_classical_controls('a'))
276-
with pytest.raises(ValueError, match='Only single qubit conditions are allowed'):
277-
_ = cirq.defer_measurements(circuit)
278-
279-
280298
def test_sympy_control():
281299
q0, q1 = cirq.LineQubit.range(2)
282300
circuit = cirq.Circuit(

0 commit comments

Comments
 (0)