diff --git a/cirq-core/cirq/study/resolver.py b/cirq-core/cirq/study/resolver.py index 3d1e2360411..c3c104d5d86 100644 --- a/cirq-core/cirq/study/resolver.py +++ b/cirq-core/cirq/study/resolver.py @@ -112,6 +112,7 @@ def value_of( Raises: RecursionError: If the ParamResolver detects a loop in recursive resolution. + sympy.SympifyError: If the resulting value cannot be interpreted. """ # Input is a pass through type, no resolution needed: return early @@ -179,7 +180,12 @@ def value_of( if not recursive: # Resolves one step at a time. For example: # a.subs({a: b, b: c}) == b + # + # Note that a sympy.SympifyError here likely means + # that one of the expressions was not parsable by sympy + # (such as a function returning NotImplemented) v = value.subs(self.param_dict, simultaneous=True) + if v.free_symbols: return v elif sympy.im(v): diff --git a/cirq-core/cirq/study/resolver_test.py b/cirq-core/cirq/study/resolver_test.py index bf06fd20a62..c2e1b8dde4b 100644 --- a/cirq-core/cirq/study/resolver_test.py +++ b/cirq-core/cirq/study/resolver_test.py @@ -227,25 +227,31 @@ class Foo: def _resolved_value_(self): return self - class Bar: - def _resolved_value_(self): - return NotImplemented - class Baz: def _resolved_value_(self): return 'Baz' foo = Foo() - bar = Bar() baz = Baz() a = sympy.Symbol('a') - b = sympy.Symbol('b') - c = sympy.Symbol('c') - r = cirq.ParamResolver({a: foo, b: bar, c: baz}) + b = sympy.Symbol('c') + r = cirq.ParamResolver({a: foo, b: baz}) assert r.value_of(a) is foo - assert r.value_of(b) is b - assert r.value_of(c) == 'Baz' + assert r.value_of(b) == 'Baz' + + +@pytest.mark.xfail(reason='this test requires sympy 1.12', strict=True) +def test_custom_value_not_implemented(): + class Bar: + def _resolved_value_(self): + return NotImplemented + + b = sympy.Symbol('b') + bar = Bar() + r = cirq.ParamResolver({b: bar}) + with pytest.raises(sympy.SympifyError): + _ = r.value_of(b) def test_compose():