Skip to content

Commit 61fefe6

Browse files
Lock down CircuitOperation and ParamResolver (#5548)
* Lock down CircuitOperation attributes. * Reduce attribute lockdown * Resolve type conflicts * review comments * docs and defensive copies * document error modes Co-authored-by: Cirq Bot <[email protected]>
1 parent dd4c0a2 commit 61fefe6

File tree

11 files changed

+254
-174
lines changed

11 files changed

+254
-174
lines changed

cirq-core/cirq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@
514514
Linspace,
515515
ListSweep,
516516
ParamDictType,
517+
ParamMappingType,
517518
ParamResolver,
518519
ParamResolverOrSimilarType,
519520
Points,

cirq-core/cirq/circuits/circuit_operation.py

+226-161
Large diffs are not rendered by default.

cirq-core/cirq/circuits/circuit_operation_test.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -994,8 +994,10 @@ def test_keys_under_parent_path():
994994
assert cirq.measurement_key_names(op1) == {'A'}
995995
op2 = op1.with_key_path(('B',))
996996
assert cirq.measurement_key_names(op2) == {'B:A'}
997-
op3 = op2.repeat(2)
998-
assert cirq.measurement_key_names(op3) == {'B:0:A', 'B:1:A'}
997+
op3 = cirq.with_key_path_prefix(op2, ('C',))
998+
assert cirq.measurement_key_names(op3) == {'C:B:A'}
999+
op4 = op3.repeat(2)
1000+
assert cirq.measurement_key_names(op4) == {'C:B:0:A', 'C:B:1:A'}
9991001

10001002

10011003
def test_mapped_circuit_preserves_moments():

cirq-core/cirq/protocols/json_test_data/spec.py

+1
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@
185185
'TParamValComplex',
186186
'TRANSFORMER',
187187
'ParamDictType',
188+
'ParamMappingType',
188189
# utility:
189190
'CliffordSimulator',
190191
'NoiseModelFromNoiseProperties',

cirq-core/cirq/study/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
flatten_with_sweep,
2222
)
2323

24-
from cirq.study.resolver import ParamDictType, ParamResolver, ParamResolverOrSimilarType
24+
from cirq.study.resolver import (
25+
ParamDictType,
26+
ParamMappingType,
27+
ParamResolver,
28+
ParamResolverOrSimilarType,
29+
)
2530

2631
from cirq.study.sweepable import Sweepable, to_resolvers, to_sweep, to_sweeps
2732

cirq-core/cirq/study/flatten_expressions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def value_of(
278278
return out
279279
# Create a new symbol
280280
symbol = self._next_symbol(value)
281-
self.param_dict[value] = symbol
281+
self._param_dict[value] = symbol
282282
self._taken_symbols.add(symbol)
283283
return symbol
284284

@@ -292,9 +292,9 @@ def __bool__(self) -> bool:
292292

293293
def __repr__(self) -> str:
294294
if self.get_param_name == self.default_get_param_name:
295-
return f'_ParamFlattener({self.param_dict!r})'
295+
return f'_ParamFlattener({self._param_dict!r})'
296296
else:
297-
return f'_ParamFlattener({self.param_dict!r}, get_param_name={self.get_param_name!r})'
297+
return f'_ParamFlattener({self._param_dict!r}, get_param_name={self.get_param_name!r})'
298298

299299
def flatten(self, val: Any) -> Any:
300300
"""Returns a copy of `val` with any symbols or expressions replaced with

cirq-core/cirq/study/resolver.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Resolves ParameterValues to assigned values."""
1616
import numbers
17-
from typing import Any, Dict, Iterator, Optional, TYPE_CHECKING, Union, cast
17+
from typing import Any, Dict, Iterator, Mapping, Optional, TYPE_CHECKING, Union, cast
1818

1919
import numpy as np
2020
import sympy
@@ -27,9 +27,11 @@
2727

2828

2929
ParamDictType = Dict['cirq.TParamKey', 'cirq.TParamValComplex']
30+
ParamMappingType = Mapping['cirq.TParamKey', 'cirq.TParamValComplex']
3031
document(ParamDictType, """Dictionary from symbols to values.""") # type: ignore
32+
document(ParamMappingType, """Immutable map from symbols to values.""") # type: ignore
3133

32-
ParamResolverOrSimilarType = Union['cirq.ParamResolver', ParamDictType, None]
34+
ParamResolverOrSimilarType = Union['cirq.ParamResolver', ParamMappingType, None]
3335
document(
3436
ParamResolverOrSimilarType, # type: ignore
3537
"""Something that can be used to turn parameters into values.""",
@@ -70,12 +72,16 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None
7072
return # Already initialized. Got wrapped as part of the __new__.
7173

7274
self._param_hash: Optional[int] = None
73-
self.param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
75+
self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
7476
for key in self.param_dict:
7577
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
7678
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
7779
self._deep_eval_map: ParamDictType = {}
7880

81+
@property
82+
def param_dict(self) -> ParamMappingType:
83+
return self._param_dict
84+
7985
def value_of(
8086
self, value: Union['cirq.TParamKey', 'cirq.TParamValComplex'], recursive: bool = True
8187
) -> 'cirq.TParamValComplex':

cirq-core/cirq/work/observable_measurement.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def measure_grouped_settings(
531531
for max_setting, param_resolver in itertools.product(
532532
grouped_settings.keys(), study.to_resolvers(circuit_sweep)
533533
):
534-
circuit_params = param_resolver.param_dict
534+
circuit_params = dict(param_resolver.param_dict)
535535
meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params)
536536
accumulator = BitstringAccumulator(
537537
meas_spec=meas_spec,

cirq-core/cirq/work/sampler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def sample_expectation_values(
353353
# Flatten Circuit Sweep into one big list of Params.
354354
# Keep track of their indices so we can map back.
355355
flat_params: List['cirq.ParamDictType'] = [
356-
pr.param_dict for pr in study.to_resolvers(params)
356+
dict(pr.param_dict) for pr in study.to_resolvers(params)
357357
]
358358
circuit_param_to_sweep_i: Dict[FrozenSet[Tuple[str, Union[int, Tuple[int, int]]]], int] = {
359359
_hashable_param(param.items()): i for i, param in enumerate(flat_params)

cirq-rigetti/cirq_rigetti/circuit_sweep_executors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _get_param_dict(resolver: cirq.ParamResolverOrSimilarType) -> Dict[Union[str
9797
"""
9898
param_dict: Dict[Union[str, sympy.Expr], Any] = {}
9999
if isinstance(resolver, cirq.ParamResolver):
100-
param_dict = resolver.param_dict
100+
param_dict = dict(resolver.param_dict)
101101
elif isinstance(resolver, dict):
102102
param_dict = resolver
103103
return param_dict

cirq-rigetti/cirq_rigetti/circuit_sweep_executors_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_with_quilc_parametric_compilation(
6060

6161
param_resolvers: List[Union[cirq.ParamResolver, cirq.ParamDictType]]
6262
if pass_dict:
63-
param_resolvers = [params.param_dict for params in sweepable]
63+
param_resolvers = [dict(params.param_dict) for params in sweepable]
6464
else:
6565
param_resolvers = [r for r in cirq.to_resolvers(sweepable)]
6666
expected_results = [

0 commit comments

Comments
 (0)