-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Optimize ParamResolver.value_of #6341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
e3d618b
Optimize ParamResolver.value_of
maffoo cfe0a50
fmt, lint, types
maffoo c7848c1
fmt
maffoo 9a4037c
Fix simulator tests
maffoo e3a7d61
Fix ionq test
maffoo abf9532
fix
maffoo b4b3acf
Handle resolving to string
maffoo 440806d
Merge branch 'master' into u/maffoo/resolver
maffoo 68631ed
Fixes from review
maffoo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,9 @@ | |
ParamResolverOrSimilarType, """Something that can be used to turn parameters into values.""" | ||
) | ||
|
||
# Used to mark values that are not found in a dict. | ||
_NotFound = object() | ||
|
||
# Used to mark values that are being resolved recursively to detect loops. | ||
_RecursionFlag = object() | ||
|
||
|
@@ -72,7 +75,7 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None | |
|
||
self._param_hash: Optional[int] = None | ||
self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict) | ||
for key in self.param_dict: | ||
for key in self._param_dict: | ||
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol): | ||
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})') | ||
self._deep_eval_map: ParamDictType = {} | ||
|
@@ -120,32 +123,28 @@ def value_of( | |
if v is not NotImplemented: | ||
return v | ||
|
||
# Handles 2 cases: | ||
# Input is a string and maps to a number in the dictionary | ||
# Input is a symbol and maps to a number in the dictionary | ||
# In both cases, return it directly. | ||
if value in self.param_dict: | ||
# Note: if the value is in the dictionary, it will be a key type | ||
# Add a cast to make mypy happy. | ||
param_value = self.param_dict[cast('cirq.TParamKey', value)] | ||
# Handle string or symbol | ||
if isinstance(value, (str, sympy.Symbol)): | ||
string = value if isinstance(value, str) else value.name | ||
symbol = value if isinstance(value, sympy.Symbol) else sympy.Symbol(value) | ||
param_value = self._param_dict.get(string, _NotFound) | ||
if param_value is _NotFound: | ||
param_value = self._param_dict.get(symbol, _NotFound) | ||
if param_value is _NotFound: | ||
# Symbol or string that is not in the param_dict cannot be resolved futher; return as symbol. | ||
return symbol | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why we decided to always return a symbol here even if resolving a string. My preference would be to return the given value unchanged if the resolver doesn't need to change it, but callers could rely on this behavior so we should be careful if we want to change it. |
||
v = _resolve_value(param_value) | ||
if v is not NotImplemented: | ||
return v | ||
if not isinstance(param_value, sympy.Basic): | ||
return value | ||
if recursive: | ||
param_value = self._value_of_recursive(value) | ||
return param_value | ||
|
||
# Input is a string and is not in the dictionary. | ||
# Treat it as a symbol instead. | ||
if isinstance(value, str): | ||
# If the string is in the param_dict as a value, return it. | ||
# Otherwise, try using the symbol instead. | ||
return self.value_of(sympy.Symbol(value), recursive) | ||
|
||
# Input is a symbol (sympy.Symbol('a')) and its string maps to a number | ||
# in the dictionary ({'a': 1.0}). Return it. | ||
if isinstance(value, sympy.Symbol) and value.name in self.param_dict: | ||
param_value = self.param_dict[value.name] | ||
v = _resolve_value(param_value) | ||
if v is not NotImplemented: | ||
return v | ||
if not isinstance(value, sympy.Basic): | ||
# No known way to resolve this variable, return unchanged. | ||
return value | ||
|
||
# The following resolves common sympy expressions | ||
# If sympy did its job and wasn't slower than molasses, | ||
|
@@ -171,10 +170,6 @@ def value_of( | |
return np.float_power(cast(complex, base), cast(complex, exponent)) | ||
return np.power(cast(complex, base), cast(complex, exponent)) | ||
|
||
if not isinstance(value, sympy.Basic): | ||
# No known way to resolve this variable, return unchanged. | ||
return value | ||
|
||
# Input is either a sympy formula or the dictionary maps to a | ||
# formula. Use sympy to resolve the value. | ||
# Note that sympy.subs() is slow, so we want to avoid this and | ||
|
@@ -186,7 +181,7 @@ def value_of( | |
# Note that a sympy.SympifyError here likely means | ||
# that one of the expressions was not parsable by sympy | ||
# (such as a function returning NotImplemented) | ||
v = value.subs(self.param_dict, simultaneous=True) | ||
v = value.subs(self._param_dict, simultaneous=True) | ||
|
||
if v.free_symbols: | ||
return v | ||
|
@@ -197,13 +192,18 @@ def value_of( | |
else: | ||
return float(v) | ||
|
||
return self._value_of_recursive(value) | ||
|
||
def _value_of_recursive( | ||
self, value: Union['cirq.TParamKey', 'cirq.TParamValComplex'] | ||
) -> 'cirq.TParamValComplex': | ||
# Recursive parameter resolution. We can safely assume that value is a | ||
# single symbol, since combinations are handled earlier in the method. | ||
if value in self._deep_eval_map: | ||
v = self._deep_eval_map[value] | ||
if v is not _RecursionFlag: | ||
return v | ||
raise RecursionError('Evaluation of {value} indirectly contains itself.') | ||
if v is _RecursionFlag: | ||
raise RecursionError('Evaluation of {value} indirectly contains itself.') | ||
return v | ||
|
||
# There isn't a full evaluation for 'value' yet. Until it's ready, | ||
# map value to None to identify loops in component evaluation. | ||
|
@@ -213,7 +213,7 @@ def value_of( | |
if v == value: | ||
self._deep_eval_map[value] = v | ||
else: | ||
self._deep_eval_map[value] = self.value_of(v, recursive) | ||
self._deep_eval_map[value] = self.value_of(v, recursive=True) | ||
return self._deep_eval_map[value] | ||
|
||
def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'ParamResolver': | ||
|
@@ -224,17 +224,17 @@ def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'P | |
new_dict.update( | ||
{k: resolver.value_of(v, recursive) for k, v in new_dict.items()} # type: ignore[misc] | ||
) | ||
if recursive and self.param_dict: | ||
if recursive and self._param_dict: | ||
new_resolver = ParamResolver(cast(ParamDictType, new_dict)) | ||
# Resolve down to single-step mappings. | ||
return ParamResolver()._resolve_parameters_(new_resolver, recursive=True) | ||
return ParamResolver(cast(ParamDictType, new_dict)) | ||
|
||
def __iter__(self) -> Iterator[Union[str, sympy.Expr]]: | ||
return iter(self.param_dict) | ||
return iter(self._param_dict) | ||
|
||
def __bool__(self) -> bool: | ||
return bool(self.param_dict) | ||
return bool(self._param_dict) | ||
|
||
def __getitem__( | ||
self, key: Union['cirq.TParamKey', 'cirq.TParamValComplex'] | ||
|
@@ -243,29 +243,31 @@ def __getitem__( | |
|
||
def __hash__(self) -> int: | ||
if self._param_hash is None: | ||
self._param_hash = hash(frozenset(self.param_dict.items())) | ||
self._param_hash = hash(frozenset(self._param_dict.items())) | ||
return self._param_hash | ||
|
||
def __eq__(self, other): | ||
if not isinstance(other, ParamResolver): | ||
return NotImplemented | ||
return self.param_dict == other.param_dict | ||
return self._param_dict == other._param_dict | ||
|
||
def __ne__(self, other): | ||
return not self == other | ||
|
||
def __repr__(self) -> str: | ||
param_dict_repr = ( | ||
'{' | ||
+ ', '.join([f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self.param_dict.items()]) | ||
+ ', '.join( | ||
[f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self._param_dict.items()] | ||
) | ||
+ '}' | ||
) | ||
return f'cirq.ParamResolver({param_dict_repr})' | ||
|
||
def _json_dict_(self) -> Dict[str, Any]: | ||
return { | ||
# JSON requires mappings to have keys of basic types. | ||
'param_dict': list(self.param_dict.items()) | ||
'param_dict': list(self._param_dict.items()) | ||
} | ||
|
||
@classmethod | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't global variables be named like
_NOT_FOUND
according to the style guide? Same with RecursionFlag below if you want to clean it up.go/pystyle#guidelines-derived-from-guidos-recommendations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to
_NOT_FOUND
and_RECURSION_FLAG
.