Skip to content

Commit f88b750

Browse files
committed
Cache more circuit properties
1 parent d511136 commit f88b750

File tree

1 file changed

+53
-22
lines changed

1 file changed

+53
-22
lines changed

cirq-core/cirq/circuits/circuit.py

+53-22
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,9 @@ def has_measurements(self):
814814
"""
815815
return protocols.is_measurement(self)
816816

817+
def _is_measurement_(self) -> bool:
818+
return any(protocols.is_measurement(op) for op in self.all_operations())
819+
817820
def are_all_measurements_terminal(self) -> bool:
818821
"""Whether all measurement gates are at the end of the circuit.
819822
@@ -1375,8 +1378,7 @@ def save_qasm(
13751378
self._to_qasm_output(header, precision, qubit_order).save(file_path)
13761379

13771380
def _json_dict_(self):
1378-
ret = protocols.obj_to_dict_helper(self, ['moments'])
1379-
return ret
1381+
return protocols.obj_to_dict_helper(self, ['moments'])
13801382

13811383
@classmethod
13821384
def _from_json_dict_(cls, moments, **kwargs):
@@ -1750,10 +1752,17 @@ def __init__(
17501752
together. This option does not affect later insertions into the
17511753
circuit.
17521754
"""
1753-
# Implementation note: we set self._frozen = None any time self._moments
1754-
# is mutated, to "invalidate" the frozen instance.
17551755
self._moments: List['cirq.Moment'] = []
1756+
1757+
# Implementation note: the following cached properties are set lazily and then
1758+
# invalidated and reset to None in `self._mutated()`, which is called any time
1759+
# `self._moments` is changed.
1760+
self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None
17561761
self._frozen: Optional['cirq.FrozenCircuit'] = None
1762+
self._is_measurement: Optional[bool] = None
1763+
self._is_parameterized: Optional[bool] = None
1764+
self._parameter_names: Optional[AbstractSet[str]] = None
1765+
17571766
flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents))
17581767
if all(isinstance(c, Moment) for c in flattened_contents):
17591768
self._moments[:] = cast(Iterable[Moment], flattened_contents)
@@ -1764,6 +1773,12 @@ def __init__(
17641773
else:
17651774
self.append(flattened_contents, strategy=strategy)
17661775

1776+
def _mutated(self) -> None:
1777+
"""Clear cached properties in response to this circuit being mutated."""
1778+
self._all_qubits = None
1779+
self._frozen = None
1780+
self._is_parameterized = None
1781+
17671782
@classmethod
17681783
def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'Circuit':
17691784
new_circuit = Circuit()
@@ -1820,10 +1835,9 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
18201835
for i in range(length):
18211836
if i in moments_by_index:
18221837
self._moments.append(moments_by_index[i].with_operations(op_lists_by_index[i]))
1823-
self._frozen = None
18241838
else:
18251839
self._moments.append(Moment(op_lists_by_index[i]))
1826-
self._frozen = None
1840+
self._mutated()
18271841

18281842
def __copy__(self) -> 'cirq.Circuit':
18291843
return self.copy()
@@ -1843,6 +1857,26 @@ def freeze(self) -> 'cirq.FrozenCircuit':
18431857
def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
18441858
return self.copy() if copy else self
18451859

1860+
def all_qubits(self) -> FrozenSet['cirq.Qid']:
1861+
if self._all_qubits is None:
1862+
self._all_qubits = super().all_qubits()
1863+
return self._all_qubits
1864+
1865+
def _is_measurement_(self) -> bool:
1866+
if self._is_measurement is None:
1867+
self._is_measurement = super()._is_measurement_()
1868+
return self._is_measurement
1869+
1870+
def _is_parameterized_(self) -> bool:
1871+
if self._is_parameterized is None:
1872+
self._is_parameterized = super()._is_parameterized_()
1873+
return self._is_parameterized
1874+
1875+
def _parameter_names_(self) -> AbstractSet[str]:
1876+
if self._parameter_names is None:
1877+
self._parameter_names = super()._parameter_names_()
1878+
return self._parameter_names
1879+
18461880
def copy(self) -> 'Circuit':
18471881
"""Return a copy of this circuit."""
18481882
copied_circuit = Circuit()
@@ -1868,13 +1902,13 @@ def __setitem__(self, key, value):
18681902
raise TypeError('Can only assign Moments into Circuits.')
18691903

18701904
self._moments[key] = value
1871-
self._frozen = None
1905+
self._mutated()
18721906

18731907
# pylint: enable=function-redefined
18741908

18751909
def __delitem__(self, key: Union[int, slice]):
18761910
del self._moments[key]
1877-
self._frozen = None
1911+
self._mutated()
18781912

18791913
def __iadd__(self, other):
18801914
self.append(other)
@@ -1903,7 +1937,7 @@ def __imul__(self, repetitions: _INT_TYPE):
19031937
if not isinstance(repetitions, (int, np.integer)):
19041938
return NotImplemented
19051939
self._moments *= int(repetitions)
1906-
self._frozen = None
1940+
self._mutated()
19071941
return self
19081942

19091943
def __mul__(self, repetitions: _INT_TYPE):
@@ -2047,7 +2081,7 @@ def _pick_or_create_inserted_op_moment_index(
20472081

20482082
if strategy is InsertStrategy.NEW or strategy is InsertStrategy.NEW_THEN_INLINE:
20492083
self._moments.insert(splitter_index, Moment())
2050-
self._frozen = None
2084+
self._mutated()
20512085
return splitter_index
20522086

20532087
if strategy is InsertStrategy.INLINE:
@@ -2105,19 +2139,17 @@ def insert(
21052139
for moment_or_op in list(ops.flatten_to_ops_or_moments(moment_or_operation_tree)):
21062140
if isinstance(moment_or_op, Moment):
21072141
self._moments.insert(k, moment_or_op)
2108-
self._frozen = None
21092142
k += 1
21102143
else:
21112144
op = moment_or_op
21122145
p = self._pick_or_create_inserted_op_moment_index(k, op, strategy)
21132146
while p >= len(self._moments):
21142147
self._moments.append(Moment())
2115-
self._frozen = None
21162148
self._moments[p] = self._moments[p].with_operation(op)
2117-
self._frozen = None
21182149
k = max(k, p + 1)
21192150
if strategy is InsertStrategy.NEW_THEN_INLINE:
21202151
strategy = InsertStrategy.INLINE
2152+
self._mutated()
21212153
return k
21222154

21232155
def insert_into_range(self, operations: 'cirq.OP_TREE', start: int, end: int) -> int:
@@ -2153,8 +2185,8 @@ def insert_into_range(self, operations: 'cirq.OP_TREE', start: int, end: int) ->
21532185
break
21542186

21552187
self._moments[i] = self._moments[i].with_operation(op)
2156-
self._frozen = None
21572188
op_index += 1
2189+
self._mutated()
21582190

21592191
if op_index >= len(flat_ops):
21602192
return end
@@ -2200,7 +2232,7 @@ def _push_frontier(
22002232
if n_new_moments > 0:
22012233
insert_index = min(late_frontier.values())
22022234
self._moments[insert_index:insert_index] = [Moment()] * n_new_moments
2203-
self._frozen = None
2235+
self._mutated()
22042236
for q in update_qubits:
22052237
if early_frontier.get(q, 0) > insert_index:
22062238
early_frontier[q] += n_new_moments
@@ -2227,13 +2259,12 @@ def _insert_operations(
22272259
if len(operations) != len(insertion_indices):
22282260
raise ValueError('operations and insertion_indices must have the same length.')
22292261
self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))]
2230-
self._frozen = None
2262+
self._mutated()
22312263
moment_to_ops: Dict[int, List['cirq.Operation']] = defaultdict(list)
22322264
for op_index, moment_index in enumerate(insertion_indices):
22332265
moment_to_ops[moment_index].append(operations[op_index])
22342266
for moment_index, new_ops in moment_to_ops.items():
22352267
self._moments[moment_index] = self._moments[moment_index].with_operations(*new_ops)
2236-
self._frozen = None
22372268

22382269
def insert_at_frontier(
22392270
self,
@@ -2295,7 +2326,7 @@ def batch_remove(self, removals: Iterable[Tuple[int, 'cirq.Operation']]) -> None
22952326
old_op for old_op in copy._moments[i].operations if op != old_op
22962327
)
22972328
self._moments = copy._moments
2298-
self._frozen = None
2329+
self._mutated()
22992330

23002331
def batch_replace(
23012332
self, replacements: Iterable[Tuple[int, 'cirq.Operation', 'cirq.Operation']]
@@ -2320,7 +2351,7 @@ def batch_replace(
23202351
old_op if old_op != op else new_op for old_op in copy._moments[i].operations
23212352
)
23222353
self._moments = copy._moments
2323-
self._frozen = None
2354+
self._mutated()
23242355

23252356
def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None:
23262357
"""Inserts operations into empty spaces in existing moments.
@@ -2341,7 +2372,7 @@ def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']])
23412372
for i, insertions in insert_intos:
23422373
copy._moments[i] = copy._moments[i].with_operations(insertions)
23432374
self._moments = copy._moments
2344-
self._frozen = None
2375+
self._mutated()
23452376

23462377
def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None:
23472378
"""Applies a batched insert operation to the circuit.
@@ -2376,7 +2407,7 @@ def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None
23762407
if next_index > insert_index:
23772408
shift += next_index - insert_index
23782409
self._moments = copy._moments
2379-
self._frozen = None
2410+
self._mutated()
23802411

23812412
def append(
23822413
self,
@@ -2407,7 +2438,7 @@ def clear_operations_touching(
24072438
for k in moment_indices:
24082439
if 0 <= k < len(self._moments):
24092440
self._moments[k] = self._moments[k].without_operations_touching(qubits)
2410-
self._frozen = None
2441+
self._mutated()
24112442

24122443
@property
24132444
def moments(self) -> Sequence['cirq.Moment']:

0 commit comments

Comments
 (0)