Skip to content

Commit e092889

Browse files
authored
Reject formulas as keys of ParamResolvers (quantumlib#5384)
* Reject formulas as keys of ParamResolvers - A ParamResolver resolves variables into values. - Having non-trivial formulas as keys allows a significant complexity and ambiguity into ParamResolvers, since it is unclear how much is supported. Prevent this case altogether by raising an error if non-symbol formulas are used in ParamResolvers. Fixes: quantumlib#3550
1 parent b0deb26 commit e092889

File tree

2 files changed

+9
-17
lines changed

2 files changed

+9
-17
lines changed

cirq/study/resolver.py

+6
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class ParamResolver:
5555
Attributes:
5656
param_dict: A dictionary from the ParameterValue key (str) to its
5757
assigned value.
58+
59+
Raises:
60+
TypeError if formulas are passed as keys.
5861
"""
5962

6063
def __new__(cls, param_dict: 'cirq.ParamResolverOrSimilarType' = None):
@@ -68,6 +71,9 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None
6871

6972
self._param_hash: Optional[int] = None
7073
self.param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
74+
for key in self.param_dict:
75+
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
76+
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
7177
self._deep_eval_map: ParamDictType = {}
7278

7379
def value_of(

cirq/study/resolver_test.py

+3-17
Original file line numberDiff line numberDiff line change
@@ -156,27 +156,13 @@ def test_param_dict_iter():
156156

157157

158158
def test_formulas_in_param_dict():
159-
"""Test formulas in a `param_dict`.
160-
161-
Param dicts are allowed to have str or sympy.Symbol as keys and
162-
floats or sympy.Symbol as values. This should not be a common use case,
163-
but this tests makes sure something reasonable is returned when
164-
mixing these types and using formulas in ParamResolvers.
165-
166-
Note that sympy orders expressions for deterministic resolution, so
167-
depending on the operands sent to sub(), the expression may not fully
168-
resolve if it needs to take several iterations of resolution.
169-
"""
159+
"""Tests that formula keys are rejected in a `param_dict`."""
170160
a = sympy.Symbol('a')
171161
b = sympy.Symbol('b')
172162
c = sympy.Symbol('c')
173163
e = sympy.Symbol('e')
174-
r = cirq.ParamResolver({a: b + 1, b: 2, b + c: 101, 'd': 2 * e})
175-
assert sympy.Eq(r.value_of('a'), 3)
176-
assert sympy.Eq(r.value_of('b'), 2)
177-
assert sympy.Eq(r.value_of(b + c), 101)
178-
assert sympy.Eq(r.value_of('c'), c)
179-
assert sympy.Eq(r.value_of('d'), 2 * e)
164+
with pytest.raises(TypeError, match='formula'):
165+
_ = cirq.ParamResolver({a: b + 1, b: 2, b + c: 101, 'd': 2 * e})
180166

181167

182168
def test_recursive_evaluation():

0 commit comments

Comments
 (0)