Skip to content

Commit 95bebae

Browse files
authored
Reject formulas as keys of ParamResolvers (#5384)
* Reject formulas as keys of ParamResolvers - A ParamResolver resolves variables into values. - Having non-trivial formulas as keys allows a significant complexity and ambiguity into ParamResolvers, since it is unclear how much is supported. Prevent this case altogether by raising an error if non-symbol formulas are used in ParamResolvers. Fixes: #3550
1 parent 1b7a800 commit 95bebae

File tree

4 files changed

+25
-24
lines changed

4 files changed

+25
-24
lines changed

cirq-core/cirq/study/resolver.py

+6
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class ParamResolver:
5555
Attributes:
5656
param_dict: A dictionary from the ParameterValue key (str) to its
5757
assigned value.
58+
59+
Raises:
60+
TypeError if formulas are passed as keys.
5861
"""
5962

6063
def __new__(cls, param_dict: 'cirq.ParamResolverOrSimilarType' = None):
@@ -68,6 +71,9 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None
6871

6972
self._param_hash: Optional[int] = None
7073
self.param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
74+
for key in self.param_dict:
75+
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
76+
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
7177
self._deep_eval_map: ParamDictType = {}
7278

7379
def value_of(

cirq-core/cirq/study/resolver_test.py

+3-17
Original file line numberDiff line numberDiff line change
@@ -156,27 +156,13 @@ def test_param_dict_iter():
156156

157157

158158
def test_formulas_in_param_dict():
159-
"""Test formulas in a `param_dict`.
160-
161-
Param dicts are allowed to have str or sympy.Symbol as keys and
162-
floats or sympy.Symbol as values. This should not be a common use case,
163-
but this tests makes sure something reasonable is returned when
164-
mixing these types and using formulas in ParamResolvers.
165-
166-
Note that sympy orders expressions for deterministic resolution, so
167-
depending on the operands sent to sub(), the expression may not fully
168-
resolve if it needs to take several iterations of resolution.
169-
"""
159+
"""Tests that formula keys are rejected in a `param_dict`."""
170160
a = sympy.Symbol('a')
171161
b = sympy.Symbol('b')
172162
c = sympy.Symbol('c')
173163
e = sympy.Symbol('e')
174-
r = cirq.ParamResolver({a: b + 1, b: 2, b + c: 101, 'd': 2 * e})
175-
assert sympy.Eq(r.value_of('a'), 3)
176-
assert sympy.Eq(r.value_of('b'), 2)
177-
assert sympy.Eq(r.value_of(b + c), 101)
178-
assert sympy.Eq(r.value_of('c'), c)
179-
assert sympy.Eq(r.value_of('d'), 2 * e)
164+
with pytest.raises(TypeError, match='formula'):
165+
_ = cirq.ParamResolver({a: b + 1, b: 2, b + c: 101, 'd': 2 * e})
180166

181167

182168
def test_recursive_evaluation():

cirq-google/cirq_google/api/v2/sweeps.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,9 @@ def sweep_to_proto(
6161
sweep_dict: Dict[str, List[float]] = {}
6262
for param_resolver in sweep:
6363
for key in param_resolver:
64-
if isinstance(key, sympy.Expr):
65-
raise ValueError(f'cannot convert to v2 Sweep proto: {sweep}')
6664
if key not in sweep_dict:
67-
sweep_dict[key] = []
68-
sweep_dict[key].append(cast(float, param_resolver.value_of(key)))
65+
sweep_dict[cast(str, key)] = []
66+
sweep_dict[cast(str, key)].append(cast(float, param_resolver.value_of(key)))
6967
out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP
7068
for key in sweep_dict:
7169
sweep_to_proto(cirq.Points(key, sweep_dict[key]), out=out.sweep_function.sweeps.add())

cirq-google/cirq_google/api/v2/sweeps_test.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,20 @@ def test_sweep_to_proto_linspace():
7373

7474

7575
def test_list_sweep_bad_expression():
76-
sweep = cirq.ListSweep([cirq.ParamResolver({sympy.Symbol('a') + sympy.Symbol('b'): 4.0})])
77-
with pytest.raises(ValueError, match='cannot convert'):
78-
v2.sweep_to_proto(sweep)
76+
with pytest.raises(TypeError, match='formula'):
77+
_ = cirq.ListSweep([cirq.ParamResolver({sympy.Symbol('a') + sympy.Symbol('b'): 4.0})])
78+
79+
80+
def test_symbol_to_string_conversion():
81+
sweep = cirq.ListSweep([cirq.ParamResolver({sympy.Symbol('a'): 4.0})])
82+
proto = v2.sweep_to_proto(sweep)
83+
assert isinstance(proto, v2.run_context_pb2.Sweep)
84+
expected = v2.run_context_pb2.Sweep()
85+
expected.sweep_function.function_type = v2.run_context_pb2.SweepFunction.ZIP
86+
p1 = expected.sweep_function.sweeps.add()
87+
p1.single_sweep.parameter_key = 'a'
88+
p1.single_sweep.points.points.extend([4.0])
89+
assert proto == expected
7990

8091

8192
def test_sweep_to_proto_points():

0 commit comments

Comments
 (0)