Skip to content

Optimize ParamResolver.value_of #6341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions cirq-core/cirq/sim/simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def steps(*args, **kwargs):

simulator.simulate_moment_steps.side_effect = steps
circuit = mock.Mock(cirq.Circuit)
param_resolver = mock.Mock(cirq.ParamResolver)
param_resolver.param_dict = {}
param_resolver = cirq.ParamResolver({})
qubit_order = mock.Mock(cirq.QubitOrder)
result = simulator.simulate(
program=circuit, param_resolver=param_resolver, qubit_order=qubit_order, initial_state=2
Expand Down Expand Up @@ -163,9 +162,7 @@ def steps(*args, **kwargs):

simulator.simulate_moment_steps.side_effect = steps
circuit = mock.Mock(cirq.Circuit)
param_resolvers = [mock.Mock(cirq.ParamResolver), mock.Mock(cirq.ParamResolver)]
for resolver in param_resolvers:
resolver.param_dict = {}
param_resolvers = [cirq.ParamResolver({}), cirq.ParamResolver({})]
qubit_order = mock.Mock(cirq.QubitOrder)
results = simulator.simulate_sweep(
program=circuit, params=param_resolvers, qubit_order=qubit_order, initial_state=2
Expand Down
82 changes: 41 additions & 41 deletions cirq-core/cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@
ParamResolverOrSimilarType, """Something that can be used to turn parameters into values."""
)

# Used to mark values that are not found in a dict.
_NOT_FOUND = object()

# Used to mark values that are being resolved recursively to detect loops.
_RecursionFlag = object()
_RECURSION_FLAG = object()


def _is_param_resolver_or_similar_type(obj: Any):
Expand Down Expand Up @@ -72,7 +75,7 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None

self._param_hash: Optional[int] = None
self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
for key in self.param_dict:
for key in self._param_dict:
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
self._deep_eval_map: ParamDictType = {}
Expand Down Expand Up @@ -120,32 +123,30 @@ def value_of(
if v is not NotImplemented:
return v

# Handles 2 cases:
# Input is a string and maps to a number in the dictionary
# Input is a symbol and maps to a number in the dictionary
# In both cases, return it directly.
if value in self.param_dict:
# Note: if the value is in the dictionary, it will be a key type
# Add a cast to make mypy happy.
param_value = self.param_dict[cast('cirq.TParamKey', value)]
# Handle string or symbol
if isinstance(value, (str, sympy.Symbol)):
string = value if isinstance(value, str) else value.name
symbol = value if isinstance(value, sympy.Symbol) else sympy.Symbol(value)
param_value = self._param_dict.get(string, _NOT_FOUND)
if param_value is _NOT_FOUND:
param_value = self._param_dict.get(symbol, _NOT_FOUND)
if param_value is _NOT_FOUND:
# Symbol or string cannot be resolved if not in param dict; return as symbol.
return symbol
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we decided to always return a symbol here even if resolving a string. My preference would be to return the given value unchanged if the resolver doesn't need to change it, but callers could rely on this behavior so we should be careful if we want to change it.

v = _resolve_value(param_value)
if v is not NotImplemented:
return v
if isinstance(param_value, str):
param_value = sympy.Symbol(param_value)
elif not isinstance(param_value, sympy.Basic):
return value # type: ignore[return-value]
if recursive:
param_value = self._value_of_recursive(value)
return param_value # type: ignore[return-value]

# Input is a string and is not in the dictionary.
# Treat it as a symbol instead.
if isinstance(value, str):
# If the string is in the param_dict as a value, return it.
# Otherwise, try using the symbol instead.
return self.value_of(sympy.Symbol(value), recursive)

# Input is a symbol (sympy.Symbol('a')) and its string maps to a number
# in the dictionary ({'a': 1.0}). Return it.
if isinstance(value, sympy.Symbol) and value.name in self.param_dict:
param_value = self.param_dict[value.name]
v = _resolve_value(param_value)
if v is not NotImplemented:
return v
if not isinstance(value, sympy.Basic):
# No known way to resolve this variable, return unchanged.
return value

# The following resolves common sympy expressions
# If sympy did its job and wasn't slower than molasses,
Expand All @@ -171,10 +172,6 @@ def value_of(
return np.float_power(cast(complex, base), cast(complex, exponent))
return np.power(cast(complex, base), cast(complex, exponent))

if not isinstance(value, sympy.Basic):
# No known way to resolve this variable, return unchanged.
return value

# Input is either a sympy formula or the dictionary maps to a
# formula. Use sympy to resolve the value.
# Note that sympy.subs() is slow, so we want to avoid this and
Expand All @@ -186,7 +183,7 @@ def value_of(
# Note that a sympy.SympifyError here likely means
# that one of the expressions was not parsable by sympy
# (such as a function returning NotImplemented)
v = value.subs(self.param_dict, simultaneous=True)
v = value.subs(self._param_dict, simultaneous=True)

if v.free_symbols:
return v
Expand All @@ -197,23 +194,26 @@ def value_of(
else:
return float(v)

return self._value_of_recursive(value)

def _value_of_recursive(self, value: 'cirq.TParamKey') -> 'cirq.TParamValComplex':
# Recursive parameter resolution. We can safely assume that value is a
# single symbol, since combinations are handled earlier in the method.
if value in self._deep_eval_map:
v = self._deep_eval_map[value]
if v is not _RecursionFlag:
return v
raise RecursionError('Evaluation of {value} indirectly contains itself.')
if v is _RECURSION_FLAG:
raise RecursionError('Evaluation of {value} indirectly contains itself.')
return v

# There isn't a full evaluation for 'value' yet. Until it's ready,
# map value to None to identify loops in component evaluation.
self._deep_eval_map[value] = _RecursionFlag # type: ignore
self._deep_eval_map[value] = _RECURSION_FLAG # type: ignore

v = self.value_of(value, recursive=False)
if v == value:
self._deep_eval_map[value] = v
else:
self._deep_eval_map[value] = self.value_of(v, recursive)
self._deep_eval_map[value] = self.value_of(v, recursive=True)
return self._deep_eval_map[value]

def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'ParamResolver':
Expand All @@ -224,17 +224,17 @@ def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'P
new_dict.update(
{k: resolver.value_of(v, recursive) for k, v in new_dict.items()} # type: ignore[misc]
)
if recursive and self.param_dict:
if recursive and self._param_dict:
new_resolver = ParamResolver(cast(ParamDictType, new_dict))
# Resolve down to single-step mappings.
return ParamResolver()._resolve_parameters_(new_resolver, recursive=True)
return ParamResolver(cast(ParamDictType, new_dict))

def __iter__(self) -> Iterator[Union[str, sympy.Expr]]:
return iter(self.param_dict)
return iter(self._param_dict)

def __bool__(self) -> bool:
return bool(self.param_dict)
return bool(self._param_dict)

def __getitem__(
self, key: Union['cirq.TParamKey', 'cirq.TParamValComplex']
Expand All @@ -243,29 +243,29 @@ def __getitem__(

def __hash__(self) -> int:
if self._param_hash is None:
self._param_hash = hash(frozenset(self.param_dict.items()))
self._param_hash = hash(frozenset(self._param_dict.items()))
return self._param_hash

def __eq__(self, other):
if not isinstance(other, ParamResolver):
return NotImplemented
return self.param_dict == other.param_dict
return self._param_dict == other._param_dict

def __ne__(self, other):
return not self == other

def __repr__(self) -> str:
param_dict_repr = (
'{'
+ ', '.join([f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self.param_dict.items()])
+ ', '.join(f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self._param_dict.items())
+ '}'
)
return f'cirq.ParamResolver({param_dict_repr})'

def _json_dict_(self) -> Dict[str, Any]:
return {
# JSON requires mappings to have keys of basic types.
'param_dict': list(self.param_dict.items())
'param_dict': list(self._param_dict.items())
}

@classmethod
Expand Down
25 changes: 13 additions & 12 deletions cirq-core/cirq/study/resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def test_value_of_transformed_types(val, resolved):

@pytest.mark.parametrize('val,resolved', [(sympy.I, 1j)])
def test_value_of_substituted_types(val, resolved):
_assert_consistent_resolution(val, resolved, True)
_assert_consistent_resolution(val, resolved)


def _assert_consistent_resolution(v, resolved, subs_called=False):
def _assert_consistent_resolution(v, resolved):
"""Asserts that parameter resolution works consistently.

The ParamResolver.value_of method can resolve any Sympy expression -
Expand All @@ -70,7 +70,7 @@ def _assert_consistent_resolution(v, resolved, subs_called=False):
Args:
v: the value to resolve
resolved: the expected resolution result
subs_called: if True, it is expected that the slow subs method is called

Raises:
AssertionError in case resolution assertion fail.
"""
Expand All @@ -93,9 +93,7 @@ def subs(self, *args, **kwargs):
# symbol based resolution
s = SubsAwareSymbol('a')
assert r.value_of(s) == resolved, f"expected {resolved}, got {r.value_of(s)}"
assert (
subs_called == s.called
), f"For pass-through type {type(v)} sympy.subs shouldn't have been called."
assert not s.called, f"For pass-through type {type(v)} sympy.subs shouldn't have been called."
assert isinstance(
r.value_of(s), type(resolved)
), f"expected {type(resolved)} got {type(r.value_of(s))}"
Expand Down Expand Up @@ -243,15 +241,18 @@ def _resolved_value_(self):


def test_custom_value_not_implemented():
class Bar:
class BarImplicit:
pass

class BarExplicit:
def _resolved_value_(self):
return NotImplemented

b = sympy.Symbol('b')
bar = Bar()
r = cirq.ParamResolver({b: bar})
with pytest.raises(sympy.SympifyError):
_ = r.value_of(b)
for cls in [BarImplicit, BarExplicit]:
b = sympy.Symbol('b')
bar = cls()
r = cirq.ParamResolver({b: bar})
assert r.value_of(b) == b


def test_compose():
Expand Down
2 changes: 1 addition & 1 deletion cirq-ionq/cirq_ionq/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_sampler_multiple_jobs():
results = sampler.sample(
program=circuit,
repetitions=4,
params=[cirq.ParamResolver({x: '0.5'}), cirq.ParamResolver({x: '0.6'})],
params=[cirq.ParamResolver({x: 0.5}), cirq.ParamResolver({x: 0.6})],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was coincidentally working before because the value_of code was calling float on the result after looking up in the param dict. But I think that behavior was unintentional because otherwise there would be no way to distinguish between strings used as symbols and strings containing floats. I can't think of any other place where we allow floats to be specified as strings, outside of serialization, and it certainly wasn't the intent to do so for parameter resolution.

)
pd.testing.assert_frame_equal(
results,
Expand Down