Skip to content

Commit 291502b

Browse files
authored
Use @cached_method for FrozenCircuit properties (#5707)
This also includes an optimization in the `cached_method` decorator for the case of methods that take no arguments other than self (as is often the case with cirq protocol methods, for example).
1 parent 479398f commit 291502b

File tree

3 files changed

+48
-50
lines changed

3 files changed

+48
-50
lines changed

cirq-core/cirq/_compat.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dataclasses
1818
import functools
1919
import importlib
20+
import inspect
2021
import os
2122
import re
2223
import sys
@@ -67,7 +68,19 @@ def bar(self, name: str) -> int:
6768
"""
6869

6970
def decorator(func):
70-
cache_name = f'_{func.__name__}_cache'
71+
cache_name = _method_cache_name(func)
72+
signature = inspect.signature(func)
73+
74+
if len(signature.parameters) == 1:
75+
# Optimization in the case where the method takes no arguments other than `self`.
76+
77+
@functools.wraps(func)
78+
def wrapped_no_args(self):
79+
if not hasattr(self, cache_name):
80+
object.__setattr__(self, cache_name, func(self))
81+
return getattr(self, cache_name)
82+
83+
return wrapped_no_args
7184

7285
@functools.wraps(func)
7386
def wrapped(self, *args, **kwargs):
@@ -87,6 +100,11 @@ def cached_func(*args, **kwargs):
87100
return decorator if method is None else decorator(method)
88101

89102

103+
def _method_cache_name(func: Callable) -> str:
104+
# Use single-underscore prefix to avoid name mangling (for tests).
105+
return f'_method_cache_{func.__name__}'
106+
107+
90108
def proper_repr(value: Any) -> str:
91109
"""Overrides sympy and numpy returning repr strings that don't parse."""
92110

cirq-core/cirq/circuits/circuit_operation_test.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import cirq
2222
import cirq.circuits.circuit_operation as circuit_operation
23+
from cirq import _compat
2324
from cirq.circuits.circuit_operation import _full_join_string_lists
2425

2526
ALL_SIMULATORS = (cirq.Simulator(), cirq.DensityMatrixSimulator(), cirq.CliffordSimulator())
@@ -90,10 +91,11 @@ def test_is_measurement_memoization():
9091
a = cirq.LineQubit(0)
9192
circuit = cirq.FrozenCircuit(cirq.measure(a, key='m'))
9293
c_op = cirq.CircuitOperation(circuit)
93-
assert circuit._has_measurements is None
94-
# Memoize `_has_measurements` in the circuit.
94+
cache_name = _compat._method_cache_name(circuit._is_measurement_)
95+
assert not hasattr(circuit, cache_name)
96+
# Memoize `_is_measurement_` in the circuit.
9597
assert cirq.is_measurement(c_op)
96-
assert circuit._has_measurements is True
98+
assert hasattr(circuit, cache_name)
9799

98100

99101
def test_invalid_measurement_keys():

cirq-core/cirq/circuits/frozen_circuit.py

+24-46
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""An immutable version of the Circuit data structure."""
15-
from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Optional, Sequence, Tuple, Union
15+
from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Sequence, Tuple, Union
1616

1717
import numpy as np
1818

19-
from cirq import ops, protocols
19+
from cirq import protocols, _compat
2020
from cirq.circuits import AbstractCircuit, Alignment, Circuit
2121
from cirq.circuits.insert_strategy import InsertStrategy
2222
from cirq.type_workarounds import NotImplementedType
@@ -51,81 +51,59 @@ def __init__(
5151
base = Circuit(contents, strategy=strategy)
5252
self._moments = tuple(base.moments)
5353

54-
# These variables are memoized when first requested.
55-
self._num_qubits: Optional[int] = None
56-
self._unitary: Optional[Union[np.ndarray, NotImplementedType]] = None
57-
self._qid_shape: Optional[Tuple[int, ...]] = None
58-
self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None
59-
self._all_operations: Optional[Tuple[ops.Operation, ...]] = None
60-
self._has_measurements: Optional[bool] = None
61-
self._all_measurement_key_objs: Optional[FrozenSet['cirq.MeasurementKey']] = None
62-
self._are_all_measurements_terminal: Optional[bool] = None
63-
self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None
64-
6554
@property
6655
def moments(self) -> Sequence['cirq.Moment']:
6756
return self._moments
6857

6958
def __hash__(self):
7059
return hash((self.moments,))
7160

72-
# Memoized methods for commonly-retrieved properties.
73-
61+
@_compat.cached_method
7462
def _num_qubits_(self) -> int:
75-
if self._num_qubits is None:
76-
self._num_qubits = len(self.all_qubits())
77-
return self._num_qubits
63+
return len(self.all_qubits())
7864

65+
@_compat.cached_method
7966
def _qid_shape_(self) -> Tuple[int, ...]:
80-
if self._qid_shape is None:
81-
self._qid_shape = super()._qid_shape_()
82-
return self._qid_shape
67+
return super()._qid_shape_()
8368

69+
@_compat.cached_method
8470
def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
85-
if self._unitary is None:
86-
self._unitary = super()._unitary_()
87-
return self._unitary
71+
return super()._unitary_()
8872

73+
@_compat.cached_method
8974
def _is_measurement_(self) -> bool:
90-
if self._has_measurements is None:
91-
self._has_measurements = protocols.is_measurement(self.unfreeze())
92-
return self._has_measurements
75+
return protocols.is_measurement(self.unfreeze())
9376

77+
@_compat.cached_method
9478
def all_qubits(self) -> FrozenSet['cirq.Qid']:
95-
if self._all_qubits is None:
96-
self._all_qubits = super().all_qubits()
97-
return self._all_qubits
79+
return super().all_qubits()
80+
81+
@_compat.cached_property
82+
def _all_operations(self) -> Tuple['cirq.Operation', ...]:
83+
return tuple(super().all_operations())
9884

9985
def all_operations(self) -> Iterator['cirq.Operation']:
100-
if self._all_operations is None:
101-
self._all_operations = tuple(super().all_operations())
10286
return iter(self._all_operations)
10387

10488
def has_measurements(self) -> bool:
105-
if self._has_measurements is None:
106-
self._has_measurements = super().has_measurements()
107-
return self._has_measurements
89+
return self._is_measurement_()
10890

91+
@_compat.cached_method
10992
def all_measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']:
110-
if self._all_measurement_key_objs is None:
111-
self._all_measurement_key_objs = super().all_measurement_key_objs()
112-
return self._all_measurement_key_objs
93+
return super().all_measurement_key_objs()
11394

11495
def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
11596
return self.all_measurement_key_objs()
11697

98+
@_compat.cached_method
11799
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
118-
if self._control_keys is None:
119-
self._control_keys = super()._control_keys_()
120-
return self._control_keys
100+
return super()._control_keys_()
121101

102+
@_compat.cached_method
122103
def are_all_measurements_terminal(self) -> bool:
123-
if self._are_all_measurements_terminal is None:
124-
self._are_all_measurements_terminal = super().are_all_measurements_terminal()
125-
return self._are_all_measurements_terminal
126-
127-
# End of memoized methods.
104+
return super().are_all_measurements_terminal()
128105

106+
@_compat.cached_method
129107
def all_measurement_key_names(self) -> FrozenSet[str]:
130108
return frozenset(str(key) for key in self.all_measurement_key_objs())
131109

0 commit comments

Comments
 (0)