Skip to content

Commit af1267d

Browse files
Allow repeated measurements in deferred transformer (#5857)
* Add handling for sympy conditions in deferred measurement transformer * docstring * mypy * mypy * cover * Make this more generic, covers all kinds of conditions. * Better docs * Sympy can also be CX * docs * docs * Allow repeated measurements in deferred transformer * Coverage * Add mixed tests, simplify loop, add simplification in ControlledGate * Fix error message * Simplify error message * Inline variable * fix merge * qudit sympy test * fix build * Fix test * Fix test * nits * mypy * mypy * mypy * Add some code comments * Add test for repeated measurement diagram * change test back Co-authored-by: Tanuj Khattar <[email protected]>
1 parent 7019adc commit af1267d

File tree

2 files changed

+105
-52
lines changed

2 files changed

+105
-52
lines changed

cirq-core/cirq/transformers/measurement_transformers.py

+54-35
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,8 @@
1313
# limitations under the License.
1414

1515
import itertools
16-
from typing import (
17-
Any,
18-
Dict,
19-
Iterable,
20-
List,
21-
Mapping,
22-
Optional,
23-
Sequence,
24-
Tuple,
25-
TYPE_CHECKING,
26-
Union,
27-
)
16+
from collections import defaultdict
17+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
2818

2919
import numpy as np
3020

@@ -43,30 +33,32 @@ class _MeasurementQid(ops.Qid):
4333
Exactly one qubit will be created per qubit in the measurement gate.
4434
"""
4535

46-
def __init__(self, key: Union[str, 'cirq.MeasurementKey'], qid: 'cirq.Qid'):
36+
def __init__(self, key: Union[str, 'cirq.MeasurementKey'], qid: 'cirq.Qid', index: int = 0):
4737
"""Initializes the qubit.
4838
4939
Args:
5040
key: The key of the measurement gate being deferred.
5141
qid: One qubit that is being measured. Each deferred measurement
5242
should create one new _MeasurementQid per qubit being measured
5343
by that gate.
44+
index: For repeated measurement keys, this represents the index of that measurement.
5445
"""
5546
self._key = value.MeasurementKey.parse_serialized(key) if isinstance(key, str) else key
5647
self._qid = qid
48+
self._index = index
5749

5850
@property
5951
def dimension(self) -> int:
6052
return self._qid.dimension
6153

6254
def _comparison_key(self) -> Any:
63-
return str(self._key), self._qid._comparison_key()
55+
return str(self._key), self._index, self._qid._comparison_key()
6456

6557
def __str__(self) -> str:
66-
return f"M('{self._key}', q={self._qid})"
58+
return f"M('{self._key}[{self._index}]', q={self._qid})"
6759

6860
def __repr__(self) -> str:
69-
return f'_MeasurementQid({self._key!r}, {self._qid!r})'
61+
return f'_MeasurementQid({self._key!r}, {self._qid!r}, {self._index})'
7062

7163

7264
@transformer_api.transformer
@@ -102,16 +94,18 @@ def defer_measurements(
10294

10395
circuit = transformer_primitives.unroll_circuit_op(circuit, deep=True, tags_to_check=None)
10496
terminal_measurements = {op for _, op in find_terminal_measurements(circuit)}
105-
measurement_qubits: Dict['cirq.MeasurementKey', List['_MeasurementQid']] = {}
97+
measurement_qubits: Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]] = defaultdict(
98+
list
99+
)
106100

107101
def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
108102
if op in terminal_measurements:
109103
return op
110104
gate = op.gate
111105
if isinstance(gate, ops.MeasurementGate):
112106
key = value.MeasurementKey.parse_serialized(gate.key)
113-
targets = [_MeasurementQid(key, q) for q in op.qubits]
114-
measurement_qubits[key] = targets
107+
targets = [_MeasurementQid(key, q, len(measurement_qubits[key])) for q in op.qubits]
108+
measurement_qubits[key].append(tuple(targets))
115109
cxs = [_mod_add(q, target) for q, target in zip(op.qubits, targets)]
116110
confusions = [
117111
_ConfusionChannel(m, [op.qubits[i].dimension for i in indexes]).on(
@@ -125,10 +119,24 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
125119
return [defer(op, None) for op in protocols.decompose_once(op)]
126120
elif op.classical_controls:
127121
# Convert to a quantum control
128-
keys = sorted(set(key for c in op.classical_controls for key in c.keys))
129-
for key in keys:
122+
123+
# First create a sorted set of the indexed keys for this control.
124+
keys = sorted(
125+
set(
126+
indexed_key
127+
for condition in op.classical_controls
128+
for indexed_key in (
129+
[(condition.key, condition.index)]
130+
if isinstance(condition, value.KeyCondition)
131+
else [(k, -1) for k in condition.keys]
132+
)
133+
)
134+
)
135+
for key, index in keys:
130136
if key not in measurement_qubits:
131137
raise ValueError(f'Deferred measurement for key={key} not found.')
138+
if index >= len(measurement_qubits[key]) or index < -len(measurement_qubits[key]):
139+
raise ValueError(f'Invalid index for {key}')
132140

133141
# Try every possible datastore state (exponential in the number of keys) against the
134142
# condition, and the ones that work are the control values for the new op.
@@ -140,12 +148,11 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
140148

141149
# Rearrange these into the format expected by SumOfProducts
142150
products = [
143-
[i for key in keys for i in store.records[key][0]]
151+
[val for k, i in keys for val in store.records[k][i]]
144152
for store in compatible_datastores
145153
]
146-
147154
control_values = ops.SumOfProducts(products)
148-
qs = [q for key in keys for q in measurement_qubits[key]]
155+
qs = [q for k, i in keys for q in measurement_qubits[k][i]]
149156
return op.without_classical_controls().controlled_by(*qs, control_values=control_values)
150157
return op
151158

@@ -155,14 +162,15 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
155162
tags_to_ignore=context.tags_to_ignore if context else (),
156163
raise_if_add_qubits=False,
157164
).unfreeze()
158-
for k, qubits in measurement_qubits.items():
159-
circuit.append(ops.measure(*qubits, key=k))
165+
for k, qubits_list in measurement_qubits.items():
166+
for qubits in qubits_list:
167+
circuit.append(ops.measure(*qubits, key=k))
160168
return circuit
161169

162170

163171
def _all_possible_datastore_states(
164-
keys: Iterable['cirq.MeasurementKey'],
165-
measurement_qubits: Mapping['cirq.MeasurementKey', Iterable['cirq.Qid']],
172+
keys: Iterable[Tuple['cirq.MeasurementKey', int]],
173+
measurement_qubits: Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]],
166174
) -> Iterable['cirq.ClassicalDataStoreReader']:
167175
"""The cartesian product of all possible DataStore states for the given keys."""
168176
# First we get the list of all possible values. So if we have a key mapped to qubits of shape
@@ -179,17 +187,28 @@ def _all_possible_datastore_states(
179187
# ((1, 1), (0,)),
180188
# ((1, 1), (1,)),
181189
# ((1, 1), (2,))]
182-
all_values = itertools.product(
190+
all_possible_measurements = itertools.product(
183191
*[
184-
tuple(itertools.product(*[range(q.dimension) for q in measurement_qubits[k]]))
185-
for k in keys
192+
tuple(itertools.product(*[range(q.dimension) for q in measurement_qubits[k][i]]))
193+
for k, i in keys
186194
]
187195
)
188-
# Then we create the ClassicalDataDictionaryStore for each of the above.
189-
for sequences in all_values:
190-
lookup = {k: [sequence] for k, sequence in zip(keys, sequences)}
196+
# Then we create the ClassicalDataDictionaryStore for each of the above. A `measurement_list`
197+
# is a single row of the above example, and can be zipped with `keys`.
198+
for measurement_list in all_possible_measurements:
199+
# Initialize a set of measurement records for this iteration. This will have the same shape
200+
# as `measurement_qubits` but zeros for all measurements.
201+
records = {
202+
key: [(0,) * len(qubits) for qubits in qubits_list]
203+
for key, qubits_list in measurement_qubits.items()
204+
}
205+
# Set the measurement values from the current row of the above, for each key/index we care
206+
# about.
207+
for (k, i), measurement in zip(keys, measurement_list):
208+
records[k][i] = measurement
209+
# Finally yield this sample to the consumer.
191210
yield value.ClassicalDataDictionaryStore(
192-
_records=lookup, _measured_qubits={k: [tuple(measurement_qubits[k])] for k in keys}
211+
_records=records, _measured_qubits=measurement_qubits
193212
)
194213

195214

cirq-core/cirq/transformers/measurement_transformers_test.py

+51-17
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,40 @@ def test_multi_qubit_control():
445445
)
446446

447447

448+
@pytest.mark.parametrize('index', [-3, -2, -1, 0, 1, 2])
449+
def test_repeated(index: int):
450+
q0, q1 = cirq.LineQubit.range(2)
451+
circuit = cirq.Circuit(
452+
cirq.measure(q0, key='a'), # The control measurement when `index` is 0 or -2
453+
cirq.X(q0),
454+
cirq.measure(q0, key='a'), # The control measurement when `index` is 1 or -1
455+
cirq.X(q1).with_classical_controls(cirq.KeyCondition(cirq.MeasurementKey('a'), index)),
456+
cirq.measure(q1, key='b'),
457+
)
458+
if index in [-3, 2]:
459+
with pytest.raises(ValueError, match='Invalid index'):
460+
_ = cirq.defer_measurements(circuit)
461+
return
462+
assert_equivalent_to_deferred(circuit)
463+
deferred = cirq.defer_measurements(circuit)
464+
q_ma = _MeasurementQid('a', q0) # The ancilla qubit created for the first `a` measurement
465+
q_ma1 = _MeasurementQid('a', q0, 1) # The ancilla qubit created for the second `a` measurement
466+
# The ancilla used for control should match the measurement used for control above.
467+
q_expected_control = q_ma if index in [0, -2] else q_ma1
468+
cirq.testing.assert_same_circuits(
469+
deferred,
470+
cirq.Circuit(
471+
cirq.CX(q0, q_ma),
472+
cirq.X(q0),
473+
cirq.CX(q0, q_ma1),
474+
cirq.Moment(cirq.CX(q_expected_control, q1)),
475+
cirq.measure(q_ma, key='a'),
476+
cirq.measure(q_ma1, key='a'),
477+
cirq.measure(q1, key='b'),
478+
),
479+
)
480+
481+
448482
def test_diagram():
449483
q0, q1, q2, q3 = cirq.LineQubit.range(4)
450484
circuit = cirq.Circuit(
@@ -457,23 +491,23 @@ def test_diagram():
457491
cirq.testing.assert_has_diagram(
458492
deferred,
459493
"""
460-
┌────┐
461-
0: ─────────────────@───────X────────M('c')───
462-
│ │
463-
1: ─────────────────┼─@──────────────M────────
464-
│ │ │
465-
2: ─────────────────┼@┼──────────────M────────
466-
│││ │
467-
3: ─────────────────┼┼┼@─────────────M────────
468-
││││
469-
M('a', q=q(0)): ────X┼┼┼────M('a')────────────
470-
│││ │
471-
M('a', q=q(2)): ─────X┼┼────M─────────────────
472-
││
473-
M('b', q=q(1)): ──────X┼────M('b')────────────
474-
│ │
475-
M('b', q=q(3)): ───────X────M─────────────────
476-
└────┘
494+
┌────┐
495+
0: ────────────────────@───────X────────M('c')───
496+
│ │
497+
1: ────────────────────┼─@──────────────M────────
498+
│ │ │
499+
2: ────────────────────┼@┼──────────────M────────
500+
│││ │
501+
3: ────────────────────┼┼┼@─────────────M────────
502+
││││
503+
M('a[0]', q=q(0)): ────X┼┼┼────M('a')────────────
504+
│││ │
505+
M('a[0]', q=q(2)): ─────X┼┼────M─────────────────
506+
││
507+
M('b[0]', q=q(1)): ──────X┼────M('b')────────────
508+
│ │
509+
M('b[0]', q=q(3)): ───────X────M─────────────────
510+
└────┘
477511
""",
478512
use_unicode_characters=True,
479513
)

0 commit comments

Comments
 (0)