Skip to content

Commit ff671ae

Browse files
authored
Allow sympy expressions as classical controls (#4740)
Part 14 of https://tinyurl.com/cirq-feedforward. Adds the ability to create classical control conditions based on sympy expressions. To account for the fact that measurement key strings can contain characters not allowed in sympy variables, the measurement keys in a sympy condition string must be wrapped in curly braces to denote them. For example, to create an expression that checks if measurement A was greater than measurement B, the proper syntax is `cirq.parse_sympy_condition('{A} > {B}')`. This PR does not yet handle qudits completely, as multi-qubit measurements are interpreted as base-2 when converting to integer. A subsequent PR (https://github.com/daxfohl/Cirq/compare/sympy3...daxfohl:qudits?expand=1) will allow this functionality.
1 parent 65d783e commit ff671ae

16 files changed

+545
-59
lines changed

cirq-core/cirq/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -483,15 +483,18 @@
483483
canonicalize_half_turns,
484484
chosen_angle_to_canonical_half_turns,
485485
chosen_angle_to_half_turns,
486+
Condition,
486487
Duration,
487488
DURATION_LIKE,
488489
GenericMetaImplementAnyOneOf,
490+
KeyCondition,
489491
LinearDict,
490492
MEASUREMENT_KEY_SEPARATOR,
491493
MeasurementKey,
492494
PeriodicValue,
493495
RANDOM_STATE_OR_SEED_LIKE,
494496
state_vector_to_probabilities,
497+
SympyCondition,
495498
Timestamp,
496499
TParamKey,
497500
TParamVal,

cirq-core/cirq/_compat.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import numpy as np
2828
import pandas as pd
2929
import sympy
30+
import sympy.printing.repr
3031

3132

3233
def proper_repr(value: Any) -> str:

cirq-core/cirq/json_resolver_cache.py

+2
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def _parallel_gate_op(gate, qubits):
9494
'ISwapPowGate': cirq.ISwapPowGate,
9595
'IdentityGate': cirq.IdentityGate,
9696
'InitObsSetting': cirq.work.InitObsSetting,
97+
'KeyCondition': cirq.KeyCondition,
9798
'KrausChannel': cirq.KrausChannel,
9899
'LinearDict': cirq.LinearDict,
99100
'LineQubit': cirq.LineQubit,
@@ -150,6 +151,7 @@ def _parallel_gate_op(gate, qubits):
150151
'StatePreparationChannel': cirq.StatePreparationChannel,
151152
'SwapPowGate': cirq.SwapPowGate,
152153
'SymmetricalQidPair': cirq.SymmetricalQidPair,
154+
'SympyCondition': cirq.SympyCondition,
153155
'TaggedOperation': cirq.TaggedOperation,
154156
'TiltedSquareLattice': cirq.TiltedSquareLattice,
155157
'TrialResult': cirq.Result, # keep support for Cirq < 0.11.

cirq-core/cirq/ops/classically_controlled_operation.py

+53-42
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
Any,
1717
Dict,
1818
FrozenSet,
19+
List,
1920
Optional,
2021
Sequence,
2122
TYPE_CHECKING,
2223
Tuple,
2324
Union,
2425
)
2526

27+
import sympy
28+
2629
from cirq import protocols, value
2730
from cirq.ops import raw_types
2831

@@ -46,7 +49,7 @@ class ClassicallyControlledOperation(raw_types.Operation):
4649
def __init__(
4750
self,
4851
sub_operation: 'cirq.Operation',
49-
conditions: Sequence[Union[str, 'cirq.MeasurementKey']],
52+
conditions: Sequence[Union[str, 'cirq.MeasurementKey', 'cirq.Condition', sympy.Basic]],
5053
):
5154
"""Initializes a `ClassicallyControlledOperation`.
5255
@@ -68,13 +71,26 @@ def __init__(
6871
raise ValueError(
6972
f'Cannot conditionally run operations with measurements: {sub_operation}'
7073
)
71-
keys = tuple(value.MeasurementKey(c) if isinstance(c, str) else c for c in conditions)
74+
conditions = tuple(conditions)
7275
if isinstance(sub_operation, ClassicallyControlledOperation):
73-
keys += sub_operation._control_keys
76+
conditions += sub_operation._conditions
7477
sub_operation = sub_operation._sub_operation
75-
self._control_keys: Tuple['cirq.MeasurementKey', ...] = keys
78+
conds: List['cirq.Condition'] = []
79+
for c in conditions:
80+
if isinstance(c, str):
81+
c = value.MeasurementKey.parse_serialized(c)
82+
if isinstance(c, value.MeasurementKey):
83+
c = value.KeyCondition(c)
84+
if isinstance(c, sympy.Basic):
85+
c = value.SympyCondition(c)
86+
conds.append(c)
87+
self._conditions: Tuple['cirq.Condition', ...] = tuple(conds)
7688
self._sub_operation: 'cirq.Operation' = sub_operation
7789

90+
@property
91+
def classical_controls(self) -> FrozenSet['cirq.Condition']:
92+
return frozenset(self._conditions).union(self._sub_operation.classical_controls)
93+
7894
def without_classical_controls(self) -> 'cirq.Operation':
7995
return self._sub_operation.without_classical_controls()
8096

@@ -84,27 +100,27 @@ def qubits(self):
84100

85101
def with_qubits(self, *new_qubits):
86102
return self._sub_operation.with_qubits(*new_qubits).with_classical_controls(
87-
*self._control_keys
103+
*self._conditions
88104
)
89105

90106
def _decompose_(self):
91107
result = protocols.decompose_once(self._sub_operation, NotImplemented)
92108
if result is NotImplemented:
93109
return NotImplemented
94110

95-
return [ClassicallyControlledOperation(op, self._control_keys) for op in result]
111+
return [ClassicallyControlledOperation(op, self._conditions) for op in result]
96112

97113
def _value_equality_values_(self):
98-
return (frozenset(self._control_keys), self._sub_operation)
114+
return (frozenset(self._conditions), self._sub_operation)
99115

100116
def __str__(self) -> str:
101-
keys = ', '.join(map(str, self._control_keys))
117+
keys = ', '.join(map(str, self._conditions))
102118
return f'{self._sub_operation}.with_classical_controls({keys})'
103119

104120
def __repr__(self):
105121
return (
106122
f'cirq.ClassicallyControlledOperation('
107-
f'{self._sub_operation!r}, {list(self._control_keys)!r})'
123+
f'{self._sub_operation!r}, {list(self._conditions)!r})'
108124
)
109125

110126
def _is_parameterized_(self) -> bool:
@@ -117,7 +133,7 @@ def _resolve_parameters_(
117133
self, resolver: 'cirq.ParamResolver', recursive: bool
118134
) -> 'ClassicallyControlledOperation':
119135
new_sub_op = protocols.resolve_parameters(self._sub_operation, resolver, recursive)
120-
return new_sub_op.with_classical_controls(*self._control_keys)
136+
return new_sub_op.with_classical_controls(*self._conditions)
121137

122138
def _circuit_diagram_info_(
123139
self, args: 'cirq.CircuitDiagramInfoArgs'
@@ -133,12 +149,20 @@ def _circuit_diagram_info_(
133149
if sub_info is None:
134150
return NotImplemented # coverage: ignore
135151

136-
wire_symbols = sub_info.wire_symbols + ('^',) * len(self._control_keys)
152+
control_count = len({k for c in self._conditions for k in c.keys})
153+
wire_symbols = sub_info.wire_symbols + ('^',) * control_count
154+
if any(not isinstance(c, value.KeyCondition) for c in self._conditions):
155+
wire_symbols = (
156+
wire_symbols[0]
157+
+ '(conditions=['
158+
+ ', '.join(str(c) for c in self._conditions)
159+
+ '])',
160+
) + wire_symbols[1:]
137161
exponent_qubit_index = None
138162
if sub_info.exponent_qubit_index is not None:
139-
exponent_qubit_index = sub_info.exponent_qubit_index + len(self._control_keys)
163+
exponent_qubit_index = sub_info.exponent_qubit_index + control_count
140164
elif sub_info.exponent is not None:
141-
exponent_qubit_index = len(self._control_keys)
165+
exponent_qubit_index = control_count
142166
return protocols.CircuitDiagramInfo(
143167
wire_symbols=wire_symbols,
144168
exponent=sub_info.exponent,
@@ -148,58 +172,45 @@ def _circuit_diagram_info_(
148172
def _json_dict_(self) -> Dict[str, Any]:
149173
return {
150174
'cirq_type': self.__class__.__name__,
151-
'conditions': self._control_keys,
175+
'conditions': self._conditions,
152176
'sub_operation': self._sub_operation,
153177
}
154178

155179
def _act_on_(self, args: 'cirq.ActOnArgs') -> bool:
156-
def not_zero(measurement):
157-
return any(i != 0 for i in measurement)
158-
159-
measurements = [
160-
args.log_of_measurement_results.get(str(key), str(key)) for key in self._control_keys
161-
]
162-
missing = [m for m in measurements if isinstance(m, str)]
163-
if missing:
164-
raise ValueError(f'Measurement keys {missing} missing when performing {self}')
165-
if all(not_zero(measurement) for measurement in measurements):
180+
if all(c.resolve(args.log_of_measurement_results) for c in self._conditions):
166181
protocols.act_on(self._sub_operation, args)
167182
return True
168183

169184
def _with_measurement_key_mapping_(
170185
self, key_map: Dict[str, str]
171186
) -> 'ClassicallyControlledOperation':
187+
conditions = [protocols.with_measurement_key_mapping(c, key_map) for c in self._conditions]
172188
sub_operation = protocols.with_measurement_key_mapping(self._sub_operation, key_map)
173189
sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation
174-
return sub_operation.with_classical_controls(
175-
*[protocols.with_measurement_key_mapping(k, key_map) for k in self._control_keys]
176-
)
190+
return sub_operation.with_classical_controls(*conditions)
177191

178-
def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlledOperation':
179-
keys = [protocols.with_key_path_prefix(k, path) for k in self._control_keys]
180-
return self._sub_operation.with_classical_controls(*keys)
192+
def _with_key_path_prefix_(self, prefix: Tuple[str, ...]) -> 'ClassicallyControlledOperation':
193+
conditions = [protocols.with_key_path_prefix(c, prefix) for c in self._conditions]
194+
sub_operation = protocols.with_key_path_prefix(self._sub_operation, prefix)
195+
sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation
196+
return sub_operation.with_classical_controls(*conditions)
181197

182198
def _with_rescoped_keys_(
183199
self,
184200
path: Tuple[str, ...],
185201
bindable_keys: FrozenSet['cirq.MeasurementKey'],
186202
) -> 'ClassicallyControlledOperation':
187-
def map_key(key: 'cirq.MeasurementKey') -> 'cirq.MeasurementKey':
188-
for i in range(len(path) + 1):
189-
back_path = path[: len(path) - i]
190-
new_key = key.with_key_path_prefix(*back_path)
191-
if new_key in bindable_keys:
192-
return new_key
193-
return key
194-
203+
conds = [protocols.with_rescoped_keys(c, path, bindable_keys) for c in self._conditions]
195204
sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys)
196-
return sub_operation.with_classical_controls(*[map_key(k) for k in self._control_keys])
205+
return sub_operation.with_classical_controls(*conds)
197206

198207
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
199-
return frozenset(self._control_keys).union(protocols.control_keys(self._sub_operation))
208+
local_keys: FrozenSet['cirq.MeasurementKey'] = frozenset(
209+
k for condition in self._conditions for k in condition.keys
210+
)
211+
return local_keys.union(protocols.control_keys(self._sub_operation))
200212

201213
def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]:
202214
args.validate_version('2.0')
203-
keys = [f'm_{key}!=0' for key in self._control_keys]
204-
all_keys = " && ".join(keys)
215+
all_keys = " && ".join(c.qasm for c in self._conditions)
205216
return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args))

0 commit comments

Comments
 (0)