Skip to content

Commit 8d07cab

Browse files
authored
Optimize ParamResolver.value_of (#6341)
Review: @dstrain115
1 parent 392083b commit 8d07cab

File tree

4 files changed

+57
-59
lines changed

4 files changed

+57
-59
lines changed

Diff for: cirq-core/cirq/sim/simulator_test.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,7 @@ def steps(*args, **kwargs):
134134

135135
simulator.simulate_moment_steps.side_effect = steps
136136
circuit = mock.Mock(cirq.Circuit)
137-
param_resolver = mock.Mock(cirq.ParamResolver)
138-
param_resolver.param_dict = {}
137+
param_resolver = cirq.ParamResolver({})
139138
qubit_order = mock.Mock(cirq.QubitOrder)
140139
result = simulator.simulate(
141140
program=circuit, param_resolver=param_resolver, qubit_order=qubit_order, initial_state=2
@@ -163,9 +162,7 @@ def steps(*args, **kwargs):
163162

164163
simulator.simulate_moment_steps.side_effect = steps
165164
circuit = mock.Mock(cirq.Circuit)
166-
param_resolvers = [mock.Mock(cirq.ParamResolver), mock.Mock(cirq.ParamResolver)]
167-
for resolver in param_resolvers:
168-
resolver.param_dict = {}
165+
param_resolvers = [cirq.ParamResolver({}), cirq.ParamResolver({})]
169166
qubit_order = mock.Mock(cirq.QubitOrder)
170167
results = simulator.simulate_sweep(
171168
program=circuit, params=param_resolvers, qubit_order=qubit_order, initial_state=2

Diff for: cirq-core/cirq/study/resolver.py

+41-41
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@
3636
ParamResolverOrSimilarType, """Something that can be used to turn parameters into values."""
3737
)
3838

39+
# Used to mark values that are not found in a dict.
40+
_NOT_FOUND = object()
41+
3942
# Used to mark values that are being resolved recursively to detect loops.
40-
_RecursionFlag = object()
43+
_RECURSION_FLAG = object()
4144

4245

4346
def _is_param_resolver_or_similar_type(obj: Any):
@@ -72,7 +75,7 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None
7275

7376
self._param_hash: Optional[int] = None
7477
self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
75-
for key in self.param_dict:
78+
for key in self._param_dict:
7679
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
7780
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
7881
self._deep_eval_map: ParamDictType = {}
@@ -120,32 +123,30 @@ def value_of(
120123
if v is not NotImplemented:
121124
return v
122125

123-
# Handles 2 cases:
124-
# Input is a string and maps to a number in the dictionary
125-
# Input is a symbol and maps to a number in the dictionary
126-
# In both cases, return it directly.
127-
if value in self.param_dict:
128-
# Note: if the value is in the dictionary, it will be a key type
129-
# Add a cast to make mypy happy.
130-
param_value = self.param_dict[cast('cirq.TParamKey', value)]
126+
# Handle string or symbol
127+
if isinstance(value, (str, sympy.Symbol)):
128+
string = value if isinstance(value, str) else value.name
129+
symbol = value if isinstance(value, sympy.Symbol) else sympy.Symbol(value)
130+
param_value = self._param_dict.get(string, _NOT_FOUND)
131+
if param_value is _NOT_FOUND:
132+
param_value = self._param_dict.get(symbol, _NOT_FOUND)
133+
if param_value is _NOT_FOUND:
134+
# Symbol or string cannot be resolved if not in param dict; return as symbol.
135+
return symbol
131136
v = _resolve_value(param_value)
132137
if v is not NotImplemented:
133138
return v
139+
if isinstance(param_value, str):
140+
param_value = sympy.Symbol(param_value)
141+
elif not isinstance(param_value, sympy.Basic):
142+
return value # type: ignore[return-value]
143+
if recursive:
144+
param_value = self._value_of_recursive(value)
145+
return param_value # type: ignore[return-value]
134146

135-
# Input is a string and is not in the dictionary.
136-
# Treat it as a symbol instead.
137-
if isinstance(value, str):
138-
# If the string is in the param_dict as a value, return it.
139-
# Otherwise, try using the symbol instead.
140-
return self.value_of(sympy.Symbol(value), recursive)
141-
142-
# Input is a symbol (sympy.Symbol('a')) and its string maps to a number
143-
# in the dictionary ({'a': 1.0}). Return it.
144-
if isinstance(value, sympy.Symbol) and value.name in self.param_dict:
145-
param_value = self.param_dict[value.name]
146-
v = _resolve_value(param_value)
147-
if v is not NotImplemented:
148-
return v
147+
if not isinstance(value, sympy.Basic):
148+
# No known way to resolve this variable, return unchanged.
149+
return value
149150

150151
# The following resolves common sympy expressions
151152
# If sympy did its job and wasn't slower than molasses,
@@ -171,10 +172,6 @@ def value_of(
171172
return np.float_power(cast(complex, base), cast(complex, exponent))
172173
return np.power(cast(complex, base), cast(complex, exponent))
173174

174-
if not isinstance(value, sympy.Basic):
175-
# No known way to resolve this variable, return unchanged.
176-
return value
177-
178175
# Input is either a sympy formula or the dictionary maps to a
179176
# formula. Use sympy to resolve the value.
180177
# Note that sympy.subs() is slow, so we want to avoid this and
@@ -186,7 +183,7 @@ def value_of(
186183
# Note that a sympy.SympifyError here likely means
187184
# that one of the expressions was not parsable by sympy
188185
# (such as a function returning NotImplemented)
189-
v = value.subs(self.param_dict, simultaneous=True)
186+
v = value.subs(self._param_dict, simultaneous=True)
190187

191188
if v.free_symbols:
192189
return v
@@ -197,23 +194,26 @@ def value_of(
197194
else:
198195
return float(v)
199196

197+
return self._value_of_recursive(value)
198+
199+
def _value_of_recursive(self, value: 'cirq.TParamKey') -> 'cirq.TParamValComplex':
200200
# Recursive parameter resolution. We can safely assume that value is a
201201
# single symbol, since combinations are handled earlier in the method.
202202
if value in self._deep_eval_map:
203203
v = self._deep_eval_map[value]
204-
if v is not _RecursionFlag:
205-
return v
206-
raise RecursionError('Evaluation of {value} indirectly contains itself.')
204+
if v is _RECURSION_FLAG:
205+
raise RecursionError('Evaluation of {value} indirectly contains itself.')
206+
return v
207207

208208
# There isn't a full evaluation for 'value' yet. Until it's ready,
209209
# map value to None to identify loops in component evaluation.
210-
self._deep_eval_map[value] = _RecursionFlag # type: ignore
210+
self._deep_eval_map[value] = _RECURSION_FLAG # type: ignore
211211

212212
v = self.value_of(value, recursive=False)
213213
if v == value:
214214
self._deep_eval_map[value] = v
215215
else:
216-
self._deep_eval_map[value] = self.value_of(v, recursive)
216+
self._deep_eval_map[value] = self.value_of(v, recursive=True)
217217
return self._deep_eval_map[value]
218218

219219
def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'ParamResolver':
@@ -224,17 +224,17 @@ def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'P
224224
new_dict.update(
225225
{k: resolver.value_of(v, recursive) for k, v in new_dict.items()} # type: ignore[misc]
226226
)
227-
if recursive and self.param_dict:
227+
if recursive and self._param_dict:
228228
new_resolver = ParamResolver(cast(ParamDictType, new_dict))
229229
# Resolve down to single-step mappings.
230230
return ParamResolver()._resolve_parameters_(new_resolver, recursive=True)
231231
return ParamResolver(cast(ParamDictType, new_dict))
232232

233233
def __iter__(self) -> Iterator[Union[str, sympy.Expr]]:
234-
return iter(self.param_dict)
234+
return iter(self._param_dict)
235235

236236
def __bool__(self) -> bool:
237-
return bool(self.param_dict)
237+
return bool(self._param_dict)
238238

239239
def __getitem__(
240240
self, key: Union['cirq.TParamKey', 'cirq.TParamValComplex']
@@ -243,29 +243,29 @@ def __getitem__(
243243

244244
def __hash__(self) -> int:
245245
if self._param_hash is None:
246-
self._param_hash = hash(frozenset(self.param_dict.items()))
246+
self._param_hash = hash(frozenset(self._param_dict.items()))
247247
return self._param_hash
248248

249249
def __eq__(self, other):
250250
if not isinstance(other, ParamResolver):
251251
return NotImplemented
252-
return self.param_dict == other.param_dict
252+
return self._param_dict == other._param_dict
253253

254254
def __ne__(self, other):
255255
return not self == other
256256

257257
def __repr__(self) -> str:
258258
param_dict_repr = (
259259
'{'
260-
+ ', '.join([f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self.param_dict.items()])
260+
+ ', '.join(f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self._param_dict.items())
261261
+ '}'
262262
)
263263
return f'cirq.ParamResolver({param_dict_repr})'
264264

265265
def _json_dict_(self) -> Dict[str, Any]:
266266
return {
267267
# JSON requires mappings to have keys of basic types.
268-
'param_dict': list(self.param_dict.items())
268+
'param_dict': list(self._param_dict.items())
269269
}
270270

271271
@classmethod

Diff for: cirq-core/cirq/study/resolver_test.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ def test_value_of_transformed_types(val, resolved):
5353

5454
@pytest.mark.parametrize('val,resolved', [(sympy.I, 1j)])
5555
def test_value_of_substituted_types(val, resolved):
56-
_assert_consistent_resolution(val, resolved, True)
56+
_assert_consistent_resolution(val, resolved)
5757

5858

59-
def _assert_consistent_resolution(v, resolved, subs_called=False):
59+
def _assert_consistent_resolution(v, resolved):
6060
"""Asserts that parameter resolution works consistently.
6161
6262
The ParamResolver.value_of method can resolve any Sympy expression -
@@ -70,7 +70,7 @@ def _assert_consistent_resolution(v, resolved, subs_called=False):
7070
Args:
7171
v: the value to resolve
7272
resolved: the expected resolution result
73-
subs_called: if True, it is expected that the slow subs method is called
73+
7474
Raises:
7575
AssertionError in case resolution assertion fail.
7676
"""
@@ -93,9 +93,7 @@ def subs(self, *args, **kwargs):
9393
# symbol based resolution
9494
s = SubsAwareSymbol('a')
9595
assert r.value_of(s) == resolved, f"expected {resolved}, got {r.value_of(s)}"
96-
assert (
97-
subs_called == s.called
98-
), f"For pass-through type {type(v)} sympy.subs shouldn't have been called."
96+
assert not s.called, f"For pass-through type {type(v)} sympy.subs shouldn't have been called."
9997
assert isinstance(
10098
r.value_of(s), type(resolved)
10199
), f"expected {type(resolved)} got {type(r.value_of(s))}"
@@ -243,15 +241,18 @@ def _resolved_value_(self):
243241

244242

245243
def test_custom_value_not_implemented():
246-
class Bar:
244+
class BarImplicit:
245+
pass
246+
247+
class BarExplicit:
247248
def _resolved_value_(self):
248249
return NotImplemented
249250

250-
b = sympy.Symbol('b')
251-
bar = Bar()
252-
r = cirq.ParamResolver({b: bar})
253-
with pytest.raises(sympy.SympifyError):
254-
_ = r.value_of(b)
251+
for cls in [BarImplicit, BarExplicit]:
252+
b = sympy.Symbol('b')
253+
bar = cls()
254+
r = cirq.ParamResolver({b: bar})
255+
assert r.value_of(b) == b
255256

256257

257258
def test_compose():

Diff for: cirq-ionq/cirq_ionq/sampler_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_sampler_multiple_jobs():
100100
results = sampler.sample(
101101
program=circuit,
102102
repetitions=4,
103-
params=[cirq.ParamResolver({x: '0.5'}), cirq.ParamResolver({x: '0.6'})],
103+
params=[cirq.ParamResolver({x: 0.5}), cirq.ParamResolver({x: 0.6})],
104104
)
105105
pd.testing.assert_frame_equal(
106106
results,

0 commit comments

Comments
 (0)