Skip to content

Commit 2bff437

Browse files
Measurement confusion maps (#5480)
* Measurement confusion maps * format * mypy+format * Error on deferred confused measure * Also change SimulatesSamples behavior. * Test SimulatesSamples * docstring zero note
1 parent c504c38 commit 2bff437

14 files changed

+372
-59
lines changed

cirq-core/cirq/experiments/readout_confusion_matrix.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def __init__(
7676
the corresponding confusion matrix.
7777
repetitions: The number of repetitions that were used to estimate the confusion
7878
matrices.
79-
timestamp: The time the data was taken, in seconds since the epoch.
79+
timestamp: The time the data was taken, in seconds since the epoch. This will be
80+
zero for fake data (i.e. data not generated from an experiment).
8081
8182
Raises:
8283
ValueError: If length of `confusion_matrices` and `measure_qubits` is different or if
@@ -113,6 +114,34 @@ def __init__(
113114
if sum(len(q) for q in self._measure_qubits) != len(self._qubits):
114115
raise ValueError(f"Repeated qubits not allowed in measure_qubits: {measure_qubits}.")
115116

117+
@classmethod
118+
def from_measurement(
119+
cls, gate: ops.MeasurementGate, qubits: Sequence['cirq.Qid']
120+
) -> 'TensoredConfusionMatrices':
121+
"""Generates TCM for the confusion map in a MeasurementGate.
122+
123+
This ignores any invert_mask defined for the gate - it only replicates the confusion map.
124+
125+
Args:
126+
gate: the MeasurementGate to match.
127+
qubits: qubits the gate is applied to.
128+
129+
Returns:
130+
TensoredConfusionMatrices matching the confusion map of the given gate.
131+
132+
Raises:
133+
ValueError: if the gate has no confusion map.
134+
"""
135+
if not gate.confusion_map:
136+
raise ValueError(f"Measurement has no confusion matrices: {gate}")
137+
confusion_matrices = []
138+
ordered_qubits = []
139+
for indices, cm in gate.confusion_map.items():
140+
confusion_matrices.append(cm)
141+
ordered_qubits.append(tuple(qubits[idx] for idx in indices))
142+
# Use zero for reps/timestamp to mark fake data.
143+
return cls(confusion_matrices, ordered_qubits, repetitions=0, timestamp=0)
144+
116145
@property
117146
def repetitions(self) -> int:
118147
"""The number of repetitions that were used to estimate the confusion matrices."""

cirq-core/cirq/experiments/readout_confusion_matrix_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,27 @@ def l2norm(result: np.ndarray):
8383
assert l2norm(corrected_result) <= l2norm(sampled_result)
8484

8585

86+
def test_from_measurement():
87+
qubits = cirq.LineQubit.range(3)
88+
confuse_02 = np.array([[0, 1, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0]])
89+
confuse_1 = np.array([[0, 1], [1, 0]])
90+
op = cirq.measure(
91+
*qubits,
92+
key='a',
93+
invert_mask=(True, False),
94+
confusion_map={(0, 2): confuse_02, (1,): confuse_1},
95+
)
96+
tcm = cirq.TensoredConfusionMatrices.from_measurement(op.gate, op.qubits)
97+
expected_tcm = cirq.TensoredConfusionMatrices(
98+
[confuse_02, confuse_1], ((qubits[0], qubits[2]), (qubits[1],)), repetitions=0, timestamp=0
99+
)
100+
assert tcm == expected_tcm
101+
102+
no_cm_op = cirq.measure(*qubits, key='a')
103+
with pytest.raises(ValueError, match="Measurement has no confusion matrices"):
104+
_ = cirq.TensoredConfusionMatrices.from_measurement(no_cm_op.gate, no_cm_op.qubits)
105+
106+
86107
def test_readout_confusion_matrix_raises():
87108
num_qubits = 2
88109
confusion_matrix = get_expected_cm(num_qubits, 0.1, 0.2)

cirq-core/cirq/ops/measure_util.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Callable, Iterable, List, overload, Optional, Tuple, TYPE_CHECKING, Union
15+
from typing import Callable, Dict, Iterable, List, overload, Optional, Tuple, TYPE_CHECKING, Union
1616

1717
import numpy as np
1818

@@ -107,6 +107,7 @@ def measure(
107107
*target,
108108
key: Optional[Union[str, 'cirq.MeasurementKey']] = None,
109109
invert_mask: Tuple[bool, ...] = (),
110+
confusion_map: Optional[Dict[Tuple[int, ...], np.ndarray]] = None,
110111
) -> raw_types.Operation:
111112
"""Returns a single MeasurementGate applied to all the given qubits.
112113
@@ -121,6 +122,10 @@ def measure(
121122
invert_mask: A list of Truthy or Falsey values indicating whether
122123
the corresponding qubits should be flipped. None indicates no
123124
inverting should be done.
125+
confusion_map: A map of qubit index sets (using indices in
126+
`target`) to the 2D confusion matrix for those qubits. Indices
127+
not included use the identity. Applied before invert_mask if both
128+
are provided.
124129
125130
Returns:
126131
An operation targeting the given qubits with a measurement.
@@ -146,7 +151,7 @@ def measure(
146151
if key is None:
147152
key = _default_measurement_key(targets)
148153
qid_shape = protocols.qid_shape(targets)
149-
return MeasurementGate(len(targets), key, invert_mask, qid_shape).on(*targets)
154+
return MeasurementGate(len(targets), key, invert_mask, qid_shape, confusion_map).on(*targets)
150155

151156

152157
@overload

cirq-core/cirq/ops/measure_util_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def test_measure_qubits():
4646
assert cirq.measure(cirq.LineQid.for_qid_shape((1, 2, 3)), key='a') == cirq.MeasurementGate(
4747
num_qubits=3, key='a', qid_shape=(1, 2, 3)
4848
).on(*cirq.LineQid.for_qid_shape((1, 2, 3)))
49+
cmap = {(0,): np.array([[0, 1], [1, 0]])}
50+
assert cirq.measure(a, confusion_map=cmap) == cirq.MeasurementGate(
51+
num_qubits=1, key='a', confusion_map=cmap
52+
).on(a)
4953

5054
with pytest.raises(ValueError, match='ndarray'):
5155
_ = cirq.measure(np.array([1, 0]))

cirq-core/cirq/ops/measurement_gate.py

+71-27
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import numpy as np
1818

19-
from cirq import protocols, value
19+
from cirq import _compat, protocols, value
2020
from cirq.ops import raw_types
2121

2222
if TYPE_CHECKING:
@@ -40,6 +40,7 @@ def __init__(
4040
key: Union[str, 'cirq.MeasurementKey'] = '',
4141
invert_mask: Tuple[bool, ...] = (),
4242
qid_shape: Tuple[int, ...] = None,
43+
confusion_map: Optional[Dict[Tuple[int, ...], np.ndarray]] = None,
4344
) -> None:
4445
"""Inits MeasurementGate.
4546
@@ -52,10 +53,15 @@ def __init__(
5253
Qubits with indices past the end of the mask are not flipped.
5354
qid_shape: Specifies the dimension of each qid the measurement
5455
applies to. The default is 2 for every qubit.
56+
confusion_map: A map of qubit index sets (using indices in the
57+
operation generated from this gate) to the 2D confusion matrix
58+
for those qubits. Indices not included use the identity.
59+
Applied before invert_mask if both are provided.
5560
5661
Raises:
57-
ValueError: If the length of invert_mask is greater than num_qubits.
58-
or if the length of qid_shape doesn't equal num_qubits.
62+
ValueError: If invert_mask or confusion_map have indices
63+
greater than the available qubit indices, or if the length of
64+
qid_shape doesn't equal num_qubits.
5965
"""
6066
if qid_shape is None:
6167
if num_qubits is None:
@@ -74,6 +80,9 @@ def __init__(
7480
self._invert_mask = invert_mask or ()
7581
if self.invert_mask is not None and len(self.invert_mask) > self.num_qubits():
7682
raise ValueError('len(invert_mask) > num_qubits')
83+
self._confusion_map = confusion_map or {}
84+
if any(x >= self.num_qubits() for idx in self._confusion_map for x in idx):
85+
raise ValueError('Confusion matrices have index out of bounds.')
7786

7887
@property
7988
def key(self) -> str:
@@ -87,6 +96,10 @@ def mkey(self) -> 'cirq.MeasurementKey':
8796
def invert_mask(self) -> Tuple[bool, ...]:
8897
return self._invert_mask
8998

99+
@property
100+
def confusion_map(self) -> Dict[Tuple[int, ...], np.ndarray]:
101+
return self._confusion_map
102+
90103
def _qid_shape_(self) -> Tuple[int, ...]:
91104
return self._qid_shape
92105

@@ -98,7 +111,11 @@ def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'MeasurementGate':
98111
if key == self.key:
99112
return self
100113
return MeasurementGate(
101-
self.num_qubits(), key=key, invert_mask=self.invert_mask, qid_shape=self._qid_shape
114+
self.num_qubits(),
115+
key=key,
116+
invert_mask=self.invert_mask,
117+
qid_shape=self._qid_shape,
118+
confusion_map=self.confusion_map,
102119
)
103120

104121
def _with_key_path_(self, path: Tuple[str, ...]):
@@ -116,14 +133,22 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
116133
return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map))
117134

118135
def with_bits_flipped(self, *bit_positions: int) -> 'MeasurementGate':
119-
"""Toggles whether or not the measurement inverts various outputs."""
136+
"""Toggles whether or not the measurement inverts various outputs.
137+
138+
This only affects the invert_mask, which is applied after confusion
139+
matrices if any are defined.
140+
"""
120141
old_mask = self.invert_mask or ()
121142
n = max(len(old_mask) - 1, *bit_positions) + 1
122143
new_mask = [k < len(old_mask) and old_mask[k] for k in range(n)]
123144
for b in bit_positions:
124145
new_mask[b] = not new_mask[b]
125146
return MeasurementGate(
126-
self.num_qubits(), key=self.key, invert_mask=tuple(new_mask), qid_shape=self._qid_shape
147+
self.num_qubits(),
148+
key=self.key,
149+
invert_mask=tuple(new_mask),
150+
qid_shape=self._qid_shape,
151+
confusion_map=self.confusion_map,
127152
)
128153

129154
def full_invert_mask(self) -> Tuple[bool, ...]:
@@ -166,12 +191,17 @@ def _circuit_diagram_info_(
166191
self, args: 'cirq.CircuitDiagramInfoArgs'
167192
) -> 'cirq.CircuitDiagramInfo':
168193
symbols = ['M'] * self.num_qubits()
169-
170-
# Show which output bits are negated.
171-
if self.invert_mask:
172-
for i, b in enumerate(self.invert_mask):
173-
if b:
174-
symbols[i] = '!M'
194+
flipped_indices = {i for i, x in enumerate(self.full_invert_mask()) if x}
195+
confused_indices = {x for idxs in self.confusion_map for x in idxs}
196+
197+
# Show which output bits are negated and/or confused.
198+
for i in range(self.num_qubits()):
199+
prefix = ''
200+
if i in flipped_indices:
201+
prefix += '!'
202+
if i in confused_indices:
203+
prefix += '?'
204+
symbols[i] = prefix + symbols[i]
175205

176206
# Mention the measurement key.
177207
label_map = args.label_map or {}
@@ -184,7 +214,7 @@ def _circuit_diagram_info_(
184214
return protocols.CircuitDiagramInfo(symbols)
185215

186216
def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
187-
if not all(d == 2 for d in self._qid_shape):
217+
if self.confusion_map or not all(d == 2 for d in self._qid_shape):
188218
return NotImplemented
189219
args.validate_version('2.0')
190220
invert_mask = self.invert_mask
@@ -202,7 +232,7 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio
202232
def _quil_(
203233
self, qubits: Tuple['cirq.Qid', ...], formatter: 'cirq.QuilFormatter'
204234
) -> Optional[str]:
205-
if not all(d == 2 for d in self._qid_shape):
235+
if self.confusion_map or not all(d == 2 for d in self._qid_shape):
206236
return NotImplemented
207237
invert_mask = self.invert_mask
208238
if len(invert_mask) < len(qubits):
@@ -222,28 +252,39 @@ def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
222252
args.append(f'key={self.mkey!r}')
223253
if self.invert_mask:
224254
args.append(f'invert_mask={self.invert_mask!r}')
255+
if self.confusion_map:
256+
proper_map_str = ', '.join(
257+
f"{k!r}: {_compat.proper_repr(v)}" for k, v in self.confusion_map.items()
258+
)
259+
args.append(f'confusion_map={{{proper_map_str}}}')
225260
arg_list = ', '.join(args)
226261
return f'cirq.measure({arg_list})'
227262

228263
def __repr__(self):
229-
qid_shape_arg = ''
264+
args = [f'{self.num_qubits()!r}', f'{self.mkey!r}', f'{self.invert_mask}']
230265
if any(d != 2 for d in self._qid_shape):
231-
qid_shape_arg = f', {self._qid_shape!r}'
232-
return (
233-
f'cirq.MeasurementGate('
234-
f'{self.num_qubits()!r}, '
235-
f'{self.mkey!r}, '
236-
f'{self.invert_mask}'
237-
f'{qid_shape_arg})'
238-
)
266+
args.append(f'qid_shape={self._qid_shape!r}')
267+
if self.confusion_map:
268+
proper_map_str = ', '.join(
269+
f"{k!r}: {_compat.proper_repr(v)}" for k, v in self.confusion_map.items()
270+
)
271+
args.append(f'confusion_map={{{proper_map_str}}}')
272+
return f'cirq.MeasurementGate({", ".join(args)})'
239273

240274
def _value_equality_values_(self) -> Any:
241-
return self.key, self.invert_mask, self._qid_shape
275+
hashable_cmap = frozenset(
276+
(idxs, tuple(v for _, v in np.ndenumerate(cmap)))
277+
for idxs, cmap in self._confusion_map.items()
278+
)
279+
return self.key, self.invert_mask, self._qid_shape, hashable_cmap
242280

243281
def _json_dict_(self) -> Dict[str, Any]:
244-
other = {}
282+
other: Dict[str, Any] = {}
245283
if not all(d == 2 for d in self._qid_shape):
246284
other['qid_shape'] = self._qid_shape
285+
if self.confusion_map:
286+
json_cmap = [(k, v.tolist()) for k, v in self.confusion_map.items()]
287+
other['confusion_map'] = json_cmap
247288
return {
248289
'num_qubits': len(self._qid_shape),
249290
'key': self.key,
@@ -252,12 +293,15 @@ def _json_dict_(self) -> Dict[str, Any]:
252293
}
253294

254295
@classmethod
255-
def _from_json_dict_(cls, num_qubits, key, invert_mask, qid_shape=None, **kwargs):
296+
def _from_json_dict_(
297+
cls, num_qubits, key, invert_mask, qid_shape=None, confusion_map=None, **kwargs
298+
):
256299
return cls(
257300
num_qubits=num_qubits,
258301
key=value.MeasurementKey.parse_serialized(key),
259302
invert_mask=tuple(invert_mask),
260303
qid_shape=None if qid_shape is None else tuple(qid_shape),
304+
confusion_map={tuple(k): np.array(v) for k, v in confusion_map or []},
261305
)
262306

263307
def _has_stabilizer_effect_(self) -> Optional[bool]:
@@ -268,7 +312,7 @@ def _act_on_(self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq
268312

269313
if not isinstance(sim_state, SimulationState):
270314
return NotImplemented
271-
sim_state.measure(qubits, self.key, self.full_invert_mask())
315+
sim_state.measure(qubits, self.key, self.full_invert_mask(), self.confusion_map)
272316
return True
273317

274318

0 commit comments

Comments
 (0)