Skip to content

Use @cached_method for FrozenCircuit properties #5707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion cirq-core/cirq/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dataclasses
import functools
import importlib
import inspect
import os
import re
import sys
Expand Down Expand Up @@ -67,7 +68,19 @@ def bar(self, name: str) -> int:
"""

def decorator(func):
cache_name = f'_{func.__name__}_cache'
cache_name = _method_cache_name(func)
signature = inspect.signature(func)

if len(signature.parameters) == 1:
# Optimization in the case where the method takes no arguments other than `self`.

@functools.wraps(func)
def wrapped_no_args(self):
if not hasattr(self, cache_name):
object.__setattr__(self, cache_name, func(self))
return getattr(self, cache_name)

return wrapped_no_args

@functools.wraps(func)
def wrapped(self, *args, **kwargs):
Expand All @@ -87,6 +100,11 @@ def cached_func(*args, **kwargs):
return decorator if method is None else decorator(method)


def _method_cache_name(func: Callable) -> str:
# Use single-underscore prefix to avoid name mangling (for tests).
return f'_method_cache_{func.__name__}'


def proper_repr(value: Any) -> str:
"""Overrides sympy and numpy returning repr strings that don't parse."""

Expand Down
8 changes: 5 additions & 3 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import cirq
import cirq.circuits.circuit_operation as circuit_operation
from cirq import _compat
from cirq.circuits.circuit_operation import _full_join_string_lists

ALL_SIMULATORS = (cirq.Simulator(), cirq.DensityMatrixSimulator(), cirq.CliffordSimulator())
Expand Down Expand Up @@ -90,10 +91,11 @@ def test_is_measurement_memoization():
a = cirq.LineQubit(0)
circuit = cirq.FrozenCircuit(cirq.measure(a, key='m'))
c_op = cirq.CircuitOperation(circuit)
assert circuit._has_measurements is None
# Memoize `_has_measurements` in the circuit.
cache_name = _compat._method_cache_name(circuit._is_measurement_)
assert not hasattr(circuit, cache_name)
# Memoize `_is_measurement_` in the circuit.
assert cirq.is_measurement(c_op)
assert circuit._has_measurements is True
assert hasattr(circuit, cache_name)


def test_invalid_measurement_keys():
Expand Down
70 changes: 24 additions & 46 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""An immutable version of the Circuit data structure."""
from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Sequence, Tuple, Union

import numpy as np

from cirq import ops, protocols
from cirq import protocols, _compat
from cirq.circuits import AbstractCircuit, Alignment, Circuit
from cirq.circuits.insert_strategy import InsertStrategy
from cirq.type_workarounds import NotImplementedType
Expand Down Expand Up @@ -51,81 +51,59 @@ def __init__(
base = Circuit(contents, strategy=strategy)
self._moments = tuple(base.moments)

# These variables are memoized when first requested.
self._num_qubits: Optional[int] = None
self._unitary: Optional[Union[np.ndarray, NotImplementedType]] = None
self._qid_shape: Optional[Tuple[int, ...]] = None
self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None
self._all_operations: Optional[Tuple[ops.Operation, ...]] = None
self._has_measurements: Optional[bool] = None
self._all_measurement_key_objs: Optional[FrozenSet['cirq.MeasurementKey']] = None
self._are_all_measurements_terminal: Optional[bool] = None
self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None

@property
def moments(self) -> Sequence['cirq.Moment']:
return self._moments

def __hash__(self):
return hash((self.moments,))

# Memoized methods for commonly-retrieved properties.

@_compat.cached_method
def _num_qubits_(self) -> int:
if self._num_qubits is None:
self._num_qubits = len(self.all_qubits())
return self._num_qubits
return len(self.all_qubits())

@_compat.cached_method
def _qid_shape_(self) -> Tuple[int, ...]:
if self._qid_shape is None:
self._qid_shape = super()._qid_shape_()
return self._qid_shape
return super()._qid_shape_()

@_compat.cached_method
def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
if self._unitary is None:
self._unitary = super()._unitary_()
return self._unitary
return super()._unitary_()

@_compat.cached_method
def _is_measurement_(self) -> bool:
if self._has_measurements is None:
self._has_measurements = protocols.is_measurement(self.unfreeze())
return self._has_measurements
return protocols.is_measurement(self.unfreeze())

@_compat.cached_method
def all_qubits(self) -> FrozenSet['cirq.Qid']:
if self._all_qubits is None:
self._all_qubits = super().all_qubits()
return self._all_qubits
return super().all_qubits()

@_compat.cached_property
def _all_operations(self) -> Tuple['cirq.Operation', ...]:
return tuple(super().all_operations())

def all_operations(self) -> Iterator['cirq.Operation']:
if self._all_operations is None:
self._all_operations = tuple(super().all_operations())
return iter(self._all_operations)

def has_measurements(self) -> bool:
if self._has_measurements is None:
self._has_measurements = super().has_measurements()
return self._has_measurements
return self._is_measurement_()

@_compat.cached_method
def all_measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']:
if self._all_measurement_key_objs is None:
self._all_measurement_key_objs = super().all_measurement_key_objs()
return self._all_measurement_key_objs
return super().all_measurement_key_objs()

def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
return self.all_measurement_key_objs()

@_compat.cached_method
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
if self._control_keys is None:
self._control_keys = super()._control_keys_()
return self._control_keys
return super()._control_keys_()

@_compat.cached_method
def are_all_measurements_terminal(self) -> bool:
if self._are_all_measurements_terminal is None:
self._are_all_measurements_terminal = super().are_all_measurements_terminal()
return self._are_all_measurements_terminal

# End of memoized methods.
return super().are_all_measurements_terminal()

@_compat.cached_method
def all_measurement_key_names(self) -> FrozenSet[str]:
return frozenset(str(key) for key in self.all_measurement_key_objs())

Expand Down