16
16
17
17
import numpy as np
18
18
19
- from cirq import protocols , value
19
+ from cirq import _compat , protocols , value
20
20
from cirq .ops import raw_types
21
21
22
22
if TYPE_CHECKING :
@@ -40,6 +40,7 @@ def __init__(
40
40
key : Union [str , 'cirq.MeasurementKey' ] = '' ,
41
41
invert_mask : Tuple [bool , ...] = (),
42
42
qid_shape : Tuple [int , ...] = None ,
43
+ confusion_map : Optional [Dict [Tuple [int , ...], np .ndarray ]] = None ,
43
44
) -> None :
44
45
"""Inits MeasurementGate.
45
46
@@ -52,10 +53,15 @@ def __init__(
52
53
Qubits with indices past the end of the mask are not flipped.
53
54
qid_shape: Specifies the dimension of each qid the measurement
54
55
applies to. The default is 2 for every qubit.
56
+ confusion_map: A map of qubit index sets (using indices in the
57
+ operation generated from this gate) to the 2D confusion matrix
58
+ for those qubits. Indices not included use the identity.
59
+ Applied before invert_mask if both are provided.
55
60
56
61
Raises:
57
- ValueError: If the length of invert_mask is greater than num_qubits.
58
- or if the length of qid_shape doesn't equal num_qubits.
62
+ ValueError: If invert_mask or confusion_map have indices
63
+ greater than the available qubit indices, or if the length of
64
+ qid_shape doesn't equal num_qubits.
59
65
"""
60
66
if qid_shape is None :
61
67
if num_qubits is None :
@@ -74,6 +80,9 @@ def __init__(
74
80
self ._invert_mask = invert_mask or ()
75
81
if self .invert_mask is not None and len (self .invert_mask ) > self .num_qubits ():
76
82
raise ValueError ('len(invert_mask) > num_qubits' )
83
+ self ._confusion_map = confusion_map or {}
84
+ if any (x >= self .num_qubits () for idx in self ._confusion_map for x in idx ):
85
+ raise ValueError ('Confusion matrices have index out of bounds.' )
77
86
78
87
@property
79
88
def key (self ) -> str :
@@ -87,6 +96,10 @@ def mkey(self) -> 'cirq.MeasurementKey':
87
96
def invert_mask (self ) -> Tuple [bool , ...]:
88
97
return self ._invert_mask
89
98
99
+ @property
100
+ def confusion_map (self ) -> Dict [Tuple [int , ...], np .ndarray ]:
101
+ return self ._confusion_map
102
+
90
103
def _qid_shape_ (self ) -> Tuple [int , ...]:
91
104
return self ._qid_shape
92
105
@@ -98,7 +111,11 @@ def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'MeasurementGate':
98
111
if key == self .key :
99
112
return self
100
113
return MeasurementGate (
101
- self .num_qubits (), key = key , invert_mask = self .invert_mask , qid_shape = self ._qid_shape
114
+ self .num_qubits (),
115
+ key = key ,
116
+ invert_mask = self .invert_mask ,
117
+ qid_shape = self ._qid_shape ,
118
+ confusion_map = self .confusion_map ,
102
119
)
103
120
104
121
def _with_key_path_ (self , path : Tuple [str , ...]):
@@ -116,14 +133,22 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
116
133
return self .with_key (protocols .with_measurement_key_mapping (self .mkey , key_map ))
117
134
118
135
def with_bits_flipped (self , * bit_positions : int ) -> 'MeasurementGate' :
119
- """Toggles whether or not the measurement inverts various outputs."""
136
+ """Toggles whether or not the measurement inverts various outputs.
137
+
138
+ This only affects the invert_mask, which is applied after confusion
139
+ matrices if any are defined.
140
+ """
120
141
old_mask = self .invert_mask or ()
121
142
n = max (len (old_mask ) - 1 , * bit_positions ) + 1
122
143
new_mask = [k < len (old_mask ) and old_mask [k ] for k in range (n )]
123
144
for b in bit_positions :
124
145
new_mask [b ] = not new_mask [b ]
125
146
return MeasurementGate (
126
- self .num_qubits (), key = self .key , invert_mask = tuple (new_mask ), qid_shape = self ._qid_shape
147
+ self .num_qubits (),
148
+ key = self .key ,
149
+ invert_mask = tuple (new_mask ),
150
+ qid_shape = self ._qid_shape ,
151
+ confusion_map = self .confusion_map ,
127
152
)
128
153
129
154
def full_invert_mask (self ) -> Tuple [bool , ...]:
@@ -166,12 +191,17 @@ def _circuit_diagram_info_(
166
191
self , args : 'cirq.CircuitDiagramInfoArgs'
167
192
) -> 'cirq.CircuitDiagramInfo' :
168
193
symbols = ['M' ] * self .num_qubits ()
169
-
170
- # Show which output bits are negated.
171
- if self .invert_mask :
172
- for i , b in enumerate (self .invert_mask ):
173
- if b :
174
- symbols [i ] = '!M'
194
+ flipped_indices = {i for i , x in enumerate (self .full_invert_mask ()) if x }
195
+ confused_indices = {x for idxs in self .confusion_map for x in idxs }
196
+
197
+ # Show which output bits are negated and/or confused.
198
+ for i in range (self .num_qubits ()):
199
+ prefix = ''
200
+ if i in flipped_indices :
201
+ prefix += '!'
202
+ if i in confused_indices :
203
+ prefix += '?'
204
+ symbols [i ] = prefix + symbols [i ]
175
205
176
206
# Mention the measurement key.
177
207
label_map = args .label_map or {}
@@ -184,7 +214,7 @@ def _circuit_diagram_info_(
184
214
return protocols .CircuitDiagramInfo (symbols )
185
215
186
216
def _qasm_ (self , args : 'cirq.QasmArgs' , qubits : Tuple ['cirq.Qid' , ...]) -> Optional [str ]:
187
- if not all (d == 2 for d in self ._qid_shape ):
217
+ if self . confusion_map or not all (d == 2 for d in self ._qid_shape ):
188
218
return NotImplemented
189
219
args .validate_version ('2.0' )
190
220
invert_mask = self .invert_mask
@@ -202,7 +232,7 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio
202
232
def _quil_ (
203
233
self , qubits : Tuple ['cirq.Qid' , ...], formatter : 'cirq.QuilFormatter'
204
234
) -> Optional [str ]:
205
- if not all (d == 2 for d in self ._qid_shape ):
235
+ if self . confusion_map or not all (d == 2 for d in self ._qid_shape ):
206
236
return NotImplemented
207
237
invert_mask = self .invert_mask
208
238
if len (invert_mask ) < len (qubits ):
@@ -222,28 +252,39 @@ def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
222
252
args .append (f'key={ self .mkey !r} ' )
223
253
if self .invert_mask :
224
254
args .append (f'invert_mask={ self .invert_mask !r} ' )
255
+ if self .confusion_map :
256
+ proper_map_str = ', ' .join (
257
+ f"{ k !r} : { _compat .proper_repr (v )} " for k , v in self .confusion_map .items ()
258
+ )
259
+ args .append (f'confusion_map={{{ proper_map_str } }}' )
225
260
arg_list = ', ' .join (args )
226
261
return f'cirq.measure({ arg_list } )'
227
262
228
263
def __repr__ (self ):
229
- qid_shape_arg = ''
264
+ args = [ f' { self . num_qubits ()!r } ' , f' { self . mkey !r } ' , f' { self . invert_mask } ' ]
230
265
if any (d != 2 for d in self ._qid_shape ):
231
- qid_shape_arg = f', { self ._qid_shape !r} '
232
- return (
233
- f'cirq.MeasurementGate('
234
- f'{ self .num_qubits ()!r} , '
235
- f'{ self .mkey !r} , '
236
- f'{ self .invert_mask } '
237
- f'{ qid_shape_arg } )'
238
- )
266
+ args .append (f'qid_shape={ self ._qid_shape !r} ' )
267
+ if self .confusion_map :
268
+ proper_map_str = ', ' .join (
269
+ f"{ k !r} : { _compat .proper_repr (v )} " for k , v in self .confusion_map .items ()
270
+ )
271
+ args .append (f'confusion_map={{{ proper_map_str } }}' )
272
+ return f'cirq.MeasurementGate({ ", " .join (args )} )'
239
273
240
274
def _value_equality_values_ (self ) -> Any :
241
- return self .key , self .invert_mask , self ._qid_shape
275
+ hashable_cmap = frozenset (
276
+ (idxs , tuple (v for _ , v in np .ndenumerate (cmap )))
277
+ for idxs , cmap in self ._confusion_map .items ()
278
+ )
279
+ return self .key , self .invert_mask , self ._qid_shape , hashable_cmap
242
280
243
281
def _json_dict_ (self ) -> Dict [str , Any ]:
244
- other = {}
282
+ other : Dict [ str , Any ] = {}
245
283
if not all (d == 2 for d in self ._qid_shape ):
246
284
other ['qid_shape' ] = self ._qid_shape
285
+ if self .confusion_map :
286
+ json_cmap = [(k , v .tolist ()) for k , v in self .confusion_map .items ()]
287
+ other ['confusion_map' ] = json_cmap
247
288
return {
248
289
'num_qubits' : len (self ._qid_shape ),
249
290
'key' : self .key ,
@@ -252,12 +293,15 @@ def _json_dict_(self) -> Dict[str, Any]:
252
293
}
253
294
254
295
@classmethod
255
- def _from_json_dict_ (cls , num_qubits , key , invert_mask , qid_shape = None , ** kwargs ):
296
+ def _from_json_dict_ (
297
+ cls , num_qubits , key , invert_mask , qid_shape = None , confusion_map = None , ** kwargs
298
+ ):
256
299
return cls (
257
300
num_qubits = num_qubits ,
258
301
key = value .MeasurementKey .parse_serialized (key ),
259
302
invert_mask = tuple (invert_mask ),
260
303
qid_shape = None if qid_shape is None else tuple (qid_shape ),
304
+ confusion_map = {tuple (k ): np .array (v ) for k , v in confusion_map or []},
261
305
)
262
306
263
307
def _has_stabilizer_effect_ (self ) -> Optional [bool ]:
@@ -268,7 +312,7 @@ def _act_on_(self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq
268
312
269
313
if not isinstance (sim_state , SimulationState ):
270
314
return NotImplemented
271
- sim_state .measure (qubits , self .key , self .full_invert_mask ())
315
+ sim_state .measure (qubits , self .key , self .full_invert_mask (), self . confusion_map )
272
316
return True
273
317
274
318
0 commit comments