Skip to content

Commit 2f7d732

Browse files
authored
Cache Circuit properties between mutations (#6322)
This caches various computed properties on `Circuit` so that they do not need to be recomputed when accessed if the circuit has not been mutated. Any mutations cause these properties to be invalidated so that they will be recomputed the next time they are accessed.
1 parent 5485227 commit 2f7d732

File tree

3 files changed

+173
-15
lines changed

3 files changed

+173
-15
lines changed

cirq-core/cirq/circuits/circuit.py

+73-15
Original file line numberDiff line numberDiff line change
@@ -188,28 +188,20 @@ def _from_moments(cls: Type[CIRCUIT_TYPE], moments: Iterable['cirq.Moment']) ->
188188
def moments(self) -> Sequence['cirq.Moment']:
189189
pass
190190

191+
@abc.abstractmethod
191192
def freeze(self) -> 'cirq.FrozenCircuit':
192193
"""Creates a FrozenCircuit from this circuit.
193194
194195
If 'self' is a FrozenCircuit, the original object is returned.
195196
"""
196-
from cirq.circuits import FrozenCircuit
197-
198-
if isinstance(self, FrozenCircuit):
199-
return self
200-
201-
return FrozenCircuit(self, strategy=InsertStrategy.EARLIEST)
202197

198+
@abc.abstractmethod
203199
def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
204200
"""Creates a Circuit from this circuit.
205201
206202
Args:
207203
copy: If True and 'self' is a Circuit, returns a copy that circuit.
208204
"""
209-
if isinstance(self, Circuit):
210-
return Circuit.copy(self) if copy else self
211-
212-
return Circuit(self, strategy=InsertStrategy.EARLIEST)
213205

214206
def __bool__(self):
215207
return bool(self.moments)
@@ -822,6 +814,9 @@ def has_measurements(self):
822814
"""
823815
return protocols.is_measurement(self)
824816

817+
def _is_measurement_(self) -> bool:
818+
return any(protocols.is_measurement(op) for op in self.all_operations())
819+
825820
def are_all_measurements_terminal(self) -> bool:
826821
"""Whether all measurement gates are at the end of the circuit.
827822
@@ -1383,8 +1378,7 @@ def save_qasm(
13831378
self._to_qasm_output(header, precision, qubit_order).save(file_path)
13841379

13851380
def _json_dict_(self):
1386-
ret = protocols.obj_to_dict_helper(self, ['moments'])
1387-
return ret
1381+
return protocols.obj_to_dict_helper(self, ['moments'])
13881382

13891383
@classmethod
13901384
def _from_json_dict_(cls, moments, **kwargs):
@@ -1759,6 +1753,16 @@ def __init__(
17591753
circuit.
17601754
"""
17611755
self._moments: List['cirq.Moment'] = []
1756+
1757+
# Implementation note: the following cached properties are set lazily and then
1758+
# invalidated and reset to None in `self._mutated()`, which is called any time
1759+
# `self._moments` is changed.
1760+
self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None
1761+
self._frozen: Optional['cirq.FrozenCircuit'] = None
1762+
self._is_measurement: Optional[bool] = None
1763+
self._is_parameterized: Optional[bool] = None
1764+
self._parameter_names: Optional[AbstractSet[str]] = None
1765+
17621766
flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents))
17631767
if all(isinstance(c, Moment) for c in flattened_contents):
17641768
self._moments[:] = cast(Iterable[Moment], flattened_contents)
@@ -1769,6 +1773,14 @@ def __init__(
17691773
else:
17701774
self.append(flattened_contents, strategy=strategy)
17711775

1776+
def _mutated(self) -> None:
1777+
"""Clear cached properties in response to this circuit being mutated."""
1778+
self._all_qubits = None
1779+
self._frozen = None
1780+
self._is_measurement = None
1781+
self._is_parameterized = None
1782+
self._parameter_names = None
1783+
17721784
@classmethod
17731785
def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'Circuit':
17741786
new_circuit = Circuit()
@@ -1831,6 +1843,41 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
18311843
def __copy__(self) -> 'cirq.Circuit':
18321844
return self.copy()
18331845

1846+
def freeze(self) -> 'cirq.FrozenCircuit':
1847+
"""Gets a frozen version of this circuit.
1848+
1849+
Repeated calls to `.freeze()` will return the same FrozenCircuit
1850+
instance as long as this circuit is not mutated.
1851+
"""
1852+
from cirq.circuits.frozen_circuit import FrozenCircuit
1853+
1854+
if self._frozen is None:
1855+
self._frozen = FrozenCircuit.from_moments(*self._moments)
1856+
return self._frozen
1857+
1858+
def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
1859+
return self.copy() if copy else self
1860+
1861+
def all_qubits(self) -> FrozenSet['cirq.Qid']:
1862+
if self._all_qubits is None:
1863+
self._all_qubits = super().all_qubits()
1864+
return self._all_qubits
1865+
1866+
def _is_measurement_(self) -> bool:
1867+
if self._is_measurement is None:
1868+
self._is_measurement = super()._is_measurement_()
1869+
return self._is_measurement
1870+
1871+
def _is_parameterized_(self) -> bool:
1872+
if self._is_parameterized is None:
1873+
self._is_parameterized = super()._is_parameterized_()
1874+
return self._is_parameterized
1875+
1876+
def _parameter_names_(self) -> AbstractSet[str]:
1877+
if self._parameter_names is None:
1878+
self._parameter_names = super()._parameter_names_()
1879+
return self._parameter_names
1880+
18341881
def copy(self) -> 'Circuit':
18351882
"""Return a copy of this circuit."""
18361883
copied_circuit = Circuit()
@@ -1856,11 +1903,13 @@ def __setitem__(self, key, value):
18561903
raise TypeError('Can only assign Moments into Circuits.')
18571904

18581905
self._moments[key] = value
1906+
self._mutated()
18591907

18601908
# pylint: enable=function-redefined
18611909

18621910
def __delitem__(self, key: Union[int, slice]):
18631911
del self._moments[key]
1912+
self._mutated()
18641913

18651914
def __iadd__(self, other):
18661915
self.append(other)
@@ -1889,6 +1938,7 @@ def __imul__(self, repetitions: _INT_TYPE):
18891938
if not isinstance(repetitions, (int, np.integer)):
18901939
return NotImplemented
18911940
self._moments *= int(repetitions)
1941+
self._mutated()
18921942
return self
18931943

18941944
def __mul__(self, repetitions: _INT_TYPE):
@@ -2032,6 +2082,7 @@ def _pick_or_create_inserted_op_moment_index(
20322082

20332083
if strategy is InsertStrategy.NEW or strategy is InsertStrategy.NEW_THEN_INLINE:
20342084
self._moments.insert(splitter_index, Moment())
2085+
self._mutated()
20352086
return splitter_index
20362087

20372088
if strategy is InsertStrategy.INLINE:
@@ -2099,6 +2150,7 @@ def insert(
20992150
k = max(k, p + 1)
21002151
if strategy is InsertStrategy.NEW_THEN_INLINE:
21012152
strategy = InsertStrategy.INLINE
2153+
self._mutated()
21022154
return k
21032155

21042156
def insert_into_range(self, operations: 'cirq.OP_TREE', start: int, end: int) -> int:
@@ -2135,6 +2187,7 @@ def insert_into_range(self, operations: 'cirq.OP_TREE', start: int, end: int) ->
21352187

21362188
self._moments[i] = self._moments[i].with_operation(op)
21372189
op_index += 1
2190+
self._mutated()
21382191

21392192
if op_index >= len(flat_ops):
21402193
return end
@@ -2180,6 +2233,7 @@ def _push_frontier(
21802233
if n_new_moments > 0:
21812234
insert_index = min(late_frontier.values())
21822235
self._moments[insert_index:insert_index] = [Moment()] * n_new_moments
2236+
self._mutated()
21832237
for q in update_qubits:
21842238
if early_frontier.get(q, 0) > insert_index:
21852239
early_frontier[q] += n_new_moments
@@ -2206,13 +2260,12 @@ def _insert_operations(
22062260
if len(operations) != len(insertion_indices):
22072261
raise ValueError('operations and insertion_indices must have the same length.')
22082262
self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))]
2263+
self._mutated()
22092264
moment_to_ops: Dict[int, List['cirq.Operation']] = defaultdict(list)
22102265
for op_index, moment_index in enumerate(insertion_indices):
22112266
moment_to_ops[moment_index].append(operations[op_index])
22122267
for moment_index, new_ops in moment_to_ops.items():
2213-
self._moments[moment_index] = Moment(
2214-
self._moments[moment_index].operations + tuple(new_ops)
2215-
)
2268+
self._moments[moment_index] = self._moments[moment_index].with_operations(*new_ops)
22162269

22172270
def insert_at_frontier(
22182271
self,
@@ -2274,6 +2327,7 @@ def batch_remove(self, removals: Iterable[Tuple[int, 'cirq.Operation']]) -> None
22742327
old_op for old_op in copy._moments[i].operations if op != old_op
22752328
)
22762329
self._moments = copy._moments
2330+
self._mutated()
22772331

22782332
def batch_replace(
22792333
self, replacements: Iterable[Tuple[int, 'cirq.Operation', 'cirq.Operation']]
@@ -2298,6 +2352,7 @@ def batch_replace(
22982352
old_op if old_op != op else new_op for old_op in copy._moments[i].operations
22992353
)
23002354
self._moments = copy._moments
2355+
self._mutated()
23012356

23022357
def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None:
23032358
"""Inserts operations into empty spaces in existing moments.
@@ -2318,6 +2373,7 @@ def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']])
23182373
for i, insertions in insert_intos:
23192374
copy._moments[i] = copy._moments[i].with_operations(insertions)
23202375
self._moments = copy._moments
2376+
self._mutated()
23212377

23222378
def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None:
23232379
"""Applies a batched insert operation to the circuit.
@@ -2352,6 +2408,7 @@ def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None
23522408
if next_index > insert_index:
23532409
shift += next_index - insert_index
23542410
self._moments = copy._moments
2411+
self._mutated()
23552412

23562413
def append(
23572414
self,
@@ -2382,6 +2439,7 @@ def clear_operations_touching(
23822439
for k in moment_indices:
23832440
if 0 <= k < len(self._moments):
23842441
self._moments[k] = self._moments[k].without_operations_touching(qubits)
2442+
self._mutated()
23852443

23862444
@property
23872445
def moments(self) -> Sequence['cirq.Moment']:

cirq-core/cirq/circuits/circuit_test.py

+94
Original file line numberDiff line numberDiff line change
@@ -4533,6 +4533,100 @@ def test_freeze_not_relocate_moments():
45334533
assert [mc is fc for mc, fc in zip(c, f)] == [True, True]
45344534

45354535

4536+
def test_freeze_is_cached():
4537+
q = cirq.q(0)
4538+
c = cirq.Circuit(cirq.X(q), cirq.measure(q))
4539+
f0 = c.freeze()
4540+
f1 = c.freeze()
4541+
assert f1 is f0
4542+
4543+
c.append(cirq.Y(q))
4544+
f2 = c.freeze()
4545+
f3 = c.freeze()
4546+
assert f2 is not f1
4547+
assert f3 is f2
4548+
4549+
c[-1] = cirq.Moment(cirq.Y(q))
4550+
f4 = c.freeze()
4551+
f5 = c.freeze()
4552+
assert f4 is not f3
4553+
assert f5 is f4
4554+
4555+
4556+
@pytest.mark.parametrize(
4557+
"circuit, mutate",
4558+
[
4559+
(
4560+
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
4561+
lambda c: c.__setitem__(0, cirq.Moment(cirq.Y(cirq.q(0)))),
4562+
),
4563+
(cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__delitem__(0)),
4564+
(cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__imul__(2)),
4565+
(
4566+
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
4567+
lambda c: c.insert(1, cirq.Y(cirq.q(0))),
4568+
),
4569+
(
4570+
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
4571+
lambda c: c.insert_into_range([cirq.Y(cirq.q(1)), cirq.M(cirq.q(1))], 0, 2),
4572+
),
4573+
(
4574+
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
4575+
lambda c: c.insert_at_frontier([cirq.Y(cirq.q(0)), cirq.Y(cirq.q(1))], 1),
4576+
),
4577+
(
4578+
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
4579+
lambda c: c.batch_replace([(0, cirq.X(cirq.q(0)), cirq.Y(cirq.q(0)))]),
4580+
),
4581+
(
4582+
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0), cirq.q(1))),
4583+
lambda c: c.batch_insert_into([(0, cirq.X(cirq.q(1)))]),
4584+
),
4585+
(
4586+
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
4587+
lambda c: c.batch_insert([(1, cirq.Y(cirq.q(0)))]),
4588+
),
4589+
(
4590+
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
4591+
lambda c: c.clear_operations_touching([cirq.q(0)], [0]),
4592+
),
4593+
],
4594+
)
4595+
def test_mutation_clears_cached_attributes(circuit, mutate):
4596+
cached_attributes = [
4597+
"_all_qubits",
4598+
"_frozen",
4599+
"_is_measurement",
4600+
"_is_parameterized",
4601+
"_parameter_names",
4602+
]
4603+
4604+
for attr in cached_attributes:
4605+
assert getattr(circuit, attr) is None, f"{attr=} is not None"
4606+
4607+
# Check that attributes are cached after getting them.
4608+
qubits = circuit.all_qubits()
4609+
frozen = circuit.freeze()
4610+
is_measurement = cirq.is_measurement(circuit)
4611+
is_parameterized = cirq.is_parameterized(circuit)
4612+
parameter_names = cirq.parameter_names(circuit)
4613+
4614+
for attr in cached_attributes:
4615+
assert getattr(circuit, attr) is not None, f"{attr=} is None"
4616+
4617+
# Check that getting again returns same object.
4618+
assert circuit.all_qubits() is qubits
4619+
assert circuit.freeze() is frozen
4620+
assert cirq.is_measurement(circuit) is is_measurement
4621+
assert cirq.is_parameterized(circuit) is is_parameterized
4622+
assert cirq.parameter_names(circuit) is parameter_names
4623+
4624+
# Check that attributes are cleared after mutation.
4625+
mutate(circuit)
4626+
for attr in cached_attributes:
4627+
assert getattr(circuit, attr) is None, f"{attr=} is not None"
4628+
4629+
45364630
def test_factorize_one_factor():
45374631
circuit = cirq.Circuit()
45384632
q0, q1, q2 = cirq.LineQubit.range(3)

cirq-core/cirq/circuits/frozen_circuit.py

+6
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
7979
def moments(self) -> Sequence['cirq.Moment']:
8080
return self._moments
8181

82+
def freeze(self) -> 'cirq.FrozenCircuit':
83+
return self
84+
85+
def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
86+
return Circuit.from_moments(*self)
87+
8288
@property
8389
def tags(self) -> Tuple[Hashable, ...]:
8490
"""Returns a tuple of the Circuit's tags."""

0 commit comments

Comments
 (0)