Skip to content

Commit c45b2c5

Browse files
authored
Fix sympy error (#5930)
1 parent 19e7a42 commit c45b2c5

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

Diff for: cirq-core/cirq/study/resolver.py

+6
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def value_of(
112112
Raises:
113113
RecursionError: If the ParamResolver detects a loop in recursive
114114
resolution.
115+
sympy.SympifyError: If the resulting value cannot be interpreted.
115116
"""
116117

117118
# Input is a pass through type, no resolution needed: return early
@@ -179,7 +180,12 @@ def value_of(
179180
if not recursive:
180181
# Resolves one step at a time. For example:
181182
# a.subs({a: b, b: c}) == b
183+
#
184+
# Note that a sympy.SympifyError here likely means
185+
# that one of the expressions was not parsable by sympy
186+
# (such as a function returning NotImplemented)
182187
v = value.subs(self.param_dict, simultaneous=True)
188+
183189
if v.free_symbols:
184190
return v
185191
elif sympy.im(v):

Diff for: cirq-core/cirq/study/resolver_test.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -227,25 +227,31 @@ class Foo:
227227
def _resolved_value_(self):
228228
return self
229229

230-
class Bar:
231-
def _resolved_value_(self):
232-
return NotImplemented
233-
234230
class Baz:
235231
def _resolved_value_(self):
236232
return 'Baz'
237233

238234
foo = Foo()
239-
bar = Bar()
240235
baz = Baz()
241236

242237
a = sympy.Symbol('a')
243-
b = sympy.Symbol('b')
244-
c = sympy.Symbol('c')
245-
r = cirq.ParamResolver({a: foo, b: bar, c: baz})
238+
b = sympy.Symbol('c')
239+
r = cirq.ParamResolver({a: foo, b: baz})
246240
assert r.value_of(a) is foo
247-
assert r.value_of(b) is b
248-
assert r.value_of(c) == 'Baz'
241+
assert r.value_of(b) == 'Baz'
242+
243+
244+
@pytest.mark.xfail(reason='this test requires sympy 1.12', strict=True)
245+
def test_custom_value_not_implemented():
246+
class Bar:
247+
def _resolved_value_(self):
248+
return NotImplemented
249+
250+
b = sympy.Symbol('b')
251+
bar = Bar()
252+
r = cirq.ParamResolver({b: bar})
253+
with pytest.raises(sympy.SympifyError):
254+
_ = r.value_of(b)
249255

250256

251257
def test_compose():

0 commit comments

Comments
 (0)