Skip to content

Commit 041ce5d

Browse files
authored
Add tags to cirq.FrozenCircuit (#6266)
* Add tags to FrozenCircuit * Address comments and fix tests * Address maffoo's comments
1 parent 83609eb commit 041ce5d

File tree

3 files changed

+112
-10
lines changed

3 files changed

+112
-10
lines changed

cirq-core/cirq/circuits/circuit.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,15 @@ def __getitem__(self, key):
272272
def __str__(self) -> str:
273273
return self.to_text_diagram()
274274

275-
def __repr__(self) -> str:
276-
cls_name = self.__class__.__name__
275+
def _repr_args(self) -> str:
277276
args = []
278277
if self.moments:
279278
args.append(_list_repr_with_indented_item_lines(self.moments))
280-
return f'cirq.{cls_name}({", ".join(args)})'
279+
return f'{", ".join(args)}'
280+
281+
def __repr__(self) -> str:
282+
cls_name = self.__class__.__name__
283+
return f'cirq.{cls_name}({self._repr_args()})'
281284

282285
def _repr_pretty_(self, p: Any, cycle: bool) -> None:
283286
"""Print ASCII diagram in Jupyter."""
@@ -1791,7 +1794,6 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
17911794

17921795
# "mop" means current moment-or-operation
17931796
for mop in ops.flatten_to_ops_or_moments(contents):
1794-
17951797
# Identify the index of the moment to place this `mop` into.
17961798
placement_index = get_earliest_accommodating_moment_index(
17971799
mop, qubit_indices, mkey_indices, ckey_indices, length
@@ -2450,7 +2452,6 @@ def _draw_moment_annotations(
24502452
first_annotation_row: int,
24512453
transpose: bool,
24522454
):
2453-
24542455
for k, annotation in enumerate(_get_moment_annotations(moment)):
24552456
args = protocols.CircuitDiagramInfoArgs(
24562457
known_qubits=(),

cirq-core/cirq/circuits/frozen_circuit.py

+74-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""An immutable version of the Circuit data structure."""
15-
from typing import AbstractSet, FrozenSet, Iterable, Iterator, Sequence, Tuple, TYPE_CHECKING, Union
15+
from typing import (
16+
AbstractSet,
17+
FrozenSet,
18+
Hashable,
19+
Iterable,
20+
Iterator,
21+
Sequence,
22+
Tuple,
23+
TYPE_CHECKING,
24+
Union,
25+
)
1626

1727
import numpy as np
1828

@@ -34,7 +44,10 @@ class FrozenCircuit(AbstractCircuit, protocols.SerializableByKey):
3444
"""
3545

3646
def __init__(
37-
self, *contents: 'cirq.OP_TREE', strategy: 'cirq.InsertStrategy' = InsertStrategy.EARLIEST
47+
self,
48+
*contents: 'cirq.OP_TREE',
49+
strategy: 'cirq.InsertStrategy' = InsertStrategy.EARLIEST,
50+
tags: Sequence[Hashable] = (),
3851
) -> None:
3952
"""Initializes a frozen circuit.
4053
@@ -47,9 +60,14 @@ def __init__(
4760
strategy: When initializing the circuit with operations and moments
4861
from `contents`, this determines how the operations are packed
4962
together.
63+
tags: A sequence of any type of object that is useful to attach metadata
64+
to this circuit as long as the type is hashable. If you wish the
65+
resulting circuit to be eventually serialized into JSON, you should
66+
also restrict the tags to be JSON serializable.
5067
"""
5168
base = Circuit(contents, strategy=strategy)
5269
self._moments = tuple(base.moments)
70+
self._tags = tuple(tags)
5371

5472
@classmethod
5573
def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
@@ -61,10 +79,35 @@ def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
6179
def moments(self) -> Sequence['cirq.Moment']:
6280
return self._moments
6381

82+
@property
83+
def tags(self) -> Tuple[Hashable, ...]:
84+
"""Returns a tuple of the Circuit's tags."""
85+
return self._tags
86+
87+
@_compat.cached_property
88+
def untagged(self) -> 'cirq.FrozenCircuit':
89+
"""Returns the underlying FrozenCircuit without any tags."""
90+
return self._from_moments(self._moments) if self.tags else self
91+
92+
def with_tags(self, *new_tags: Hashable) -> 'cirq.FrozenCircuit':
93+
"""Creates a new tagged `FrozenCircuit` with `self.tags` and `new_tags` combined."""
94+
if not new_tags:
95+
return self
96+
new_circuit = FrozenCircuit(tags=self.tags + new_tags)
97+
new_circuit._moments = self._moments
98+
return new_circuit
99+
64100
@_compat.cached_method
65101
def __hash__(self) -> int:
66102
# Explicitly cached for performance
67-
return hash((self.moments,))
103+
return hash((self.moments, self.tags))
104+
105+
def __eq__(self, other):
106+
super_eq = super().__eq__(other)
107+
if super_eq is not True:
108+
return super_eq
109+
other_tags = other.tags if isinstance(other, FrozenCircuit) else ()
110+
return self.tags == other_tags
68111

69112
def __getstate__(self):
70113
# Don't save hash when pickling; see #3777.
@@ -130,11 +173,23 @@ def all_measurement_key_names(self) -> FrozenSet[str]:
130173

131174
@_compat.cached_method
132175
def _is_parameterized_(self) -> bool:
133-
return super()._is_parameterized_()
176+
return super()._is_parameterized_() or any(
177+
protocols.is_parameterized(tag) for tag in self.tags
178+
)
134179

135180
@_compat.cached_method
136181
def _parameter_names_(self) -> AbstractSet[str]:
137-
return super()._parameter_names_()
182+
tag_params = {name for tag in self.tags for name in protocols.parameter_names(tag)}
183+
return super()._parameter_names_() | tag_params
184+
185+
def _resolve_parameters_(
186+
self, resolver: 'cirq.ParamResolver', recursive: bool
187+
) -> 'cirq.FrozenCircuit':
188+
resolved_circuit = super()._resolve_parameters_(resolver, recursive)
189+
resolved_tags = [
190+
protocols.resolve_parameters(tag, resolver, recursive) for tag in self.tags
191+
]
192+
return resolved_circuit.with_tags(*resolved_tags)
138193

139194
def _measurement_key_names_(self) -> FrozenSet[str]:
140195
return self.all_measurement_key_names()
@@ -161,6 +216,20 @@ def __pow__(self, other) -> 'cirq.FrozenCircuit':
161216
except:
162217
return NotImplemented
163218

219+
def _repr_args(self) -> str:
220+
moments_repr = super()._repr_args()
221+
tag_repr = ','.join(_compat.proper_repr(t) for t in self._tags)
222+
return f'{moments_repr}, tags=[{tag_repr}]' if self.tags else moments_repr
223+
224+
def _json_dict_(self):
225+
attribute_names = ['moments', 'tags'] if self.tags else ['moments']
226+
ret = protocols.obj_to_dict_helper(self, attribute_names)
227+
return ret
228+
229+
@classmethod
230+
def _from_json_dict_(cls, moments, *, tags=(), **kwargs):
231+
return cls(moments, strategy=InsertStrategy.EARLIEST, tags=tags)
232+
164233
def concat_ragged(
165234
*circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT
166235
) -> 'cirq.FrozenCircuit':

cirq-core/cirq/circuits/frozen_circuit_test.py

+32
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818

1919
import pytest
20+
import sympy
2021

2122
import cirq
2223

@@ -74,3 +75,34 @@ def test_immutable():
7475
match="(can't set attribute)|(property 'moments' of 'FrozenCircuit' object has no setter)",
7576
):
7677
c.moments = (cirq.Moment(cirq.H(q)), cirq.Moment(cirq.X(q)))
78+
79+
80+
def test_tagged_circuits():
81+
q = cirq.LineQubit(0)
82+
ops = [cirq.X(q), cirq.H(q)]
83+
tags = [sympy.Symbol("a"), "b"]
84+
circuit = cirq.Circuit(ops)
85+
frozen_circuit = cirq.FrozenCircuit(ops)
86+
tagged_circuit = cirq.FrozenCircuit(ops, tags=tags)
87+
# Test equality
88+
assert tagged_circuit.tags == tuple(tags)
89+
assert circuit == frozen_circuit != tagged_circuit
90+
assert cirq.approx_eq(circuit, frozen_circuit)
91+
assert cirq.approx_eq(frozen_circuit, tagged_circuit)
92+
# Test hash
93+
assert hash(frozen_circuit) != hash(tagged_circuit)
94+
# Test _repr_ and _json_ round trips.
95+
cirq.testing.assert_equivalent_repr(tagged_circuit)
96+
cirq.testing.assert_json_roundtrip_works(tagged_circuit)
97+
# Test utility methods and constructors
98+
assert frozen_circuit.with_tags() is frozen_circuit
99+
assert frozen_circuit.with_tags(*tags) == tagged_circuit
100+
assert tagged_circuit.with_tags("c") == cirq.FrozenCircuit(ops, tags=[*tags, "c"])
101+
assert tagged_circuit.untagged == frozen_circuit
102+
assert frozen_circuit.untagged is frozen_circuit
103+
# Test parameterized protocols
104+
assert cirq.is_parameterized(frozen_circuit) is False
105+
assert cirq.is_parameterized(tagged_circuit) is True
106+
assert cirq.parameter_names(tagged_circuit) == {"a"}
107+
# Tags are not propagated to diagrams yet.
108+
assert str(frozen_circuit) == str(tagged_circuit)

0 commit comments

Comments
 (0)