Skip to content

Commit 451b2f1

Browse files
authored
Repetition IDs for CircuitOperation (#3741)
* Repetition IDs in CircuitOperation * Small cleanup * address some comments * Make repetitions optional and add override_all_ids arg * remove rogue newline * simplify * rep_id propagation even for non-measurement COp * reformat * Address comments
1 parent d14eb7c commit 451b2f1

File tree

4 files changed

+306
-43
lines changed

4 files changed

+306
-43
lines changed

cirq/circuits/circuit_operation.py

+139-16
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
component operations in order, including any nested CircuitOperations.
1919
"""
2020

21-
from typing import TYPE_CHECKING, AbstractSet, Callable, Dict, Optional, Tuple, Union
21+
from typing import TYPE_CHECKING, AbstractSet, Callable, Dict, List, Optional, Tuple, Union
2222

2323
import dataclasses
2424
import numpy as np
@@ -33,6 +33,22 @@
3333
INT_TYPE = Union[int, np.integer]
3434

3535

36+
def default_repetition_ids(repetitions: int) -> Optional[List[str]]:
37+
if abs(repetitions) > 1:
38+
return [str(i) for i in range(abs(repetitions))]
39+
return None
40+
41+
42+
def cartesian_product_of_string_lists(list1: Optional[List[str]], list2: Optional[List[str]]):
43+
if list1 is None and list2 is None:
44+
return None # coverage: ignore
45+
if list1 is None:
46+
return list2 # coverage: ignore
47+
if list2 is None:
48+
return list1
49+
return [f'{first}-{second}' for first in list1 for second in list2]
50+
51+
3652
@dataclasses.dataclass(frozen=True)
3753
class CircuitOperation(ops.Operation):
3854
"""An operation that encapsulates a circuit.
@@ -47,6 +63,10 @@ class CircuitOperation(ops.Operation):
4763
qubit_map: Remappings for qubits in the circuit.
4864
measurement_key_map: Remappings for measurement keys in the circuit.
4965
param_resolver: Resolved values for parameters in the circuit.
66+
repetition_ids: List of identifiers for each repetition of the
67+
CircuitOperation. If populated, the length should be equal to the
68+
repetitions. If not populated and abs(`repetitions`) > 1, it is
69+
initialized to strings for numbers in `range(repetitions)`.
5070
"""
5171

5272
_hash: Optional[int] = dataclasses.field(default=None, init=False)
@@ -56,10 +76,28 @@ class CircuitOperation(ops.Operation):
5676
qubit_map: Dict['cirq.Qid', 'cirq.Qid'] = dataclasses.field(default_factory=dict)
5777
measurement_key_map: Dict[str, str] = dataclasses.field(default_factory=dict)
5878
param_resolver: study.ParamResolver = study.ParamResolver()
79+
repetition_ids: Optional[List[str]] = dataclasses.field(default=None)
5980

6081
def __post_init__(self):
6182
if not isinstance(self.circuit, circuits.FrozenCircuit):
6283
raise TypeError(f'Expected circuit of type FrozenCircuit, got: {type(self.circuit)!r}')
84+
85+
# Ensure that the circuit is invertible if the repetitions are negative.
86+
if self.repetitions < 0:
87+
try:
88+
protocols.inverse(self.circuit.unfreeze())
89+
except TypeError:
90+
raise ValueError(f'repetitions are negative but the circuit is not invertible')
91+
92+
# Initialize repetition_ids to default, if unspecified. Else, validate their length.
93+
loop_size = abs(self.repetitions)
94+
if not self.repetition_ids:
95+
object.__setattr__(self, 'repetition_ids', self._default_repetition_ids())
96+
elif len(self.repetition_ids) != loop_size:
97+
raise ValueError(
98+
f'Expected repetition_ids to be a list of length {loop_size}, '
99+
f'got: {self.repetition_ids}'
100+
)
63101
# Ensure that param_resolver is converted to an actual ParamResolver.
64102
object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver))
65103

@@ -83,6 +121,7 @@ def __eq__(self, other) -> bool:
83121
and self.measurement_key_map == other.measurement_key_map
84122
and self.param_resolver == other.param_resolver
85123
and self.repetitions == other.repetitions
124+
and self.repetition_ids == other.repetition_ids
86125
)
87126

88127
# Methods for getting post-mapping properties of the contained circuit.
@@ -93,6 +132,9 @@ def qubits(self) -> Tuple['cirq.Qid', ...]:
93132
ordered_qubits = ops.QubitOrder.DEFAULT.order_for(self.circuit.all_qubits())
94133
return tuple(self.qubit_map.get(q, q) for q in ordered_qubits)
95134

135+
def _default_repetition_ids(self) -> Optional[List[str]]:
136+
return default_repetition_ids(self.repetitions)
137+
96138
def _qid_shape_(self) -> Tuple[int, ...]:
97139
return tuple(q.dimension for q in self.qubits)
98140

@@ -117,8 +159,39 @@ def _decompose_(self) -> 'cirq.OP_TREE':
117159
result = result ** -1
118160
result = protocols.with_measurement_key_mapping(result, self.measurement_key_map)
119161
result = protocols.resolve_parameters(result, self.param_resolver, recursive=False)
120-
121-
return list(result.all_operations()) * abs(self.repetitions)
162+
# repetition_ids don't need to be taken into account if the circuit has no measurements
163+
# or if repetition_ids are unset.
164+
if self.repetition_ids is None or not protocols.is_measurement(result):
165+
return list(result.all_operations()) * abs(self.repetitions)
166+
# If it's a measurement circuit with repetitions/repetition_ids, prefix the repetition_ids
167+
# to measurements. Details at https://tinyurl.com/measurement-repeated-circuitop.
168+
ops = [] # type: List[cirq.Operation]
169+
for parent_id in self.repetition_ids:
170+
for op in result.all_operations():
171+
if isinstance(op, CircuitOperation):
172+
# For a CircuitOperation, prefix the current repetition_id to the children
173+
# repetition_ids.
174+
ops.append(
175+
op.with_repetition_ids(
176+
# If `op.repetition_ids` is None, this will return `[parent_id]`.
177+
cartesian_product_of_string_lists([parent_id], op.repetition_ids)
178+
)
179+
)
180+
elif protocols.is_measurement(op):
181+
# For a non-CircuitOperation measurement, prefix the current repetition_id
182+
# to the children measurement keys. Implemented by creating a mapping and
183+
# using the with_measurement_key_mapping protocol.
184+
ops.append(
185+
protocols.with_measurement_key_mapping(
186+
op,
187+
key_map={
188+
key: f'{parent_id}-{key}' for key in protocols.measurement_keys(op)
189+
},
190+
)
191+
)
192+
else:
193+
ops.append(op)
194+
return ops
122195

123196
# Methods for string representation of the operation.
124197

@@ -132,6 +205,9 @@ def __repr__(self):
132205
args += f'measurement_key_map={proper_repr(self.measurement_key_map)},\n'
133206
if self.param_resolver:
134207
args += f'param_resolver={proper_repr(self.param_resolver)},\n'
208+
if self.repetition_ids != self._default_repetition_ids():
209+
# Default repetition_ids need not be specified.
210+
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
135211
indented_args = args.replace('\n', '\n ')
136212
return f'cirq.CircuitOperation({indented_args[:-4]})'
137213

@@ -155,7 +231,11 @@ def dict_str(d: Dict) -> str:
155231
args.append(f'key_map={dict_str(self.measurement_key_map)}')
156232
if self.param_resolver:
157233
args.append(f'params={self.param_resolver.param_dict}')
158-
if self.repetitions != 1:
234+
if self.repetition_ids != self._default_repetition_ids():
235+
# Default repetition_ids need not be specified.
236+
args.append(f'repetition_ids={self.repetition_ids}')
237+
elif self.repetitions != 1:
238+
# Only add loops if we haven't added repetition_ids.
159239
args.append(f'loops={self.repetitions}')
160240
if not args:
161241
return f'{header}\n{circuit_msg}'
@@ -173,6 +253,7 @@ def __hash__(self):
173253
frozenset(self.qubit_map.items()),
174254
frozenset(self.measurement_key_map.items()),
175255
self.param_resolver,
256+
tuple([] if self.repetition_ids is None else self.repetition_ids),
176257
)
177258
),
178259
)
@@ -188,53 +269,95 @@ def _json_dict_(self):
188269
'qubit_map': sorted(self.qubit_map.items()),
189270
'measurement_key_map': self.measurement_key_map,
190271
'param_resolver': self.param_resolver,
272+
'repetition_ids': self.repetition_ids,
191273
}
192274

193275
@classmethod
194276
def _from_json_dict_(
195-
cls, circuit, repetitions, qubit_map, measurement_key_map, param_resolver, **kwargs
277+
cls,
278+
circuit,
279+
repetitions,
280+
qubit_map,
281+
measurement_key_map,
282+
param_resolver,
283+
repetition_ids,
284+
**kwargs,
196285
):
197286
return (
198287
cls(circuit)
199288
.with_qubit_mapping(dict(qubit_map))
200289
.with_measurement_key_mapping(measurement_key_map)
201290
.with_params(param_resolver)
202-
.repeat(repetitions)
291+
.repeat(repetitions, repetition_ids)
203292
)
204293

205294
# Methods for constructing a similar object with one field modified.
206295

207296
def repeat(
208297
self,
209-
repetitions: INT_TYPE,
298+
repetitions: Optional[INT_TYPE] = None,
299+
repetition_ids: Optional[List[str]] = None,
210300
) -> 'CircuitOperation':
211301
"""Returns a copy of this operation repeated 'repetitions' times.
302+
Each repetition instance will be identified by a single repetition_id.
212303
213304
Args:
214305
repetitions: Number of times this operation should repeat. This
215-
is multiplied with any pre-existing repetitions.
306+
is multiplied with any pre-existing repetitions. If unset, it
307+
defaults to the length of `repetition_ids`.
308+
repetition_ids: List of IDs, one for each repetition. If unset,
309+
defaults to `default_repetition_ids(repetitions)`.
216310
217311
Returns:
218-
A copy of this operation repeated 'repetitions' times.
312+
A copy of this operation repeated `repetitions` times with the
313+
appropriate `repetition_ids`. The output `repetition_ids` are the
314+
cartesian product of input `repetition_ids` with the base
315+
operation's `repetition_ids`. If the base operation has unset
316+
`repetition_ids` (indicates {-1, 0, 1} `repetitions` with no custom
317+
IDs), the input `repetition_ids` are directly used.
219318
220319
Raises:
221320
TypeError: `repetitions` is not an integer value.
222-
NotImplementedError: The operation contains measurements and
223-
cannot have repetitions.
321+
ValueError: Unexpected length of `repetition_ids`.
322+
ValueError: Both `repetitions` and `repetition_ids` are None.
224323
"""
324+
if repetitions is None:
325+
if repetition_ids is None:
326+
raise ValueError('At least one of repetitions and repetition_ids must be set')
327+
repetitions = len(repetition_ids)
328+
225329
if not isinstance(repetitions, (int, np.integer)):
226330
raise TypeError('Only integer repetitions are allowed.')
227-
if repetitions == 1:
331+
332+
repetitions = int(repetitions)
333+
334+
if repetitions == 1 and repetition_ids is None:
228335
# As CircuitOperation is immutable, this can safely return the original.
229336
return self
230-
repetitions = int(repetitions)
231-
if protocols.is_measurement(self.circuit):
232-
raise NotImplementedError('Loops over measurements are not supported.')
233-
return self.replace(repetitions=self.repetitions * repetitions)
337+
338+
expected_repetition_id_length = abs(repetitions)
339+
# The eventual number of repetitions of the returned CircuitOperation.
340+
final_repetitions = self.repetitions * repetitions
341+
342+
if repetition_ids is None:
343+
repetition_ids = default_repetition_ids(expected_repetition_id_length)
344+
elif len(repetition_ids) != expected_repetition_id_length:
345+
raise ValueError(
346+
f'Expected repetition_ids={repetition_ids} length to be '
347+
f'{expected_repetition_id_length}'
348+
)
349+
350+
# If `self.repetition_ids` is None, this will just return `repetition_ids`.
351+
repetition_ids = cartesian_product_of_string_lists(repetition_ids, self.repetition_ids)
352+
353+
return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)
234354

235355
def __pow__(self, power: int) -> 'CircuitOperation':
236356
return self.repeat(power)
237357

358+
def with_repetition_ids(self, repetition_ids: List[str]) -> 'CircuitOperation':
359+
return self.replace(repetition_ids=repetition_ids)
360+
238361
def with_qubit_mapping(
239362
self,
240363
qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']],

0 commit comments

Comments
 (0)