12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
"""An immutable version of the Circuit data structure."""
15
- from typing import AbstractSet , FrozenSet , Iterable , Iterator , Sequence , Tuple , TYPE_CHECKING , Union
15
+ from typing import (
16
+ AbstractSet ,
17
+ FrozenSet ,
18
+ Hashable ,
19
+ Iterable ,
20
+ Iterator ,
21
+ Sequence ,
22
+ Tuple ,
23
+ TYPE_CHECKING ,
24
+ Union ,
25
+ )
16
26
17
27
import numpy as np
18
28
@@ -34,7 +44,10 @@ class FrozenCircuit(AbstractCircuit, protocols.SerializableByKey):
34
44
"""
35
45
36
46
def __init__ (
37
- self , * contents : 'cirq.OP_TREE' , strategy : 'cirq.InsertStrategy' = InsertStrategy .EARLIEST
47
+ self ,
48
+ * contents : 'cirq.OP_TREE' ,
49
+ strategy : 'cirq.InsertStrategy' = InsertStrategy .EARLIEST ,
50
+ tags : Sequence [Hashable ] = (),
38
51
) -> None :
39
52
"""Initializes a frozen circuit.
40
53
@@ -47,9 +60,14 @@ def __init__(
47
60
strategy: When initializing the circuit with operations and moments
48
61
from `contents`, this determines how the operations are packed
49
62
together.
63
+ tags: A sequence of any type of object that is useful to attach metadata
64
+ to this circuit as long as the type is hashable. If you wish the
65
+ resulting circuit to be eventually serialized into JSON, you should
66
+ also restrict the tags to be JSON serializable.
50
67
"""
51
68
base = Circuit (contents , strategy = strategy )
52
69
self ._moments = tuple (base .moments )
70
+ self ._tags = tuple (tags )
53
71
54
72
@classmethod
55
73
def _from_moments (cls , moments : Iterable ['cirq.Moment' ]) -> 'FrozenCircuit' :
@@ -61,10 +79,35 @@ def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
61
79
def moments (self ) -> Sequence ['cirq.Moment' ]:
62
80
return self ._moments
63
81
82
+ @property
83
+ def tags (self ) -> Tuple [Hashable , ...]:
84
+ """Returns a tuple of the Circuit's tags."""
85
+ return self ._tags
86
+
87
+ @_compat .cached_property
88
+ def untagged (self ) -> 'cirq.FrozenCircuit' :
89
+ """Returns the underlying FrozenCircuit without any tags."""
90
+ return self ._from_moments (self ._moments ) if self .tags else self
91
+
92
+ def with_tags (self , * new_tags : Hashable ) -> 'cirq.FrozenCircuit' :
93
+ """Creates a new tagged `FrozenCircuit` with `self.tags` and `new_tags` combined."""
94
+ if not new_tags :
95
+ return self
96
+ new_circuit = FrozenCircuit (tags = self .tags + new_tags )
97
+ new_circuit ._moments = self ._moments
98
+ return new_circuit
99
+
64
100
@_compat .cached_method
65
101
def __hash__ (self ) -> int :
66
102
# Explicitly cached for performance
67
- return hash ((self .moments ,))
103
+ return hash ((self .moments , self .tags ))
104
+
105
+ def __eq__ (self , other ):
106
+ super_eq = super ().__eq__ (other )
107
+ if super_eq is not True :
108
+ return super_eq
109
+ other_tags = other .tags if isinstance (other , FrozenCircuit ) else ()
110
+ return self .tags == other_tags
68
111
69
112
def __getstate__ (self ):
70
113
# Don't save hash when pickling; see #3777.
@@ -130,11 +173,23 @@ def all_measurement_key_names(self) -> FrozenSet[str]:
130
173
131
174
@_compat .cached_method
132
175
def _is_parameterized_ (self ) -> bool :
133
- return super ()._is_parameterized_ ()
176
+ return super ()._is_parameterized_ () or any (
177
+ protocols .is_parameterized (tag ) for tag in self .tags
178
+ )
134
179
135
180
@_compat .cached_method
136
181
def _parameter_names_ (self ) -> AbstractSet [str ]:
137
- return super ()._parameter_names_ ()
182
+ tag_params = {name for tag in self .tags for name in protocols .parameter_names (tag )}
183
+ return super ()._parameter_names_ () | tag_params
184
+
185
+ def _resolve_parameters_ (
186
+ self , resolver : 'cirq.ParamResolver' , recursive : bool
187
+ ) -> 'cirq.FrozenCircuit' :
188
+ resolved_circuit = super ()._resolve_parameters_ (resolver , recursive )
189
+ resolved_tags = [
190
+ protocols .resolve_parameters (tag , resolver , recursive ) for tag in self .tags
191
+ ]
192
+ return resolved_circuit .with_tags (* resolved_tags )
138
193
139
194
def _measurement_key_names_ (self ) -> FrozenSet [str ]:
140
195
return self .all_measurement_key_names ()
@@ -161,6 +216,20 @@ def __pow__(self, other) -> 'cirq.FrozenCircuit':
161
216
except :
162
217
return NotImplemented
163
218
219
+ def _repr_args (self ) -> str :
220
+ moments_repr = super ()._repr_args ()
221
+ tag_repr = ',' .join (_compat .proper_repr (t ) for t in self ._tags )
222
+ return f'{ moments_repr } , tags=[{ tag_repr } ]' if self .tags else moments_repr
223
+
224
+ def _json_dict_ (self ):
225
+ attribute_names = ['moments' , 'tags' ] if self .tags else ['moments' ]
226
+ ret = protocols .obj_to_dict_helper (self , attribute_names )
227
+ return ret
228
+
229
+ @classmethod
230
+ def _from_json_dict_ (cls , moments , * , tags = (), ** kwargs ):
231
+ return cls (moments , strategy = InsertStrategy .EARLIEST , tags = tags )
232
+
164
233
def concat_ragged (
165
234
* circuits : 'cirq.AbstractCircuit' , align : Union ['cirq.Alignment' , str ] = Alignment .LEFT
166
235
) -> 'cirq.FrozenCircuit' :
0 commit comments