Skip to content

Commit f736079

Browse files
committed
Return same FrozenCircuit instance from Circuit.freeze() until mutated
This simplifies working with frozen circuits and circuit identity by storing the `FrozenCircuit` instance returned by `Circuit.freeze()` and returning the same instance until it is "invalidated" by any mutations of the Circuit itself. Note that if we make additional changes to the Circuit implementation, such as adding new mutating methods, we will need to remember to add `self._frozen = None` statements to invalidate the frozen representation in those places as well. To reduce risk, I have put these invalidation statements immediately after any mutation operations on `self._moments`, even in places where some invalidations could be elided or pushed to the end of a method; it seems safer to keep these close together in the source code. I have also added an implementation note calling out this detail and reminding future developers to invalidate when needed if adding other `Circuit` mutations.
1 parent 96b3842 commit f736079

File tree

3 files changed

+65
-13
lines changed

3 files changed

+65
-13
lines changed

cirq-core/cirq/circuits/circuit.py

+39-13
Original file line numberDiff line numberDiff line change
@@ -173,28 +173,20 @@ def _from_moments(cls: Type[CIRCUIT_TYPE], moments: Iterable['cirq.Moment']) ->
173173
def moments(self) -> Sequence['cirq.Moment']:
174174
pass
175175

176+
@abc.abstractmethod
176177
def freeze(self) -> 'cirq.FrozenCircuit':
177178
"""Creates a FrozenCircuit from this circuit.
178179
179180
If 'self' is a FrozenCircuit, the original object is returned.
180181
"""
181-
from cirq.circuits import FrozenCircuit
182-
183-
if isinstance(self, FrozenCircuit):
184-
return self
185-
186-
return FrozenCircuit(self, strategy=InsertStrategy.EARLIEST)
187182

183+
@abc.abstractmethod
188184
def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
189185
"""Creates a Circuit from this circuit.
190186
191187
Args:
192188
copy: If True and 'self' is a Circuit, returns a copy that circuit.
193189
"""
194-
if isinstance(self, Circuit):
195-
return Circuit.copy(self) if copy else self
196-
197-
return Circuit(self, strategy=InsertStrategy.EARLIEST)
198190

199191
def __bool__(self):
200192
return bool(self.moments)
@@ -1743,7 +1735,10 @@ def __init__(
17431735
together. This option does not affect later insertions into the
17441736
circuit.
17451737
"""
1738+
# Implementation note: we set self._frozen = None any time self._moments
1739+
# is mutated, to "invalidate" the frozen instance.
17461740
self._moments: List['cirq.Moment'] = []
1741+
self._frozen: Optional['cirq.FrozenCircuit'] = None
17471742
flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents))
17481743
if all(isinstance(c, Moment) for c in flattened_contents):
17491744
self._moments[:] = cast(Iterable[Moment], flattened_contents)
@@ -1810,12 +1805,29 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
18101805
for i in range(length):
18111806
if i in moments_by_index:
18121807
self._moments.append(moments_by_index[i].with_operations(op_lists_by_index[i]))
1808+
self._frozen = None
18131809
else:
18141810
self._moments.append(Moment(op_lists_by_index[i]))
1811+
self._frozen = None
18151812

18161813
def __copy__(self) -> 'cirq.Circuit':
18171814
return self.copy()
18181815

1816+
def freeze(self) -> 'cirq.FrozenCircuit':
1817+
"""Gets a frozen version of this circuit.
1818+
1819+
Repeated calls to `.freeze()` will return the same FrozenCircuit
1820+
instance as long as this circuit is not mutated.
1821+
"""
1822+
from cirq.circuits import FrozenCircuit
1823+
1824+
if self._frozen is None:
1825+
self._frozen = FrozenCircuit.from_moments(*self._moments)
1826+
return self._frozen
1827+
1828+
def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
1829+
return self.copy() if copy else self
1830+
18191831
def copy(self) -> 'Circuit':
18201832
"""Return a copy of this circuit."""
18211833
copied_circuit = Circuit()
@@ -1841,11 +1853,13 @@ def __setitem__(self, key, value):
18411853
raise TypeError('Can only assign Moments into Circuits.')
18421854

18431855
self._moments[key] = value
1856+
self._frozen = None
18441857

18451858
# pylint: enable=function-redefined
18461859

18471860
def __delitem__(self, key: Union[int, slice]):
18481861
del self._moments[key]
1862+
self._frozen = None
18491863

18501864
def __iadd__(self, other):
18511865
self.append(other)
@@ -1874,6 +1888,7 @@ def __imul__(self, repetitions: _INT_TYPE):
18741888
if not isinstance(repetitions, (int, np.integer)):
18751889
return NotImplemented
18761890
self._moments *= int(repetitions)
1891+
self._frozen = None
18771892
return self
18781893

18791894
def __mul__(self, repetitions: _INT_TYPE):
@@ -2017,6 +2032,7 @@ def _pick_or_create_inserted_op_moment_index(
20172032

20182033
if strategy is InsertStrategy.NEW or strategy is InsertStrategy.NEW_THEN_INLINE:
20192034
self._moments.insert(splitter_index, Moment())
2035+
self._frozen = None
20202036
return splitter_index
20212037

20222038
if strategy is InsertStrategy.INLINE:
@@ -2074,13 +2090,16 @@ def insert(
20742090
for moment_or_op in list(ops.flatten_to_ops_or_moments(moment_or_operation_tree)):
20752091
if isinstance(moment_or_op, Moment):
20762092
self._moments.insert(k, moment_or_op)
2093+
self._frozen = None
20772094
k += 1
20782095
else:
20792096
op = moment_or_op
20802097
p = self._pick_or_create_inserted_op_moment_index(k, op, strategy)
20812098
while p >= len(self._moments):
20822099
self._moments.append(Moment())
2100+
self._frozen = None
20832101
self._moments[p] = self._moments[p].with_operation(op)
2102+
self._frozen = None
20842103
k = max(k, p + 1)
20852104
if strategy is InsertStrategy.NEW_THEN_INLINE:
20862105
strategy = InsertStrategy.INLINE
@@ -2119,6 +2138,7 @@ def insert_into_range(self, operations: 'cirq.OP_TREE', start: int, end: int) ->
21192138
break
21202139

21212140
self._moments[i] = self._moments[i].with_operation(op)
2141+
self._frozen = None
21222142
op_index += 1
21232143

21242144
if op_index >= len(flat_ops):
@@ -2165,6 +2185,7 @@ def _push_frontier(
21652185
if n_new_moments > 0:
21662186
insert_index = min(late_frontier.values())
21672187
self._moments[insert_index:insert_index] = [Moment()] * n_new_moments
2188+
self._frozen = None
21682189
for q in update_qubits:
21692190
if early_frontier.get(q, 0) > insert_index:
21702191
early_frontier[q] += n_new_moments
@@ -2191,13 +2212,13 @@ def _insert_operations(
21912212
if len(operations) != len(insertion_indices):
21922213
raise ValueError('operations and insertion_indices must have the same length.')
21932214
self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))]
2215+
self._frozen = None
21942216
moment_to_ops: Dict[int, List['cirq.Operation']] = defaultdict(list)
21952217
for op_index, moment_index in enumerate(insertion_indices):
21962218
moment_to_ops[moment_index].append(operations[op_index])
21972219
for moment_index, new_ops in moment_to_ops.items():
2198-
self._moments[moment_index] = Moment(
2199-
self._moments[moment_index].operations + tuple(new_ops)
2200-
)
2220+
self._moments[moment_index] = self._moments[moment_index].with_operations(*new_ops)
2221+
self._frozen = None
22012222

22022223
def insert_at_frontier(
22032224
self,
@@ -2259,6 +2280,7 @@ def batch_remove(self, removals: Iterable[Tuple[int, 'cirq.Operation']]) -> None
22592280
old_op for old_op in copy._moments[i].operations if op != old_op
22602281
)
22612282
self._moments = copy._moments
2283+
self._frozen = None
22622284

22632285
def batch_replace(
22642286
self, replacements: Iterable[Tuple[int, 'cirq.Operation', 'cirq.Operation']]
@@ -2283,6 +2305,7 @@ def batch_replace(
22832305
old_op if old_op != op else new_op for old_op in copy._moments[i].operations
22842306
)
22852307
self._moments = copy._moments
2308+
self._frozen = None
22862309

22872310
def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None:
22882311
"""Inserts operations into empty spaces in existing moments.
@@ -2303,6 +2326,7 @@ def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']])
23032326
for i, insertions in insert_intos:
23042327
copy._moments[i] = copy._moments[i].with_operations(insertions)
23052328
self._moments = copy._moments
2329+
self._frozen = None
23062330

23072331
def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None:
23082332
"""Applies a batched insert operation to the circuit.
@@ -2337,6 +2361,7 @@ def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None
23372361
if next_index > insert_index:
23382362
shift += next_index - insert_index
23392363
self._moments = copy._moments
2364+
self._frozen = None
23402365

23412366
def append(
23422367
self,
@@ -2367,6 +2392,7 @@ def clear_operations_touching(
23672392
for k in moment_indices:
23682393
if 0 <= k < len(self._moments):
23692394
self._moments[k] = self._moments[k].without_operations_touching(qubits)
2395+
self._frozen = None
23702396

23712397
@property
23722398
def moments(self) -> Sequence['cirq.Moment']:

cirq-core/cirq/circuits/circuit_test.py

+20
Original file line numberDiff line numberDiff line change
@@ -4521,6 +4521,26 @@ def test_freeze_not_relocate_moments():
45214521
assert [mc is fc for mc, fc in zip(c, f)] == [True, True]
45224522

45234523

4524+
def test_freeze_returns_same_instance_if_not_mutated():
4525+
q = cirq.q(0)
4526+
c = cirq.Circuit(cirq.X(q), cirq.measure(q))
4527+
f0 = c.freeze()
4528+
f1 = c.freeze()
4529+
assert f1 is f0
4530+
4531+
c.append(cirq.Y(q))
4532+
f2 = c.freeze()
4533+
f3 = c.freeze()
4534+
assert f2 is not f1
4535+
assert f3 is f2
4536+
4537+
c[-1] = cirq.Moment(cirq.Y(q))
4538+
f4 = c.freeze()
4539+
f5 = c.freeze()
4540+
assert f4 is not f3
4541+
assert f5 is f4
4542+
4543+
45244544
def test_factorize_one_factor():
45254545
circuit = cirq.Circuit()
45264546
q0, q1, q2 = cirq.LineQubit.range(3)

cirq-core/cirq/circuits/frozen_circuit.py

+6
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
7979
def moments(self) -> Sequence['cirq.Moment']:
8080
return self._moments
8181

82+
def freeze(self) -> 'cirq.FrozenCircuit':
83+
return self
84+
85+
def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
86+
return Circuit.from_moments(*self)
87+
8288
@property
8389
def tags(self) -> Tuple[Hashable, ...]:
8490
"""Returns a tuple of the Circuit's tags."""

0 commit comments

Comments
 (0)