Skip to content

Fix sympy error #5930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cirq-core/cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def value_of(
Raises:
RecursionError: If the ParamResolver detects a loop in recursive
resolution.
sympy.SympifyError: If the resulting value cannot be interpreted.
"""

# Input is a pass through type, no resolution needed: return early
Expand Down Expand Up @@ -179,7 +180,12 @@ def value_of(
if not recursive:
# Resolves one step at a time. For example:
# a.subs({a: b, b: c}) == b
#
# Note that a sympy.SympifyError here likely means
# that one of the expressions was not parsable by sympy
# (such as a function returning NotImplemented)
v = value.subs(self.param_dict, simultaneous=True)

if v.free_symbols:
return v
elif sympy.im(v):
Expand Down
26 changes: 16 additions & 10 deletions cirq-core/cirq/study/resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,25 +227,31 @@ class Foo:
def _resolved_value_(self):
return self

class Bar:
def _resolved_value_(self):
return NotImplemented

class Baz:
def _resolved_value_(self):
return 'Baz'

foo = Foo()
bar = Bar()
baz = Baz()

a = sympy.Symbol('a')
b = sympy.Symbol('b')
c = sympy.Symbol('c')
r = cirq.ParamResolver({a: foo, b: bar, c: baz})
b = sympy.Symbol('c')
r = cirq.ParamResolver({a: foo, b: baz})
assert r.value_of(a) is foo
assert r.value_of(b) is b
assert r.value_of(c) == 'Baz'
assert r.value_of(b) == 'Baz'


@pytest.mark.xfail(reason='this test requires sympy 1.12', strict=True)
def test_custom_value_not_implemented():
class Bar:
def _resolved_value_(self):
return NotImplemented

b = sympy.Symbol('b')
bar = Bar()
r = cirq.ParamResolver({b: bar})
with pytest.raises(sympy.SympifyError):
_ = r.value_of(b)


def test_compose():
Expand Down