23
23
import dataclasses
24
24
import numpy as np
25
25
26
- from cirq import circuits , ops , protocols , study
26
+ from cirq import circuits , ops , protocols , value , study
27
27
from cirq ._compat import proper_repr
28
28
29
29
if TYPE_CHECKING :
30
30
import cirq
31
31
32
32
33
33
INT_TYPE = Union [int , np .integer ]
34
- MEASUREMENT_KEY_SEPARATOR = ': '
34
+ REPETITION_ID_SEPARATOR = '- '
35
35
36
36
37
37
def default_repetition_ids (repetitions : int ) -> Optional [List [str ]]:
@@ -40,40 +40,18 @@ def default_repetition_ids(repetitions: int) -> Optional[List[str]]:
40
40
return None
41
41
42
42
43
- def cartesian_product_of_string_lists (list1 : Optional [List [str ]], list2 : Optional [List [str ]]):
43
+ def _full_join_string_lists (list1 : Optional [List [str ]], list2 : Optional [List [str ]]):
44
44
if list1 is None and list2 is None :
45
45
return None # coverage: ignore
46
46
if list1 is None :
47
47
return list2 # coverage: ignore
48
48
if list2 is None :
49
49
return list1
50
50
return [
51
- f'{ MEASUREMENT_KEY_SEPARATOR .join ([first , second ])} ' for first in list1 for second in list2
51
+ f'{ REPETITION_ID_SEPARATOR .join ([first , second ])} ' for first in list1 for second in list2
52
52
]
53
53
54
54
55
- def split_maybe_indexed_key (maybe_indexed_key : str ) -> List [str ]:
56
- """Given a measurement_key, splits into index (series of repetition_ids) and unindexed key
57
- parts. For a key without index, returns the unaltered key in a list. Assumes that the
58
- unindexed measurement key does not contain the MEASUREMENT_KEY_SEPARATOR. This is validated by
59
- the `CircuitOperation` constructor."""
60
- return maybe_indexed_key .rsplit (MEASUREMENT_KEY_SEPARATOR , maxsplit = 1 )
61
-
62
-
63
- def get_unindexed_key (maybe_indexed_key : str ) -> str :
64
- """Given a measurement_key, returns the unindexed key part (without the series of prefixed
65
- repetition_ids). For an already unindexed key, returns the unaltered key."""
66
- return split_maybe_indexed_key (maybe_indexed_key )[- 1 ]
67
-
68
-
69
- def remap_maybe_indexed_key (key_map : Dict [str , str ], key : str ) -> str :
70
- """Given a key map and a measurement_key (indexed or unindexed), returns the remapped key in
71
- the same format. Does not modify the index (series of repetition_ids) part, if it exists."""
72
- split_key = split_maybe_indexed_key (key )
73
- split_key [- 1 ] = key_map .get (split_key [- 1 ], split_key [- 1 ])
74
- return MEASUREMENT_KEY_SEPARATOR .join (split_key )
75
-
76
-
77
55
@dataclasses .dataclass (frozen = True )
78
56
class CircuitOperation (ops .Operation ):
79
57
"""An operation that encapsulates a circuit.
@@ -90,6 +68,7 @@ class CircuitOperation(ops.Operation):
90
68
The keys and values should be unindexed (i.e. without repetition_ids).
91
69
The values cannot contain the `MEASUREMENT_KEY_SEPARATOR`.
92
70
param_resolver: Resolved values for parameters in the circuit.
71
+ parent_path: A tuple of identifiers for any parent CircuitOperations containing this one.
93
72
repetition_ids: List of identifiers for each repetition of the
94
73
CircuitOperation. If populated, the length should be equal to the
95
74
repetitions. If not populated and abs(`repetitions`) > 1, it is
@@ -104,6 +83,7 @@ class CircuitOperation(ops.Operation):
104
83
measurement_key_map : Dict [str , str ] = dataclasses .field (default_factory = dict )
105
84
param_resolver : study .ParamResolver = study .ParamResolver ()
106
85
repetition_ids : Optional [List [str ]] = dataclasses .field (default = None )
86
+ parent_path : Tuple [str , ...] = dataclasses .field (default_factory = tuple )
107
87
108
88
def __post_init__ (self ):
109
89
if not isinstance (self .circuit , circuits .FrozenCircuit ):
@@ -128,27 +108,12 @@ def __post_init__(self):
128
108
129
109
# Disallow mapping to keys containing the `MEASUREMENT_KEY_SEPARATOR`
130
110
for mapped_key in self .measurement_key_map .values ():
131
- if MEASUREMENT_KEY_SEPARATOR in mapped_key :
111
+ if value . MEASUREMENT_KEY_SEPARATOR in mapped_key :
132
112
raise ValueError (
133
- f'Mapping to invalid key: { mapped_key } . "{ MEASUREMENT_KEY_SEPARATOR } " '
113
+ f'Mapping to invalid key: { mapped_key } . "{ value . MEASUREMENT_KEY_SEPARATOR } " '
134
114
'is not allowed for measurement keys in a CircuitOperation'
135
115
)
136
116
137
- # Validate the keys for all direct child measurements. They are not allowed to contain
138
- # `MEASUREMENT_KEY_SEPARATOR`
139
- for _ , op in self .circuit .findall_operations (
140
- lambda op : not isinstance (op , CircuitOperation ) and protocols .is_measurement (op )
141
- ):
142
- for key in protocols .measurement_keys (op ):
143
- key = self .measurement_key_map .get (key , key )
144
- if MEASUREMENT_KEY_SEPARATOR in key :
145
- raise ValueError (
146
- f'Measurement { op } found to have invalid key: { key } . '
147
- f'"{ MEASUREMENT_KEY_SEPARATOR } " is not allowed for measurement keys '
148
- 'in a CircuitOperation. Consider remapping the key using '
149
- '`measurement_key_map` in the CircuitOperation constructor.'
150
- )
151
-
152
117
# Disallow qid mapping dimension conflicts.
153
118
for q , q_new in self .qubit_map .items ():
154
119
if q_new .dimension != q .dimension :
@@ -178,6 +143,7 @@ def __eq__(self, other) -> bool:
178
143
and self .param_resolver == other .param_resolver
179
144
and self .repetitions == other .repetitions
180
145
and self .repetition_ids == other .repetition_ids
146
+ and self .parent_path == other .parent_path
181
147
)
182
148
183
149
# Methods for getting post-mapping properties of the contained circuit.
@@ -195,12 +161,20 @@ def _qid_shape_(self) -> Tuple[int, ...]:
195
161
return tuple (q .dimension for q in self .qubits )
196
162
197
163
def _measurement_keys_ (self ) -> AbstractSet [str ]:
198
- circuit_keys = self .circuit .all_measurement_keys ()
164
+ circuit_keys = [
165
+ value .MeasurementKey .parse_serialized (key_str )
166
+ for key_str in self .circuit .all_measurement_keys ()
167
+ ]
199
168
if self .repetition_ids is not None :
200
- circuit_keys = cartesian_product_of_string_lists (
201
- self .repetition_ids , list (circuit_keys )
202
- )
203
- return {remap_maybe_indexed_key (self .measurement_key_map , key ) for key in circuit_keys }
169
+ circuit_keys = [
170
+ key .with_key_path_prefix (repetition_id )
171
+ for repetition_id in self .repetition_ids
172
+ for key in circuit_keys
173
+ ]
174
+ return {
175
+ str (protocols .with_measurement_key_mapping (key , self .measurement_key_map ))
176
+ for key in circuit_keys
177
+ }
204
178
205
179
def _parameter_names_ (self ) -> AbstractSet [str ]:
206
180
return {
@@ -225,32 +199,9 @@ def _decompose_(self) -> 'cirq.OP_TREE':
225
199
# If it's a measurement circuit with repetitions/repetition_ids, prefix the repetition_ids
226
200
# to measurements. Details at https://tinyurl.com/measurement-repeated-circuitop.
227
201
ops = [] # type: List[cirq.Operation]
228
- for parent_id in self .repetition_ids :
229
- for op in result .all_operations ():
230
- if isinstance (op , CircuitOperation ):
231
- # For a CircuitOperation, prefix the current repetition_id to the children
232
- # repetition_ids.
233
- ops .append (
234
- op .with_repetition_ids (
235
- # If `op.repetition_ids` is None, this will return `[parent_id]`.
236
- cartesian_product_of_string_lists ([parent_id ], op .repetition_ids )
237
- )
238
- )
239
- elif protocols .is_measurement (op ):
240
- # For a non-CircuitOperation measurement, prefix the current repetition_id
241
- # to the children measurement keys. Implemented by creating a mapping and
242
- # using the with_measurement_key_mapping protocol.
243
- ops .append (
244
- protocols .with_measurement_key_mapping (
245
- op ,
246
- key_map = {
247
- key : f'{ MEASUREMENT_KEY_SEPARATOR .join ([parent_id , key ])} '
248
- for key in protocols .measurement_keys (op )
249
- },
250
- )
251
- )
252
- else :
253
- ops .append (op )
202
+ for repetition_id in self .repetition_ids :
203
+ path = self .parent_path + (repetition_id ,)
204
+ ops += protocols .with_key_path (result , path ).all_operations ()
254
205
return ops
255
206
256
207
# Methods for string representation of the operation.
@@ -265,6 +216,8 @@ def __repr__(self):
265
216
args += f'measurement_key_map={ proper_repr (self .measurement_key_map )} ,\n '
266
217
if self .param_resolver :
267
218
args += f'param_resolver={ proper_repr (self .param_resolver )} ,\n '
219
+ if self .parent_path :
220
+ args += f'parent_path={ proper_repr (self .parent_path )} ,\n '
268
221
if self .repetition_ids != self ._default_repetition_ids ():
269
222
# Default repetition_ids need not be specified.
270
223
args += f'repetition_ids={ proper_repr (self .repetition_ids )} ,\n '
@@ -291,6 +244,8 @@ def dict_str(d: Dict) -> str:
291
244
args .append (f'key_map={ dict_str (self .measurement_key_map )} ' )
292
245
if self .param_resolver :
293
246
args .append (f'params={ self .param_resolver .param_dict } ' )
247
+ if self .parent_path :
248
+ args .append (f'parent_path={ self .parent_path } ' )
294
249
if self .repetition_ids != self ._default_repetition_ids ():
295
250
# Default repetition_ids need not be specified.
296
251
args .append (f'repetition_ids={ self .repetition_ids } ' )
@@ -313,6 +268,7 @@ def __hash__(self):
313
268
frozenset (self .qubit_map .items ()),
314
269
frozenset (self .measurement_key_map .items ()),
315
270
self .param_resolver ,
271
+ self .parent_path ,
316
272
tuple ([] if self .repetition_ids is None else self .repetition_ids ),
317
273
)
318
274
),
@@ -330,6 +286,7 @@ def _json_dict_(self):
330
286
'measurement_key_map' : self .measurement_key_map ,
331
287
'param_resolver' : self .param_resolver ,
332
288
'repetition_ids' : self .repetition_ids ,
289
+ 'parent_path' : self .parent_path ,
333
290
}
334
291
335
292
@classmethod
@@ -341,13 +298,15 @@ def _from_json_dict_(
341
298
measurement_key_map ,
342
299
param_resolver ,
343
300
repetition_ids ,
301
+ parent_path = (),
344
302
** kwargs ,
345
303
):
346
304
return (
347
305
cls (circuit )
348
306
.with_qubit_mapping (dict (qubit_map ))
349
307
.with_measurement_key_mapping (measurement_key_map )
350
308
.with_params (param_resolver )
309
+ .with_key_path (tuple (parent_path ))
351
310
.repeat (repetitions , repetition_ids )
352
311
)
353
312
@@ -408,13 +367,19 @@ def repeat(
408
367
)
409
368
410
369
# If `self.repetition_ids` is None, this will just return `repetition_ids`.
411
- repetition_ids = cartesian_product_of_string_lists (repetition_ids , self .repetition_ids )
370
+ repetition_ids = _full_join_string_lists (repetition_ids , self .repetition_ids )
412
371
413
372
return self .replace (repetitions = final_repetitions , repetition_ids = repetition_ids )
414
373
415
374
def __pow__ (self , power : int ) -> 'CircuitOperation' :
416
375
return self .repeat (power )
417
376
377
+ def _with_key_path_ (self , path : Tuple [str , ...]):
378
+ return dataclasses .replace (self , parent_path = path )
379
+
380
+ def with_key_path (self , path : Tuple [str , ...]):
381
+ return self ._with_key_path_ (path )
382
+
418
383
def with_repetition_ids (self , repetition_ids : List [str ]) -> 'CircuitOperation' :
419
384
return self .replace (repetition_ids = repetition_ids )
420
385
@@ -501,7 +466,7 @@ def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOpera
501
466
"""
502
467
new_map = {}
503
468
for k in self .circuit .all_measurement_keys ():
504
- k = get_unindexed_key (k )
469
+ k = value . MeasurementKey . parse_serialized (k ). name
505
470
k_new = self .measurement_key_map .get (k , k )
506
471
k_new = key_map .get (k_new , k_new )
507
472
if k_new != k :
0 commit comments