16
16
Any ,
17
17
Dict ,
18
18
FrozenSet ,
19
+ List ,
19
20
Optional ,
20
21
Sequence ,
21
22
TYPE_CHECKING ,
22
23
Tuple ,
23
24
Union ,
24
25
)
25
26
27
+ import sympy
28
+
26
29
from cirq import protocols , value
27
30
from cirq .ops import raw_types
28
31
@@ -46,7 +49,7 @@ class ClassicallyControlledOperation(raw_types.Operation):
46
49
def __init__ (
47
50
self ,
48
51
sub_operation : 'cirq.Operation' ,
49
- conditions : Sequence [Union [str , 'cirq.MeasurementKey' ]],
52
+ conditions : Sequence [Union [str , 'cirq.MeasurementKey' , 'cirq.Condition' , sympy . Basic ]],
50
53
):
51
54
"""Initializes a `ClassicallyControlledOperation`.
52
55
@@ -68,13 +71,26 @@ def __init__(
68
71
raise ValueError (
69
72
f'Cannot conditionally run operations with measurements: { sub_operation } '
70
73
)
71
- keys = tuple (value . MeasurementKey ( c ) if isinstance ( c , str ) else c for c in conditions )
74
+ conditions = tuple (conditions )
72
75
if isinstance (sub_operation , ClassicallyControlledOperation ):
73
- keys += sub_operation ._control_keys
76
+ conditions += sub_operation ._conditions
74
77
sub_operation = sub_operation ._sub_operation
75
- self ._control_keys : Tuple ['cirq.MeasurementKey' , ...] = keys
78
+ conds : List ['cirq.Condition' ] = []
79
+ for c in conditions :
80
+ if isinstance (c , str ):
81
+ c = value .MeasurementKey .parse_serialized (c )
82
+ if isinstance (c , value .MeasurementKey ):
83
+ c = value .KeyCondition (c )
84
+ if isinstance (c , sympy .Basic ):
85
+ c = value .SympyCondition (c )
86
+ conds .append (c )
87
+ self ._conditions : Tuple ['cirq.Condition' , ...] = tuple (conds )
76
88
self ._sub_operation : 'cirq.Operation' = sub_operation
77
89
90
+ @property
91
+ def classical_controls (self ) -> FrozenSet ['cirq.Condition' ]:
92
+ return frozenset (self ._conditions ).union (self ._sub_operation .classical_controls )
93
+
78
94
def without_classical_controls (self ) -> 'cirq.Operation' :
79
95
return self ._sub_operation .without_classical_controls ()
80
96
@@ -84,27 +100,27 @@ def qubits(self):
84
100
85
101
def with_qubits (self , * new_qubits ):
86
102
return self ._sub_operation .with_qubits (* new_qubits ).with_classical_controls (
87
- * self ._control_keys
103
+ * self ._conditions
88
104
)
89
105
90
106
def _decompose_ (self ):
91
107
result = protocols .decompose_once (self ._sub_operation , NotImplemented )
92
108
if result is NotImplemented :
93
109
return NotImplemented
94
110
95
- return [ClassicallyControlledOperation (op , self ._control_keys ) for op in result ]
111
+ return [ClassicallyControlledOperation (op , self ._conditions ) for op in result ]
96
112
97
113
def _value_equality_values_ (self ):
98
- return (frozenset (self ._control_keys ), self ._sub_operation )
114
+ return (frozenset (self ._conditions ), self ._sub_operation )
99
115
100
116
def __str__ (self ) -> str :
101
- keys = ', ' .join (map (str , self ._control_keys ))
117
+ keys = ', ' .join (map (str , self ._conditions ))
102
118
return f'{ self ._sub_operation } .with_classical_controls({ keys } )'
103
119
104
120
def __repr__ (self ):
105
121
return (
106
122
f'cirq.ClassicallyControlledOperation('
107
- f'{ self ._sub_operation !r} , { list (self ._control_keys )!r} )'
123
+ f'{ self ._sub_operation !r} , { list (self ._conditions )!r} )'
108
124
)
109
125
110
126
def _is_parameterized_ (self ) -> bool :
@@ -117,7 +133,7 @@ def _resolve_parameters_(
117
133
self , resolver : 'cirq.ParamResolver' , recursive : bool
118
134
) -> 'ClassicallyControlledOperation' :
119
135
new_sub_op = protocols .resolve_parameters (self ._sub_operation , resolver , recursive )
120
- return new_sub_op .with_classical_controls (* self ._control_keys )
136
+ return new_sub_op .with_classical_controls (* self ._conditions )
121
137
122
138
def _circuit_diagram_info_ (
123
139
self , args : 'cirq.CircuitDiagramInfoArgs'
@@ -133,12 +149,20 @@ def _circuit_diagram_info_(
133
149
if sub_info is None :
134
150
return NotImplemented # coverage: ignore
135
151
136
- wire_symbols = sub_info .wire_symbols + ('^' ,) * len (self ._control_keys )
152
+ control_count = len ({k for c in self ._conditions for k in c .keys })
153
+ wire_symbols = sub_info .wire_symbols + ('^' ,) * control_count
154
+ if any (not isinstance (c , value .KeyCondition ) for c in self ._conditions ):
155
+ wire_symbols = (
156
+ wire_symbols [0 ]
157
+ + '(conditions=['
158
+ + ', ' .join (str (c ) for c in self ._conditions )
159
+ + '])' ,
160
+ ) + wire_symbols [1 :]
137
161
exponent_qubit_index = None
138
162
if sub_info .exponent_qubit_index is not None :
139
- exponent_qubit_index = sub_info .exponent_qubit_index + len ( self . _control_keys )
163
+ exponent_qubit_index = sub_info .exponent_qubit_index + control_count
140
164
elif sub_info .exponent is not None :
141
- exponent_qubit_index = len ( self . _control_keys )
165
+ exponent_qubit_index = control_count
142
166
return protocols .CircuitDiagramInfo (
143
167
wire_symbols = wire_symbols ,
144
168
exponent = sub_info .exponent ,
@@ -148,58 +172,45 @@ def _circuit_diagram_info_(
148
172
def _json_dict_ (self ) -> Dict [str , Any ]:
149
173
return {
150
174
'cirq_type' : self .__class__ .__name__ ,
151
- 'conditions' : self ._control_keys ,
175
+ 'conditions' : self ._conditions ,
152
176
'sub_operation' : self ._sub_operation ,
153
177
}
154
178
155
179
def _act_on_ (self , args : 'cirq.ActOnArgs' ) -> bool :
156
- def not_zero (measurement ):
157
- return any (i != 0 for i in measurement )
158
-
159
- measurements = [
160
- args .log_of_measurement_results .get (str (key ), str (key )) for key in self ._control_keys
161
- ]
162
- missing = [m for m in measurements if isinstance (m , str )]
163
- if missing :
164
- raise ValueError (f'Measurement keys { missing } missing when performing { self } ' )
165
- if all (not_zero (measurement ) for measurement in measurements ):
180
+ if all (c .resolve (args .log_of_measurement_results ) for c in self ._conditions ):
166
181
protocols .act_on (self ._sub_operation , args )
167
182
return True
168
183
169
184
def _with_measurement_key_mapping_ (
170
185
self , key_map : Dict [str , str ]
171
186
) -> 'ClassicallyControlledOperation' :
187
+ conditions = [protocols .with_measurement_key_mapping (c , key_map ) for c in self ._conditions ]
172
188
sub_operation = protocols .with_measurement_key_mapping (self ._sub_operation , key_map )
173
189
sub_operation = self ._sub_operation if sub_operation is NotImplemented else sub_operation
174
- return sub_operation .with_classical_controls (
175
- * [protocols .with_measurement_key_mapping (k , key_map ) for k in self ._control_keys ]
176
- )
190
+ return sub_operation .with_classical_controls (* conditions )
177
191
178
- def _with_key_path_prefix_ (self , path : Tuple [str , ...]) -> 'ClassicallyControlledOperation' :
179
- keys = [protocols .with_key_path_prefix (k , path ) for k in self ._control_keys ]
180
- return self ._sub_operation .with_classical_controls (* keys )
192
+ def _with_key_path_prefix_ (self , prefix : Tuple [str , ...]) -> 'ClassicallyControlledOperation' :
193
+ conditions = [protocols .with_key_path_prefix (c , prefix ) for c in self ._conditions ]
194
+ sub_operation = protocols .with_key_path_prefix (self ._sub_operation , prefix )
195
+ sub_operation = self ._sub_operation if sub_operation is NotImplemented else sub_operation
196
+ return sub_operation .with_classical_controls (* conditions )
181
197
182
198
def _with_rescoped_keys_ (
183
199
self ,
184
200
path : Tuple [str , ...],
185
201
bindable_keys : FrozenSet ['cirq.MeasurementKey' ],
186
202
) -> 'ClassicallyControlledOperation' :
187
- def map_key (key : 'cirq.MeasurementKey' ) -> 'cirq.MeasurementKey' :
188
- for i in range (len (path ) + 1 ):
189
- back_path = path [: len (path ) - i ]
190
- new_key = key .with_key_path_prefix (* back_path )
191
- if new_key in bindable_keys :
192
- return new_key
193
- return key
194
-
203
+ conds = [protocols .with_rescoped_keys (c , path , bindable_keys ) for c in self ._conditions ]
195
204
sub_operation = protocols .with_rescoped_keys (self ._sub_operation , path , bindable_keys )
196
- return sub_operation .with_classical_controls (* [ map_key ( k ) for k in self . _control_keys ] )
205
+ return sub_operation .with_classical_controls (* conds )
197
206
198
207
def _control_keys_ (self ) -> FrozenSet ['cirq.MeasurementKey' ]:
199
- return frozenset (self ._control_keys ).union (protocols .control_keys (self ._sub_operation ))
208
+ local_keys : FrozenSet ['cirq.MeasurementKey' ] = frozenset (
209
+ k for condition in self ._conditions for k in condition .keys
210
+ )
211
+ return local_keys .union (protocols .control_keys (self ._sub_operation ))
200
212
201
213
def _qasm_ (self , args : 'cirq.QasmArgs' ) -> Optional [str ]:
202
214
args .validate_version ('2.0' )
203
- keys = [f'm_{ key } !=0' for key in self ._control_keys ]
204
- all_keys = " && " .join (keys )
215
+ all_keys = " && " .join (c .qasm for c in self ._conditions )
205
216
return args .format ('if ({0}) {1}' , all_keys , protocols .qasm (self ._sub_operation , args = args ))
0 commit comments