-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Optimize ParamResolver.value_of #6341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
e3d618b
cfe0a50
c7848c1
9a4037c
e3a7d61
abf9532
b4b3acf
440806d
68631ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,9 @@ | |
ParamResolverOrSimilarType, """Something that can be used to turn parameters into values.""" | ||
) | ||
|
||
# Used to mark values that are not found in a dict. | ||
_NotFound = object() | ||
|
||
# Used to mark values that are being resolved recursively to detect loops. | ||
_RecursionFlag = object() | ||
|
||
|
@@ -72,7 +75,7 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None | |
|
||
self._param_hash: Optional[int] = None | ||
self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict) | ||
for key in self.param_dict: | ||
for key in self._param_dict: | ||
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol): | ||
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})') | ||
self._deep_eval_map: ParamDictType = {} | ||
|
@@ -120,32 +123,30 @@ def value_of( | |
if v is not NotImplemented: | ||
return v | ||
|
||
# Handles 2 cases: | ||
# Input is a string and maps to a number in the dictionary | ||
# Input is a symbol and maps to a number in the dictionary | ||
# In both cases, return it directly. | ||
if value in self.param_dict: | ||
# Note: if the value is in the dictionary, it will be a key type | ||
# Add a cast to make mypy happy. | ||
param_value = self.param_dict[cast('cirq.TParamKey', value)] | ||
# Handle string or symbol | ||
if isinstance(value, (str, sympy.Symbol)): | ||
string = value if isinstance(value, str) else value.name | ||
symbol = value if isinstance(value, sympy.Symbol) else sympy.Symbol(value) | ||
param_value = self._param_dict.get(string, _NotFound) | ||
if param_value is _NotFound: | ||
param_value = self._param_dict.get(symbol, _NotFound) | ||
if param_value is _NotFound: | ||
# Symbol or string cannot be resolved if not in param dict; return as symbol. | ||
return symbol | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why we decided to always return a symbol here even if resolving a string. My preference would be to return the given value unchanged if the resolver doesn't need to change it, but callers could rely on this behavior so we should be careful if we want to change it. |
||
v = _resolve_value(param_value) | ||
if v is not NotImplemented: | ||
return v | ||
if isinstance(param_value, str): | ||
param_value = sympy.Symbol(param_value) | ||
elif not isinstance(param_value, sympy.Basic): | ||
return value # type: ignore[return-value] | ||
if recursive: | ||
param_value = self._value_of_recursive(value) | ||
return param_value # type: ignore[return-value] | ||
|
||
# Input is a string and is not in the dictionary. | ||
# Treat it as a symbol instead. | ||
if isinstance(value, str): | ||
# If the string is in the param_dict as a value, return it. | ||
# Otherwise, try using the symbol instead. | ||
return self.value_of(sympy.Symbol(value), recursive) | ||
|
||
# Input is a symbol (sympy.Symbol('a')) and its string maps to a number | ||
# in the dictionary ({'a': 1.0}). Return it. | ||
if isinstance(value, sympy.Symbol) and value.name in self.param_dict: | ||
param_value = self.param_dict[value.name] | ||
v = _resolve_value(param_value) | ||
if v is not NotImplemented: | ||
return v | ||
if not isinstance(value, sympy.Basic): | ||
# No known way to resolve this variable, return unchanged. | ||
return value | ||
|
||
# The following resolves common sympy expressions | ||
# If sympy did its job and wasn't slower than molasses, | ||
|
@@ -171,10 +172,6 @@ def value_of( | |
return np.float_power(cast(complex, base), cast(complex, exponent)) | ||
return np.power(cast(complex, base), cast(complex, exponent)) | ||
|
||
if not isinstance(value, sympy.Basic): | ||
# No known way to resolve this variable, return unchanged. | ||
return value | ||
|
||
# Input is either a sympy formula or the dictionary maps to a | ||
# formula. Use sympy to resolve the value. | ||
# Note that sympy.subs() is slow, so we want to avoid this and | ||
|
@@ -186,7 +183,7 @@ def value_of( | |
# Note that a sympy.SympifyError here likely means | ||
# that one of the expressions was not parsable by sympy | ||
# (such as a function returning NotImplemented) | ||
v = value.subs(self.param_dict, simultaneous=True) | ||
v = value.subs(self._param_dict, simultaneous=True) | ||
|
||
if v.free_symbols: | ||
return v | ||
|
@@ -197,13 +194,16 @@ def value_of( | |
else: | ||
return float(v) | ||
|
||
return self._value_of_recursive(value) | ||
|
||
def _value_of_recursive(self, value: 'cirq.TParamKey') -> 'cirq.TParamValComplex': | ||
# Recursive parameter resolution. We can safely assume that value is a | ||
# single symbol, since combinations are handled earlier in the method. | ||
if value in self._deep_eval_map: | ||
v = self._deep_eval_map[value] | ||
if v is not _RecursionFlag: | ||
return v | ||
raise RecursionError('Evaluation of {value} indirectly contains itself.') | ||
if v is _RecursionFlag: | ||
raise RecursionError('Evaluation of {value} indirectly contains itself.') | ||
return v | ||
|
||
# There isn't a full evaluation for 'value' yet. Until it's ready, | ||
# map value to None to identify loops in component evaluation. | ||
|
@@ -213,7 +213,7 @@ def value_of( | |
if v == value: | ||
self._deep_eval_map[value] = v | ||
else: | ||
self._deep_eval_map[value] = self.value_of(v, recursive) | ||
self._deep_eval_map[value] = self.value_of(v, recursive=True) | ||
return self._deep_eval_map[value] | ||
|
||
def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'ParamResolver': | ||
|
@@ -224,17 +224,17 @@ def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'P | |
new_dict.update( | ||
{k: resolver.value_of(v, recursive) for k, v in new_dict.items()} # type: ignore[misc] | ||
) | ||
if recursive and self.param_dict: | ||
if recursive and self._param_dict: | ||
new_resolver = ParamResolver(cast(ParamDictType, new_dict)) | ||
# Resolve down to single-step mappings. | ||
return ParamResolver()._resolve_parameters_(new_resolver, recursive=True) | ||
return ParamResolver(cast(ParamDictType, new_dict)) | ||
|
||
def __iter__(self) -> Iterator[Union[str, sympy.Expr]]: | ||
return iter(self.param_dict) | ||
return iter(self._param_dict) | ||
|
||
def __bool__(self) -> bool: | ||
return bool(self.param_dict) | ||
return bool(self._param_dict) | ||
|
||
def __getitem__( | ||
self, key: Union['cirq.TParamKey', 'cirq.TParamValComplex'] | ||
|
@@ -243,29 +243,29 @@ def __getitem__( | |
|
||
def __hash__(self) -> int: | ||
if self._param_hash is None: | ||
self._param_hash = hash(frozenset(self.param_dict.items())) | ||
self._param_hash = hash(frozenset(self._param_dict.items())) | ||
return self._param_hash | ||
|
||
def __eq__(self, other): | ||
if not isinstance(other, ParamResolver): | ||
return NotImplemented | ||
return self.param_dict == other.param_dict | ||
return self._param_dict == other._param_dict | ||
|
||
def __ne__(self, other): | ||
return not self == other | ||
|
||
def __repr__(self) -> str: | ||
param_dict_repr = ( | ||
'{' | ||
+ ', '.join([f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self.param_dict.items()]) | ||
+ ', '.join(f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self._param_dict.items()) | ||
+ '}' | ||
) | ||
return f'cirq.ParamResolver({param_dict_repr})' | ||
|
||
def _json_dict_(self) -> Dict[str, Any]: | ||
return { | ||
# JSON requires mappings to have keys of basic types. | ||
'param_dict': list(self.param_dict.items()) | ||
'param_dict': list(self._param_dict.items()) | ||
} | ||
|
||
@classmethod | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,7 +100,7 @@ def test_sampler_multiple_jobs(): | |
results = sampler.sample( | ||
program=circuit, | ||
repetitions=4, | ||
params=[cirq.ParamResolver({x: '0.5'}), cirq.ParamResolver({x: '0.6'})], | ||
params=[cirq.ParamResolver({x: 0.5}), cirq.ParamResolver({x: 0.6})], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was coincidentally working before because the |
||
) | ||
pd.testing.assert_frame_equal( | ||
results, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't global variables be named like
_NOT_FOUND
according to the style guide? Same with RecursionFlag below if you want to clean it up.go/pystyle#guidelines-derived-from-guidos-recommendations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to
_NOT_FOUND
and_RECURSION_FLAG
.