Skip to content

Commit 3c1c802

Browse files
authored
Add repeat-until functionality to subcircuits (#5018)
* Allow flattening of subcircuits * format * Add serialization logic and tests * Change flatten_repetitions (default False) to use_repetition_ids (default True) * Add shape tests for simulation results from flattened subcircuits * docs * add repeat_until * repr/json/etc * format * chagne do_while to repeat_until * merge fix * make mapped_single_loop private * Address code review comments. * Fix test * simplify branch * simplify branch * simplify branch * add unbound controls in repeat_until to control_keys
1 parent d90b09c commit 3c1c802

File tree

4 files changed

+231
-23
lines changed

4 files changed

+231
-23
lines changed

cirq-core/cirq/circuits/circuit_operation.py

+68-22
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
component operations in order, including any nested CircuitOperations.
1919
"""
2020
from typing import (
21-
TYPE_CHECKING,
2221
AbstractSet,
2322
Callable,
23+
cast,
2424
Dict,
2525
FrozenSet,
2626
Iterator,
2727
List,
2828
Optional,
2929
Tuple,
30+
TYPE_CHECKING,
3031
Union,
3132
)
3233

@@ -94,6 +95,12 @@ class CircuitOperation(ops.Operation):
9495
will have its path prepended with the repetition id for each
9596
repetition. When False, this will not happen and the measurement
9697
key will be repeated.
98+
repeat_until: A condition that will be tested after each iteration of
99+
the subcircuit. The subcircuit will repeat until condition returns
100+
True, but will always run at least once, and the measurement key
101+
need not be defined prior to the subcircuit (but must be defined in
102+
a measurement within the subcircuit). This field is incompatible
103+
with repetitions or repetition_ids.
97104
"""
98105

99106
_hash: Optional[int] = dataclasses.field(default=None, init=False)
@@ -103,6 +110,9 @@ class CircuitOperation(ops.Operation):
103110
_cached_control_keys: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field(
104111
default=None, init=False
105112
)
113+
_cached_mapped_single_loop: Optional['cirq.Circuit'] = dataclasses.field(
114+
default=None, init=False
115+
)
106116

107117
circuit: 'cirq.FrozenCircuit'
108118
repetitions: int = 1
@@ -113,6 +123,7 @@ class CircuitOperation(ops.Operation):
113123
parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
114124
extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset)
115125
use_repetition_ids: bool = True
126+
repeat_until: Optional['cirq.Condition'] = dataclasses.field(default=None)
116127

117128
def __post_init__(self):
118129
if not isinstance(self.circuit, circuits.FrozenCircuit):
@@ -148,6 +159,14 @@ def __post_init__(self):
148159
if q_new.dimension != q.dimension:
149160
raise ValueError(f'Qid dimension conflict.\nFrom qid: {q}\nTo qid: {q_new}')
150161

162+
if self.repeat_until:
163+
if self.use_repetition_ids or self.repetitions != 1:
164+
raise ValueError('Cannot use repetitions with repeat_until')
165+
if protocols.measurement_key_objs(self._mapped_single_loop()).isdisjoint(
166+
self.repeat_until.keys
167+
):
168+
raise ValueError('Infinite loop: condition is not modified in subcircuit.')
169+
151170
# Ensure that param_resolver is converted to an actual ParamResolver.
152171
object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver))
153172

@@ -174,6 +193,7 @@ def __eq__(self, other) -> bool:
174193
and self.repetition_ids == other.repetition_ids
175194
and self.parent_path == other.parent_path
176195
and self.use_repetition_ids == other.use_repetition_ids
196+
and self.repeat_until == other.repeat_until
177197
)
178198

179199
# Methods for getting post-mapping properties of the contained circuit.
@@ -223,6 +243,8 @@ def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
223243
if not protocols.control_keys(self.circuit)
224244
else protocols.control_keys(self.mapped_circuit())
225245
)
246+
if self.repeat_until is not None:
247+
keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_()
226248
object.__setattr__(self, '_cached_control_keys', keys)
227249
return self._cached_control_keys # type: ignore
228250

@@ -235,6 +257,27 @@ def _parameter_names_(self) -> AbstractSet[str]:
235257
)
236258
}
237259

260+
def _mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circuit':
261+
if self._cached_mapped_single_loop is None:
262+
circuit = self.circuit.unfreeze()
263+
if self.qubit_map:
264+
circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q))
265+
if self.repetitions < 0:
266+
circuit = circuit ** -1
267+
if self.measurement_key_map:
268+
circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map)
269+
if self.param_resolver:
270+
circuit = protocols.resolve_parameters(
271+
circuit, self.param_resolver, recursive=False
272+
)
273+
object.__setattr__(self, '_cached_mapped_single_loop', circuit)
274+
circuit = cast(circuits.Circuit, self._cached_mapped_single_loop)
275+
if repetition_id:
276+
circuit = protocols.with_rescoped_keys(circuit, (repetition_id,))
277+
return protocols.with_rescoped_keys(
278+
circuit, self.parent_path, bindable_keys=self.extern_keys
279+
)
280+
238281
def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
239282
"""Applies all maps to the contained circuit and returns the result.
240283
@@ -249,24 +292,12 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
249292
"""
250293
if self.repetitions == 0:
251294
return circuits.Circuit()
252-
circuit = self.circuit.unfreeze()
253-
if self.qubit_map:
254-
circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q))
255-
if self.repetitions < 0:
256-
circuit = circuit ** -1
257-
if self.measurement_key_map:
258-
circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map)
259-
if self.param_resolver:
260-
circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False)
261-
if self.repetition_ids is not None:
262-
if not self.use_repetition_ids or not protocols.is_measurement(circuit):
263-
circuit = circuit * abs(self.repetitions)
264-
else:
265-
circuit = circuits.Circuit(
266-
protocols.with_rescoped_keys(circuit, (rep,)) for rep in self.repetition_ids
267-
)
268-
circuit = protocols.with_rescoped_keys(
269-
circuit, self.parent_path, bindable_keys=self.extern_keys
295+
circuit = (
296+
circuits.Circuit(self._mapped_single_loop(rep) for rep in self.repetition_ids)
297+
if self.repetition_ids is not None
298+
and self.use_repetition_ids
299+
and protocols.is_measurement(self.circuit)
300+
else self._mapped_single_loop() * abs(self.repetitions)
270301
)
271302
if deep:
272303
circuit = circuit.map_operations(
@@ -282,8 +313,16 @@ def _decompose_(self) -> Iterator['cirq.Operation']:
282313
return self.mapped_circuit(deep=False).all_operations()
283314

284315
def _act_on_(self, args: 'cirq.OperationTarget') -> bool:
285-
for op in self._decompose_():
286-
protocols.act_on(op, args)
316+
if self.repeat_until:
317+
circuit = self._mapped_single_loop()
318+
while True:
319+
for op in circuit.all_operations():
320+
protocols.act_on(op, args)
321+
if self.repeat_until.resolve(args.classical_data):
322+
break
323+
else:
324+
for op in self._decompose_():
325+
protocols.act_on(op, args)
287326
return True
288327

289328
# Methods for string representation of the operation.
@@ -305,6 +344,8 @@ def __repr__(self):
305344
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
306345
if not self.use_repetition_ids:
307346
args += 'use_repetition_ids=False,\n'
347+
if self.repeat_until:
348+
args += f'repeat_until={self.repeat_until!r},\n'
308349
indented_args = args.replace('\n', '\n ')
309350
return f'cirq.CircuitOperation({indented_args[:-4]})'
310351

@@ -337,6 +378,8 @@ def dict_str(d: Dict) -> str:
337378
args.append(f'loops={self.repetitions}')
338379
if not self.use_repetition_ids:
339380
args.append('no_rep_ids')
381+
if self.repeat_until:
382+
args.append(f'until={self.repeat_until}')
340383
if not args:
341384
return circuit_msg
342385
return f'{circuit_msg}({", ".join(args)})'
@@ -375,6 +418,8 @@ def _json_dict_(self):
375418
}
376419
if not self.use_repetition_ids:
377420
resp['use_repetition_ids'] = False
421+
if self.repeat_until:
422+
resp['repeat_until'] = self.repeat_until
378423
return resp
379424

380425
@classmethod
@@ -388,10 +433,11 @@ def _from_json_dict_(
388433
repetition_ids,
389434
parent_path=(),
390435
use_repetition_ids=True,
436+
repeat_until=None,
391437
**kwargs,
392438
):
393439
return (
394-
cls(circuit, use_repetition_ids=use_repetition_ids)
440+
cls(circuit, use_repetition_ids=use_repetition_ids, repeat_until=repeat_until)
395441
.with_qubit_mapping(dict(qubit_map))
396442
.with_measurement_key_mapping(measurement_key_map)
397443
.with_params(param_resolver)

cirq-core/cirq/circuits/circuit_operation_test.py

+121
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,24 @@ def test_string_format():
533533
use_repetition_ids=False,
534534
)"""
535535
)
536+
op7 = cirq.CircuitOperation(
537+
cirq.FrozenCircuit(cirq.measure(x, key='a')),
538+
use_repetition_ids=False,
539+
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
540+
)
541+
assert (
542+
repr(op7)
543+
== """\
544+
cirq.CircuitOperation(
545+
circuit=cirq.FrozenCircuit([
546+
cirq.Moment(
547+
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')),
548+
),
549+
]),
550+
use_repetition_ids=False,
551+
repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')),
552+
)"""
553+
)
536554

537555

538556
def test_json_dict():
@@ -977,4 +995,107 @@ def test_simulate_no_repetition_ids_inner(sim):
977995
assert result.records['1:a'].shape == (1, 2, 1)
978996

979997

998+
@pytest.mark.parametrize('sim', ALL_SIMULATORS)
999+
def test_repeat_until(sim):
1000+
q = cirq.LineQubit(0)
1001+
key = cirq.MeasurementKey('m')
1002+
c = cirq.Circuit(
1003+
cirq.X(q),
1004+
cirq.CircuitOperation(
1005+
cirq.FrozenCircuit(
1006+
cirq.X(q),
1007+
cirq.measure(q, key=key),
1008+
),
1009+
use_repetition_ids=False,
1010+
repeat_until=cirq.KeyCondition(key),
1011+
),
1012+
)
1013+
measurements = sim.run(c).records['m'][0]
1014+
assert len(measurements) == 2
1015+
assert measurements[0] == (0,)
1016+
assert measurements[1] == (1,)
1017+
1018+
1019+
@pytest.mark.parametrize('sim', ALL_SIMULATORS)
1020+
def test_repeat_until_sympy(sim):
1021+
q1, q2 = cirq.LineQubit.range(2)
1022+
circuitop = cirq.CircuitOperation(
1023+
cirq.FrozenCircuit(
1024+
cirq.X(q2),
1025+
cirq.measure(q2, key='b'),
1026+
),
1027+
use_repetition_ids=False,
1028+
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), sympy.Symbol('b'))),
1029+
)
1030+
c = cirq.Circuit(
1031+
cirq.measure(q1, key='a'),
1032+
circuitop,
1033+
)
1034+
# Validate commutation
1035+
assert len(c) == 2
1036+
assert cirq.control_keys(circuitop) == {cirq.MeasurementKey('a')}
1037+
measurements = sim.run(c).records['b'][0]
1038+
assert len(measurements) == 2
1039+
assert measurements[0] == (1,)
1040+
assert measurements[1] == (0,)
1041+
1042+
1043+
@pytest.mark.parametrize('sim', [cirq.Simulator(), cirq.DensityMatrixSimulator()])
1044+
def test_post_selection(sim):
1045+
q = cirq.LineQubit(0)
1046+
key = cirq.MeasurementKey('m')
1047+
c = cirq.Circuit(
1048+
cirq.CircuitOperation(
1049+
cirq.FrozenCircuit(
1050+
cirq.X(q) ** 0.2,
1051+
cirq.measure(q, key=key),
1052+
),
1053+
use_repetition_ids=False,
1054+
repeat_until=cirq.KeyCondition(key),
1055+
),
1056+
)
1057+
result = sim.run(c)
1058+
assert result.records['m'][0][-1] == (1,)
1059+
for i in range(len(result.records['m'][0]) - 1):
1060+
assert result.records['m'][0][i] == (0,)
1061+
1062+
1063+
def test_repeat_until_diagram():
1064+
q = cirq.LineQubit(0)
1065+
key = cirq.MeasurementKey('m')
1066+
c = cirq.Circuit(
1067+
cirq.CircuitOperation(
1068+
cirq.FrozenCircuit(
1069+
cirq.X(q) ** 0.2,
1070+
cirq.measure(q, key=key),
1071+
),
1072+
use_repetition_ids=False,
1073+
repeat_until=cirq.KeyCondition(key),
1074+
),
1075+
)
1076+
cirq.testing.assert_has_diagram(
1077+
c,
1078+
"""
1079+
0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)───
1080+
""",
1081+
use_unicode_characters=True,
1082+
)
1083+
1084+
1085+
def test_repeat_until_error():
1086+
q = cirq.LineQubit(0)
1087+
with pytest.raises(ValueError, match='Cannot use repetitions with repeat_until'):
1088+
cirq.CircuitOperation(
1089+
cirq.FrozenCircuit(),
1090+
use_repetition_ids=True,
1091+
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
1092+
)
1093+
with pytest.raises(ValueError, match='Infinite loop'):
1094+
cirq.CircuitOperation(
1095+
cirq.FrozenCircuit(cirq.measure(q, key='m')),
1096+
use_repetition_ids=False,
1097+
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
1098+
)
1099+
1100+
9801101
# TODO: Operation has a "gate" property. What is this for a CircuitOperation?

cirq-core/cirq/protocols/json_test_data/CircuitOperation.json

+26
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,32 @@
294294
"parent_path": [],
295295
"repetition_ids": null,
296296
"use_repetition_ids": false
297+
},
298+
{
299+
"cirq_type": "CircuitOperation",
300+
"circuit": {
301+
"cirq_type": "_SerializedKey",
302+
"key": 1
303+
},
304+
"repetitions": 1,
305+
"qubit_map": [],
306+
"measurement_key_map": {},
307+
"param_resolver": {
308+
"cirq_type": "ParamResolver",
309+
"param_dict": []
310+
},
311+
"parent_path": [],
312+
"repetition_ids": null,
313+
"use_repetition_ids": false,
314+
"repeat_until": {
315+
"cirq_type": "KeyCondition",
316+
"key": {
317+
"cirq_type": "MeasurementKey",
318+
"name": "0,1,2,3,4",
319+
"path": []
320+
},
321+
"index": -1
322+
}
297323
}
298324
]
299325
]

cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr

+16-1
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,19 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([
3636
),
3737
]),
3838
param_resolver={sympy.Symbol('theta'): 1.5},
39-
use_repetition_ids=False)]
39+
use_repetition_ids=False),
40+
cirq.CircuitOperation(circuit=cirq.FrozenCircuit([
41+
cirq.Moment(
42+
cirq.H(cirq.LineQubit(0)),
43+
cirq.H(cirq.LineQubit(1)),
44+
cirq.H(cirq.LineQubit(2)),
45+
cirq.H(cirq.LineQubit(3)),
46+
cirq.H(cirq.LineQubit(4)),
47+
),
48+
cirq.Moment(
49+
cirq.MeasurementGate(5, '0,1,2,3,4', ()).on(cirq.LineQubit(0), cirq.LineQubit(1), cirq.LineQubit(2), cirq.LineQubit(3), cirq.LineQubit(4)),
50+
),
51+
]),
52+
use_repetition_ids=False,
53+
repeat_until=cirq.KeyCondition(key=cirq.MeasurementKey('0,1,2,3,4')),
54+
)]

0 commit comments

Comments
 (0)