36
36
ParamResolverOrSimilarType , """Something that can be used to turn parameters into values."""
37
37
)
38
38
39
+ # Used to mark values that are not found in a dict.
40
+ _NOT_FOUND = object ()
41
+
39
42
# Used to mark values that are being resolved recursively to detect loops.
40
- _RecursionFlag = object ()
43
+ _RECURSION_FLAG = object ()
41
44
42
45
43
46
def _is_param_resolver_or_similar_type (obj : Any ):
@@ -72,7 +75,7 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None
72
75
73
76
self ._param_hash : Optional [int ] = None
74
77
self ._param_dict = cast (ParamDictType , {} if param_dict is None else param_dict )
75
- for key in self .param_dict :
78
+ for key in self ._param_dict :
76
79
if isinstance (key , sympy .Expr ) and not isinstance (key , sympy .Symbol ):
77
80
raise TypeError (f'ParamResolver keys cannot be (non-symbol) formulas ({ key } )' )
78
81
self ._deep_eval_map : ParamDictType = {}
@@ -120,32 +123,30 @@ def value_of(
120
123
if v is not NotImplemented :
121
124
return v
122
125
123
- # Handles 2 cases:
124
- # Input is a string and maps to a number in the dictionary
125
- # Input is a symbol and maps to a number in the dictionary
126
- # In both cases, return it directly.
127
- if value in self .param_dict :
128
- # Note: if the value is in the dictionary, it will be a key type
129
- # Add a cast to make mypy happy.
130
- param_value = self .param_dict [cast ('cirq.TParamKey' , value )]
126
+ # Handle string or symbol
127
+ if isinstance (value , (str , sympy .Symbol )):
128
+ string = value if isinstance (value , str ) else value .name
129
+ symbol = value if isinstance (value , sympy .Symbol ) else sympy .Symbol (value )
130
+ param_value = self ._param_dict .get (string , _NOT_FOUND )
131
+ if param_value is _NOT_FOUND :
132
+ param_value = self ._param_dict .get (symbol , _NOT_FOUND )
133
+ if param_value is _NOT_FOUND :
134
+ # Symbol or string cannot be resolved if not in param dict; return as symbol.
135
+ return symbol
131
136
v = _resolve_value (param_value )
132
137
if v is not NotImplemented :
133
138
return v
139
+ if isinstance (param_value , str ):
140
+ param_value = sympy .Symbol (param_value )
141
+ elif not isinstance (param_value , sympy .Basic ):
142
+ return value # type: ignore[return-value]
143
+ if recursive :
144
+ param_value = self ._value_of_recursive (value )
145
+ return param_value # type: ignore[return-value]
134
146
135
- # Input is a string and is not in the dictionary.
136
- # Treat it as a symbol instead.
137
- if isinstance (value , str ):
138
- # If the string is in the param_dict as a value, return it.
139
- # Otherwise, try using the symbol instead.
140
- return self .value_of (sympy .Symbol (value ), recursive )
141
-
142
- # Input is a symbol (sympy.Symbol('a')) and its string maps to a number
143
- # in the dictionary ({'a': 1.0}). Return it.
144
- if isinstance (value , sympy .Symbol ) and value .name in self .param_dict :
145
- param_value = self .param_dict [value .name ]
146
- v = _resolve_value (param_value )
147
- if v is not NotImplemented :
148
- return v
147
+ if not isinstance (value , sympy .Basic ):
148
+ # No known way to resolve this variable, return unchanged.
149
+ return value
149
150
150
151
# The following resolves common sympy expressions
151
152
# If sympy did its job and wasn't slower than molasses,
@@ -171,10 +172,6 @@ def value_of(
171
172
return np .float_power (cast (complex , base ), cast (complex , exponent ))
172
173
return np .power (cast (complex , base ), cast (complex , exponent ))
173
174
174
- if not isinstance (value , sympy .Basic ):
175
- # No known way to resolve this variable, return unchanged.
176
- return value
177
-
178
175
# Input is either a sympy formula or the dictionary maps to a
179
176
# formula. Use sympy to resolve the value.
180
177
# Note that sympy.subs() is slow, so we want to avoid this and
@@ -186,7 +183,7 @@ def value_of(
186
183
# Note that a sympy.SympifyError here likely means
187
184
# that one of the expressions was not parsable by sympy
188
185
# (such as a function returning NotImplemented)
189
- v = value .subs (self .param_dict , simultaneous = True )
186
+ v = value .subs (self ._param_dict , simultaneous = True )
190
187
191
188
if v .free_symbols :
192
189
return v
@@ -197,23 +194,26 @@ def value_of(
197
194
else :
198
195
return float (v )
199
196
197
+ return self ._value_of_recursive (value )
198
+
199
+ def _value_of_recursive (self , value : 'cirq.TParamKey' ) -> 'cirq.TParamValComplex' :
200
200
# Recursive parameter resolution. We can safely assume that value is a
201
201
# single symbol, since combinations are handled earlier in the method.
202
202
if value in self ._deep_eval_map :
203
203
v = self ._deep_eval_map [value ]
204
- if v is not _RecursionFlag :
205
- return v
206
- raise RecursionError ( 'Evaluation of {value} indirectly contains itself.' )
204
+ if v is _RECURSION_FLAG :
205
+ raise RecursionError ( 'Evaluation of {value} indirectly contains itself.' )
206
+ return v
207
207
208
208
# There isn't a full evaluation for 'value' yet. Until it's ready,
209
209
# map value to None to identify loops in component evaluation.
210
- self ._deep_eval_map [value ] = _RecursionFlag # type: ignore
210
+ self ._deep_eval_map [value ] = _RECURSION_FLAG # type: ignore
211
211
212
212
v = self .value_of (value , recursive = False )
213
213
if v == value :
214
214
self ._deep_eval_map [value ] = v
215
215
else :
216
- self ._deep_eval_map [value ] = self .value_of (v , recursive )
216
+ self ._deep_eval_map [value ] = self .value_of (v , recursive = True )
217
217
return self ._deep_eval_map [value ]
218
218
219
219
def _resolve_parameters_ (self , resolver : 'ParamResolver' , recursive : bool ) -> 'ParamResolver' :
@@ -224,17 +224,17 @@ def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'P
224
224
new_dict .update (
225
225
{k : resolver .value_of (v , recursive ) for k , v in new_dict .items ()} # type: ignore[misc]
226
226
)
227
- if recursive and self .param_dict :
227
+ if recursive and self ._param_dict :
228
228
new_resolver = ParamResolver (cast (ParamDictType , new_dict ))
229
229
# Resolve down to single-step mappings.
230
230
return ParamResolver ()._resolve_parameters_ (new_resolver , recursive = True )
231
231
return ParamResolver (cast (ParamDictType , new_dict ))
232
232
233
233
def __iter__ (self ) -> Iterator [Union [str , sympy .Expr ]]:
234
- return iter (self .param_dict )
234
+ return iter (self ._param_dict )
235
235
236
236
def __bool__ (self ) -> bool :
237
- return bool (self .param_dict )
237
+ return bool (self ._param_dict )
238
238
239
239
def __getitem__ (
240
240
self , key : Union ['cirq.TParamKey' , 'cirq.TParamValComplex' ]
@@ -243,29 +243,29 @@ def __getitem__(
243
243
244
244
def __hash__ (self ) -> int :
245
245
if self ._param_hash is None :
246
- self ._param_hash = hash (frozenset (self .param_dict .items ()))
246
+ self ._param_hash = hash (frozenset (self ._param_dict .items ()))
247
247
return self ._param_hash
248
248
249
249
def __eq__ (self , other ):
250
250
if not isinstance (other , ParamResolver ):
251
251
return NotImplemented
252
- return self .param_dict == other .param_dict
252
+ return self ._param_dict == other ._param_dict
253
253
254
254
def __ne__ (self , other ):
255
255
return not self == other
256
256
257
257
def __repr__ (self ) -> str :
258
258
param_dict_repr = (
259
259
'{'
260
- + ', ' .join ([ f'{ proper_repr (k )} : { proper_repr (v )} ' for k , v in self .param_dict .items ()] )
260
+ + ', ' .join (f'{ proper_repr (k )} : { proper_repr (v )} ' for k , v in self ._param_dict .items ())
261
261
+ '}'
262
262
)
263
263
return f'cirq.ParamResolver({ param_dict_repr } )'
264
264
265
265
def _json_dict_ (self ) -> Dict [str , Any ]:
266
266
return {
267
267
# JSON requires mappings to have keys of basic types.
268
- 'param_dict' : list (self .param_dict .items ())
268
+ 'param_dict' : list (self ._param_dict .items ())
269
269
}
270
270
271
271
@classmethod
0 commit comments