Skip to content

Commit b4c445f

Browse files
authored
Use MeasurementKey in CircuitOperation (#4086)
* Add the concept of measurement key path and use it for nested/repeated CircuitOperations. Also add `with_key_path` protocol. * Format and docstrings * json and other fixes * Change to immutable default for pylint * Make full_join_string_lists private
1 parent 6f9eedc commit b4c445f

26 files changed

+620
-135
lines changed

cirq-core/cirq/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@
430430
Duration,
431431
DURATION_LIKE,
432432
LinearDict,
433+
MEASUREMENT_KEY_SEPARATOR,
433434
MeasurementKey,
434435
PeriodicValue,
435436
RANDOM_STATE_OR_SEED_LIKE,
@@ -531,6 +532,7 @@
531532
trace_distance_from_angle_list,
532533
unitary,
533534
validate_mixture,
535+
with_key_path,
534536
with_measurement_key_mapping,
535537
)
536538

cirq-core/cirq/circuits/circuit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,11 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
897897
[protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments]
898898
)
899899

900+
def _with_key_path_(self, path: Tuple[str, ...]):
901+
return self._with_sliced_moments(
902+
[protocols.with_key_path(moment, path) for moment in self.moments]
903+
)
904+
900905
def _qid_shape_(self) -> Tuple[int, ...]:
901906
return self.qid_shape()
902907

cirq-core/cirq/circuits/circuit_operation.py

Lines changed: 41 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323
import dataclasses
2424
import numpy as np
2525

26-
from cirq import circuits, ops, protocols, study
26+
from cirq import circuits, ops, protocols, value, study
2727
from cirq._compat import proper_repr
2828

2929
if TYPE_CHECKING:
3030
import cirq
3131

3232

3333
INT_TYPE = Union[int, np.integer]
34-
MEASUREMENT_KEY_SEPARATOR = ':'
34+
REPETITION_ID_SEPARATOR = '-'
3535

3636

3737
def default_repetition_ids(repetitions: int) -> Optional[List[str]]:
@@ -40,40 +40,18 @@ def default_repetition_ids(repetitions: int) -> Optional[List[str]]:
4040
return None
4141

4242

43-
def cartesian_product_of_string_lists(list1: Optional[List[str]], list2: Optional[List[str]]):
43+
def _full_join_string_lists(list1: Optional[List[str]], list2: Optional[List[str]]):
4444
if list1 is None and list2 is None:
4545
return None # coverage: ignore
4646
if list1 is None:
4747
return list2 # coverage: ignore
4848
if list2 is None:
4949
return list1
5050
return [
51-
f'{MEASUREMENT_KEY_SEPARATOR.join([first, second])}' for first in list1 for second in list2
51+
f'{REPETITION_ID_SEPARATOR.join([first, second])}' for first in list1 for second in list2
5252
]
5353

5454

55-
def split_maybe_indexed_key(maybe_indexed_key: str) -> List[str]:
56-
"""Given a measurement_key, splits into index (series of repetition_ids) and unindexed key
57-
parts. For a key without index, returns the unaltered key in a list. Assumes that the
58-
unindexed measurement key does not contain the MEASUREMENT_KEY_SEPARATOR. This is validated by
59-
the `CircuitOperation` constructor."""
60-
return maybe_indexed_key.rsplit(MEASUREMENT_KEY_SEPARATOR, maxsplit=1)
61-
62-
63-
def get_unindexed_key(maybe_indexed_key: str) -> str:
64-
"""Given a measurement_key, returns the unindexed key part (without the series of prefixed
65-
repetition_ids). For an already unindexed key, returns the unaltered key."""
66-
return split_maybe_indexed_key(maybe_indexed_key)[-1]
67-
68-
69-
def remap_maybe_indexed_key(key_map: Dict[str, str], key: str) -> str:
70-
"""Given a key map and a measurement_key (indexed or unindexed), returns the remapped key in
71-
the same format. Does not modify the index (series of repetition_ids) part, if it exists."""
72-
split_key = split_maybe_indexed_key(key)
73-
split_key[-1] = key_map.get(split_key[-1], split_key[-1])
74-
return MEASUREMENT_KEY_SEPARATOR.join(split_key)
75-
76-
7755
@dataclasses.dataclass(frozen=True)
7856
class CircuitOperation(ops.Operation):
7957
"""An operation that encapsulates a circuit.
@@ -90,6 +68,7 @@ class CircuitOperation(ops.Operation):
9068
The keys and values should be unindexed (i.e. without repetition_ids).
9169
The values cannot contain the `MEASUREMENT_KEY_SEPARATOR`.
9270
param_resolver: Resolved values for parameters in the circuit.
71+
parent_path: A tuple of identifiers for any parent CircuitOperations containing this one.
9372
repetition_ids: List of identifiers for each repetition of the
9473
CircuitOperation. If populated, the length should be equal to the
9574
repetitions. If not populated and abs(`repetitions`) > 1, it is
@@ -104,6 +83,7 @@ class CircuitOperation(ops.Operation):
10483
measurement_key_map: Dict[str, str] = dataclasses.field(default_factory=dict)
10584
param_resolver: study.ParamResolver = study.ParamResolver()
10685
repetition_ids: Optional[List[str]] = dataclasses.field(default=None)
86+
parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
10787

10888
def __post_init__(self):
10989
if not isinstance(self.circuit, circuits.FrozenCircuit):
@@ -128,27 +108,12 @@ def __post_init__(self):
128108

129109
# Disallow mapping to keys containing the `MEASUREMENT_KEY_SEPARATOR`
130110
for mapped_key in self.measurement_key_map.values():
131-
if MEASUREMENT_KEY_SEPARATOR in mapped_key:
111+
if value.MEASUREMENT_KEY_SEPARATOR in mapped_key:
132112
raise ValueError(
133-
f'Mapping to invalid key: {mapped_key}. "{MEASUREMENT_KEY_SEPARATOR}" '
113+
f'Mapping to invalid key: {mapped_key}. "{value.MEASUREMENT_KEY_SEPARATOR}" '
134114
'is not allowed for measurement keys in a CircuitOperation'
135115
)
136116

137-
# Validate the keys for all direct child measurements. They are not allowed to contain
138-
# `MEASUREMENT_KEY_SEPARATOR`
139-
for _, op in self.circuit.findall_operations(
140-
lambda op: not isinstance(op, CircuitOperation) and protocols.is_measurement(op)
141-
):
142-
for key in protocols.measurement_keys(op):
143-
key = self.measurement_key_map.get(key, key)
144-
if MEASUREMENT_KEY_SEPARATOR in key:
145-
raise ValueError(
146-
f'Measurement {op} found to have invalid key: {key}. '
147-
f'"{MEASUREMENT_KEY_SEPARATOR}" is not allowed for measurement keys '
148-
'in a CircuitOperation. Consider remapping the key using '
149-
'`measurement_key_map` in the CircuitOperation constructor.'
150-
)
151-
152117
# Disallow qid mapping dimension conflicts.
153118
for q, q_new in self.qubit_map.items():
154119
if q_new.dimension != q.dimension:
@@ -178,6 +143,7 @@ def __eq__(self, other) -> bool:
178143
and self.param_resolver == other.param_resolver
179144
and self.repetitions == other.repetitions
180145
and self.repetition_ids == other.repetition_ids
146+
and self.parent_path == other.parent_path
181147
)
182148

183149
# Methods for getting post-mapping properties of the contained circuit.
@@ -195,12 +161,20 @@ def _qid_shape_(self) -> Tuple[int, ...]:
195161
return tuple(q.dimension for q in self.qubits)
196162

197163
def _measurement_keys_(self) -> AbstractSet[str]:
198-
circuit_keys = self.circuit.all_measurement_keys()
164+
circuit_keys = [
165+
value.MeasurementKey.parse_serialized(key_str)
166+
for key_str in self.circuit.all_measurement_keys()
167+
]
199168
if self.repetition_ids is not None:
200-
circuit_keys = cartesian_product_of_string_lists(
201-
self.repetition_ids, list(circuit_keys)
202-
)
203-
return {remap_maybe_indexed_key(self.measurement_key_map, key) for key in circuit_keys}
169+
circuit_keys = [
170+
key.with_key_path_prefix(repetition_id)
171+
for repetition_id in self.repetition_ids
172+
for key in circuit_keys
173+
]
174+
return {
175+
str(protocols.with_measurement_key_mapping(key, self.measurement_key_map))
176+
for key in circuit_keys
177+
}
204178

205179
def _parameter_names_(self) -> AbstractSet[str]:
206180
return {
@@ -225,32 +199,9 @@ def _decompose_(self) -> 'cirq.OP_TREE':
225199
# If it's a measurement circuit with repetitions/repetition_ids, prefix the repetition_ids
226200
# to measurements. Details at https://tinyurl.com/measurement-repeated-circuitop.
227201
ops = [] # type: List[cirq.Operation]
228-
for parent_id in self.repetition_ids:
229-
for op in result.all_operations():
230-
if isinstance(op, CircuitOperation):
231-
# For a CircuitOperation, prefix the current repetition_id to the children
232-
# repetition_ids.
233-
ops.append(
234-
op.with_repetition_ids(
235-
# If `op.repetition_ids` is None, this will return `[parent_id]`.
236-
cartesian_product_of_string_lists([parent_id], op.repetition_ids)
237-
)
238-
)
239-
elif protocols.is_measurement(op):
240-
# For a non-CircuitOperation measurement, prefix the current repetition_id
241-
# to the children measurement keys. Implemented by creating a mapping and
242-
# using the with_measurement_key_mapping protocol.
243-
ops.append(
244-
protocols.with_measurement_key_mapping(
245-
op,
246-
key_map={
247-
key: f'{MEASUREMENT_KEY_SEPARATOR.join([parent_id, key])}'
248-
for key in protocols.measurement_keys(op)
249-
},
250-
)
251-
)
252-
else:
253-
ops.append(op)
202+
for repetition_id in self.repetition_ids:
203+
path = self.parent_path + (repetition_id,)
204+
ops += protocols.with_key_path(result, path).all_operations()
254205
return ops
255206

256207
# Methods for string representation of the operation.
@@ -265,6 +216,8 @@ def __repr__(self):
265216
args += f'measurement_key_map={proper_repr(self.measurement_key_map)},\n'
266217
if self.param_resolver:
267218
args += f'param_resolver={proper_repr(self.param_resolver)},\n'
219+
if self.parent_path:
220+
args += f'parent_path={proper_repr(self.parent_path)},\n'
268221
if self.repetition_ids != self._default_repetition_ids():
269222
# Default repetition_ids need not be specified.
270223
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
@@ -291,6 +244,8 @@ def dict_str(d: Dict) -> str:
291244
args.append(f'key_map={dict_str(self.measurement_key_map)}')
292245
if self.param_resolver:
293246
args.append(f'params={self.param_resolver.param_dict}')
247+
if self.parent_path:
248+
args.append(f'parent_path={self.parent_path}')
294249
if self.repetition_ids != self._default_repetition_ids():
295250
# Default repetition_ids need not be specified.
296251
args.append(f'repetition_ids={self.repetition_ids}')
@@ -313,6 +268,7 @@ def __hash__(self):
313268
frozenset(self.qubit_map.items()),
314269
frozenset(self.measurement_key_map.items()),
315270
self.param_resolver,
271+
self.parent_path,
316272
tuple([] if self.repetition_ids is None else self.repetition_ids),
317273
)
318274
),
@@ -330,6 +286,7 @@ def _json_dict_(self):
330286
'measurement_key_map': self.measurement_key_map,
331287
'param_resolver': self.param_resolver,
332288
'repetition_ids': self.repetition_ids,
289+
'parent_path': self.parent_path,
333290
}
334291

335292
@classmethod
@@ -341,13 +298,15 @@ def _from_json_dict_(
341298
measurement_key_map,
342299
param_resolver,
343300
repetition_ids,
301+
parent_path=(),
344302
**kwargs,
345303
):
346304
return (
347305
cls(circuit)
348306
.with_qubit_mapping(dict(qubit_map))
349307
.with_measurement_key_mapping(measurement_key_map)
350308
.with_params(param_resolver)
309+
.with_key_path(tuple(parent_path))
351310
.repeat(repetitions, repetition_ids)
352311
)
353312

@@ -408,13 +367,19 @@ def repeat(
408367
)
409368

410369
# If `self.repetition_ids` is None, this will just return `repetition_ids`.
411-
repetition_ids = cartesian_product_of_string_lists(repetition_ids, self.repetition_ids)
370+
repetition_ids = _full_join_string_lists(repetition_ids, self.repetition_ids)
412371

413372
return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)
414373

415374
def __pow__(self, power: int) -> 'CircuitOperation':
416375
return self.repeat(power)
417376

377+
def _with_key_path_(self, path: Tuple[str, ...]):
378+
return dataclasses.replace(self, parent_path=path)
379+
380+
def with_key_path(self, path: Tuple[str, ...]):
381+
return self._with_key_path_(path)
382+
418383
def with_repetition_ids(self, repetition_ids: List[str]) -> 'CircuitOperation':
419384
return self.replace(repetition_ids=repetition_ids)
420385

@@ -501,7 +466,7 @@ def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOpera
501466
"""
502467
new_map = {}
503468
for k in self.circuit.all_measurement_keys():
504-
k = get_unindexed_key(k)
469+
k = value.MeasurementKey.parse_serialized(k).name
505470
k_new = self.measurement_key_map.get(k, k)
506471
k_new = key_map.get(k_new, k_new)
507472
if k_new != k:

0 commit comments

Comments
 (0)