diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index 1fd6eefdf16..3c4d6242d17 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -134,8 +134,7 @@ def steps(*args, **kwargs): simulator.simulate_moment_steps.side_effect = steps circuit = mock.Mock(cirq.Circuit) - param_resolver = mock.Mock(cirq.ParamResolver) - param_resolver.param_dict = {} + param_resolver = cirq.ParamResolver({}) qubit_order = mock.Mock(cirq.QubitOrder) result = simulator.simulate( program=circuit, param_resolver=param_resolver, qubit_order=qubit_order, initial_state=2 @@ -163,9 +162,7 @@ def steps(*args, **kwargs): simulator.simulate_moment_steps.side_effect = steps circuit = mock.Mock(cirq.Circuit) - param_resolvers = [mock.Mock(cirq.ParamResolver), mock.Mock(cirq.ParamResolver)] - for resolver in param_resolvers: - resolver.param_dict = {} + param_resolvers = [cirq.ParamResolver({}), cirq.ParamResolver({})] qubit_order = mock.Mock(cirq.QubitOrder) results = simulator.simulate_sweep( program=circuit, params=param_resolvers, qubit_order=qubit_order, initial_state=2 diff --git a/cirq-core/cirq/study/resolver.py b/cirq-core/cirq/study/resolver.py index e603caeda02..b66c31ac884 100644 --- a/cirq-core/cirq/study/resolver.py +++ b/cirq-core/cirq/study/resolver.py @@ -36,8 +36,11 @@ ParamResolverOrSimilarType, """Something that can be used to turn parameters into values.""" ) +# Used to mark values that are not found in a dict. +_NOT_FOUND = object() + # Used to mark values that are being resolved recursively to detect loops. -_RecursionFlag = object() +_RECURSION_FLAG = object() def _is_param_resolver_or_similar_type(obj: Any): @@ -72,7 +75,7 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None self._param_hash: Optional[int] = None self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict) - for key in self.param_dict: + for key in self._param_dict: if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol): raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})') self._deep_eval_map: ParamDictType = {} @@ -120,32 +123,30 @@ def value_of( if v is not NotImplemented: return v - # Handles 2 cases: - # Input is a string and maps to a number in the dictionary - # Input is a symbol and maps to a number in the dictionary - # In both cases, return it directly. - if value in self.param_dict: - # Note: if the value is in the dictionary, it will be a key type - # Add a cast to make mypy happy. - param_value = self.param_dict[cast('cirq.TParamKey', value)] + # Handle string or symbol + if isinstance(value, (str, sympy.Symbol)): + string = value if isinstance(value, str) else value.name + symbol = value if isinstance(value, sympy.Symbol) else sympy.Symbol(value) + param_value = self._param_dict.get(string, _NOT_FOUND) + if param_value is _NOT_FOUND: + param_value = self._param_dict.get(symbol, _NOT_FOUND) + if param_value is _NOT_FOUND: + # Symbol or string cannot be resolved if not in param dict; return as symbol. + return symbol v = _resolve_value(param_value) if v is not NotImplemented: return v + if isinstance(param_value, str): + param_value = sympy.Symbol(param_value) + elif not isinstance(param_value, sympy.Basic): + return value # type: ignore[return-value] + if recursive: + param_value = self._value_of_recursive(value) + return param_value # type: ignore[return-value] - # Input is a string and is not in the dictionary. - # Treat it as a symbol instead. - if isinstance(value, str): - # If the string is in the param_dict as a value, return it. - # Otherwise, try using the symbol instead. - return self.value_of(sympy.Symbol(value), recursive) - - # Input is a symbol (sympy.Symbol('a')) and its string maps to a number - # in the dictionary ({'a': 1.0}). Return it. - if isinstance(value, sympy.Symbol) and value.name in self.param_dict: - param_value = self.param_dict[value.name] - v = _resolve_value(param_value) - if v is not NotImplemented: - return v + if not isinstance(value, sympy.Basic): + # No known way to resolve this variable, return unchanged. + return value # The following resolves common sympy expressions # If sympy did its job and wasn't slower than molasses, @@ -171,10 +172,6 @@ def value_of( return np.float_power(cast(complex, base), cast(complex, exponent)) return np.power(cast(complex, base), cast(complex, exponent)) - if not isinstance(value, sympy.Basic): - # No known way to resolve this variable, return unchanged. - return value - # Input is either a sympy formula or the dictionary maps to a # formula. Use sympy to resolve the value. # Note that sympy.subs() is slow, so we want to avoid this and @@ -186,7 +183,7 @@ def value_of( # Note that a sympy.SympifyError here likely means # that one of the expressions was not parsable by sympy # (such as a function returning NotImplemented) - v = value.subs(self.param_dict, simultaneous=True) + v = value.subs(self._param_dict, simultaneous=True) if v.free_symbols: return v @@ -197,23 +194,26 @@ def value_of( else: return float(v) + return self._value_of_recursive(value) + + def _value_of_recursive(self, value: 'cirq.TParamKey') -> 'cirq.TParamValComplex': # Recursive parameter resolution. We can safely assume that value is a # single symbol, since combinations are handled earlier in the method. if value in self._deep_eval_map: v = self._deep_eval_map[value] - if v is not _RecursionFlag: - return v - raise RecursionError('Evaluation of {value} indirectly contains itself.') + if v is _RECURSION_FLAG: + raise RecursionError('Evaluation of {value} indirectly contains itself.') + return v # There isn't a full evaluation for 'value' yet. Until it's ready, # map value to None to identify loops in component evaluation. - self._deep_eval_map[value] = _RecursionFlag # type: ignore + self._deep_eval_map[value] = _RECURSION_FLAG # type: ignore v = self.value_of(value, recursive=False) if v == value: self._deep_eval_map[value] = v else: - self._deep_eval_map[value] = self.value_of(v, recursive) + self._deep_eval_map[value] = self.value_of(v, recursive=True) return self._deep_eval_map[value] def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'ParamResolver': @@ -224,17 +224,17 @@ def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'P new_dict.update( {k: resolver.value_of(v, recursive) for k, v in new_dict.items()} # type: ignore[misc] ) - if recursive and self.param_dict: + if recursive and self._param_dict: new_resolver = ParamResolver(cast(ParamDictType, new_dict)) # Resolve down to single-step mappings. return ParamResolver()._resolve_parameters_(new_resolver, recursive=True) return ParamResolver(cast(ParamDictType, new_dict)) def __iter__(self) -> Iterator[Union[str, sympy.Expr]]: - return iter(self.param_dict) + return iter(self._param_dict) def __bool__(self) -> bool: - return bool(self.param_dict) + return bool(self._param_dict) def __getitem__( self, key: Union['cirq.TParamKey', 'cirq.TParamValComplex'] @@ -243,13 +243,13 @@ def __getitem__( def __hash__(self) -> int: if self._param_hash is None: - self._param_hash = hash(frozenset(self.param_dict.items())) + self._param_hash = hash(frozenset(self._param_dict.items())) return self._param_hash def __eq__(self, other): if not isinstance(other, ParamResolver): return NotImplemented - return self.param_dict == other.param_dict + return self._param_dict == other._param_dict def __ne__(self, other): return not self == other @@ -257,7 +257,7 @@ def __ne__(self, other): def __repr__(self) -> str: param_dict_repr = ( '{' - + ', '.join([f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self.param_dict.items()]) + + ', '.join(f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self._param_dict.items()) + '}' ) return f'cirq.ParamResolver({param_dict_repr})' @@ -265,7 +265,7 @@ def __repr__(self) -> str: def _json_dict_(self) -> Dict[str, Any]: return { # JSON requires mappings to have keys of basic types. - 'param_dict': list(self.param_dict.items()) + 'param_dict': list(self._param_dict.items()) } @classmethod diff --git a/cirq-core/cirq/study/resolver_test.py b/cirq-core/cirq/study/resolver_test.py index 627c8540cb9..3981906d73f 100644 --- a/cirq-core/cirq/study/resolver_test.py +++ b/cirq-core/cirq/study/resolver_test.py @@ -53,10 +53,10 @@ def test_value_of_transformed_types(val, resolved): @pytest.mark.parametrize('val,resolved', [(sympy.I, 1j)]) def test_value_of_substituted_types(val, resolved): - _assert_consistent_resolution(val, resolved, True) + _assert_consistent_resolution(val, resolved) -def _assert_consistent_resolution(v, resolved, subs_called=False): +def _assert_consistent_resolution(v, resolved): """Asserts that parameter resolution works consistently. The ParamResolver.value_of method can resolve any Sympy expression - @@ -70,7 +70,7 @@ def _assert_consistent_resolution(v, resolved, subs_called=False): Args: v: the value to resolve resolved: the expected resolution result - subs_called: if True, it is expected that the slow subs method is called + Raises: AssertionError in case resolution assertion fail. """ @@ -93,9 +93,7 @@ def subs(self, *args, **kwargs): # symbol based resolution s = SubsAwareSymbol('a') assert r.value_of(s) == resolved, f"expected {resolved}, got {r.value_of(s)}" - assert ( - subs_called == s.called - ), f"For pass-through type {type(v)} sympy.subs shouldn't have been called." + assert not s.called, f"For pass-through type {type(v)} sympy.subs shouldn't have been called." assert isinstance( r.value_of(s), type(resolved) ), f"expected {type(resolved)} got {type(r.value_of(s))}" @@ -243,15 +241,18 @@ def _resolved_value_(self): def test_custom_value_not_implemented(): - class Bar: + class BarImplicit: + pass + + class BarExplicit: def _resolved_value_(self): return NotImplemented - b = sympy.Symbol('b') - bar = Bar() - r = cirq.ParamResolver({b: bar}) - with pytest.raises(sympy.SympifyError): - _ = r.value_of(b) + for cls in [BarImplicit, BarExplicit]: + b = sympy.Symbol('b') + bar = cls() + r = cirq.ParamResolver({b: bar}) + assert r.value_of(b) == b def test_compose(): diff --git a/cirq-ionq/cirq_ionq/sampler_test.py b/cirq-ionq/cirq_ionq/sampler_test.py index 2bc0c69f23f..0aba2251013 100644 --- a/cirq-ionq/cirq_ionq/sampler_test.py +++ b/cirq-ionq/cirq_ionq/sampler_test.py @@ -100,7 +100,7 @@ def test_sampler_multiple_jobs(): results = sampler.sample( program=circuit, repetitions=4, - params=[cirq.ParamResolver({x: '0.5'}), cirq.ParamResolver({x: '0.6'})], + params=[cirq.ParamResolver({x: 0.5}), cirq.ParamResolver({x: 0.6})], ) pd.testing.assert_frame_equal( results,