From 9ad260c4d3bbcb8b227d006857c288577b1ca416 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Thu, 23 Jun 2022 11:09:37 -0700 Subject: [PATCH 1/2] Use `@cached_method` for FrozenCircuit properties --- cirq-core/cirq/_compat.py | 14 +++++ cirq-core/cirq/circuits/frozen_circuit.py | 70 ++++++++--------------- 2 files changed, 38 insertions(+), 46 deletions(-) diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index 7cc28cdf559..76e9e9656d0 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -17,6 +17,7 @@ import dataclasses import functools import importlib +import inspect import os import re import sys @@ -69,6 +70,19 @@ def bar(self, name: str) -> int: def decorator(func): cache_name = f'_{func.__name__}_cache' + signature = inspect.signature(func) + + if len(signature.parameters) == 1: + # Optimization in the case where the method takes no arguments other than `self`. + + @functools.wraps(func) + def wrapped_no_args(self): + if not hasattr(self, cache_name): + object.__setattr__(self, cache_name, func(self)) + return getattr(self, cache_name) + + return wrapped_no_args + @functools.wraps(func) def wrapped(self, *args, **kwargs): cached = getattr(self, cache_name, None) diff --git a/cirq-core/cirq/circuits/frozen_circuit.py b/cirq-core/cirq/circuits/frozen_circuit.py index bc735439619..1aae526c3b1 100644 --- a/cirq-core/cirq/circuits/frozen_circuit.py +++ b/cirq-core/cirq/circuits/frozen_circuit.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """An immutable version of the Circuit data structure.""" -from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Sequence, Tuple, Union import numpy as np -from cirq import ops, protocols +from cirq import protocols, _compat from cirq.circuits import AbstractCircuit, Alignment, Circuit from cirq.circuits.insert_strategy import InsertStrategy from cirq.type_workarounds import NotImplementedType @@ -51,17 +51,6 @@ def __init__( base = Circuit(contents, strategy=strategy) self._moments = tuple(base.moments) - # These variables are memoized when first requested. - self._num_qubits: Optional[int] = None - self._unitary: Optional[Union[np.ndarray, NotImplementedType]] = None - self._qid_shape: Optional[Tuple[int, ...]] = None - self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None - self._all_operations: Optional[Tuple[ops.Operation, ...]] = None - self._has_measurements: Optional[bool] = None - self._all_measurement_key_objs: Optional[FrozenSet['cirq.MeasurementKey']] = None - self._are_all_measurements_terminal: Optional[bool] = None - self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None - @property def moments(self) -> Sequence['cirq.Moment']: return self._moments @@ -69,63 +58,52 @@ def moments(self) -> Sequence['cirq.Moment']: def __hash__(self): return hash((self.moments,)) - # Memoized methods for commonly-retrieved properties. - + @_compat.cached_method def _num_qubits_(self) -> int: - if self._num_qubits is None: - self._num_qubits = len(self.all_qubits()) - return self._num_qubits + return len(self.all_qubits()) + @_compat.cached_method def _qid_shape_(self) -> Tuple[int, ...]: - if self._qid_shape is None: - self._qid_shape = super()._qid_shape_() - return self._qid_shape + return super()._qid_shape_() + @_compat.cached_method def _unitary_(self) -> Union[np.ndarray, NotImplementedType]: - if self._unitary is None: - self._unitary = super()._unitary_() - return self._unitary + return super()._unitary_() + @_compat.cached_method def _is_measurement_(self) -> bool: - if self._has_measurements is None: - self._has_measurements = protocols.is_measurement(self.unfreeze()) - return self._has_measurements + return protocols.is_measurement(self.unfreeze()) + @_compat.cached_method def all_qubits(self) -> FrozenSet['cirq.Qid']: - if self._all_qubits is None: - self._all_qubits = super().all_qubits() - return self._all_qubits + return super().all_qubits() + + @_compat.cached_property + def _all_operations(self) -> Tuple['cirq.Operation', ...]: + return tuple(super().all_operations()) def all_operations(self) -> Iterator['cirq.Operation']: - if self._all_operations is None: - self._all_operations = tuple(super().all_operations()) return iter(self._all_operations) def has_measurements(self) -> bool: - if self._has_measurements is None: - self._has_measurements = super().has_measurements() - return self._has_measurements + return self._is_measurement_() + @_compat.cached_method def all_measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']: - if self._all_measurement_key_objs is None: - self._all_measurement_key_objs = super().all_measurement_key_objs() - return self._all_measurement_key_objs + return super().all_measurement_key_objs() def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: return self.all_measurement_key_objs() + @_compat.cached_method def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: - if self._control_keys is None: - self._control_keys = super()._control_keys_() - return self._control_keys + return super()._control_keys_() + @_compat.cached_method def are_all_measurements_terminal(self) -> bool: - if self._are_all_measurements_terminal is None: - self._are_all_measurements_terminal = super().are_all_measurements_terminal() - return self._are_all_measurements_terminal - - # End of memoized methods. + return super().are_all_measurements_terminal() + @_compat.cached_method def all_measurement_key_names(self) -> FrozenSet[str]: return frozenset(str(key) for key in self.all_measurement_key_objs()) From 2cd2de93662c0f2ac9ada3752c6a550de1cdce6d Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Mon, 11 Jul 2022 09:11:55 -0700 Subject: [PATCH 2/2] Fix test of _is_measurement_ memoization --- cirq-core/cirq/_compat.py | 8 ++++++-- cirq-core/cirq/circuits/circuit_operation_test.py | 8 +++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index 76e9e9656d0..cf123d0408b 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -68,8 +68,7 @@ def bar(self, name: str) -> int: """ def decorator(func): - cache_name = f'_{func.__name__}_cache' - + cache_name = _method_cache_name(func) signature = inspect.signature(func) if len(signature.parameters) == 1: @@ -101,6 +100,11 @@ def cached_func(*args, **kwargs): return decorator if method is None else decorator(method) +def _method_cache_name(func: Callable) -> str: + # Use single-underscore prefix to avoid name mangling (for tests). + return f'_method_cache_{func.__name__}' + + def proper_repr(value: Any) -> str: """Overrides sympy and numpy returning repr strings that don't parse.""" diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index c8cc0be2e70..0c50be992a9 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -20,6 +20,7 @@ import cirq import cirq.circuits.circuit_operation as circuit_operation +from cirq import _compat from cirq.circuits.circuit_operation import _full_join_string_lists ALL_SIMULATORS = (cirq.Simulator(), cirq.DensityMatrixSimulator(), cirq.CliffordSimulator()) @@ -90,10 +91,11 @@ def test_is_measurement_memoization(): a = cirq.LineQubit(0) circuit = cirq.FrozenCircuit(cirq.measure(a, key='m')) c_op = cirq.CircuitOperation(circuit) - assert circuit._has_measurements is None - # Memoize `_has_measurements` in the circuit. + cache_name = _compat._method_cache_name(circuit._is_measurement_) + assert not hasattr(circuit, cache_name) + # Memoize `_is_measurement_` in the circuit. assert cirq.is_measurement(c_op) - assert circuit._has_measurements is True + assert hasattr(circuit, cache_name) def test_invalid_measurement_keys():