Skip to content

Document CIRCUIT_TYPE and hide other typevars/aliases in circuits.py #5229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 8, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

import cirq._version
from cirq import _compat, devices, ops, protocols, qis
from cirq._doc import document
from cirq.circuits._bucket_priority_queue import BucketPriorityQueue
from cirq.circuits.circuit_operation import CircuitOperation
from cirq.circuits.insert_strategy import InsertStrategy
Expand All @@ -65,9 +66,30 @@
if TYPE_CHECKING:
import cirq

T_DESIRED_GATE_TYPE = TypeVar('T_DESIRED_GATE_TYPE', bound='ops.Gate')

_TGate = TypeVar('_TGate', bound='cirq.Gate')

CIRCUIT_TYPE = TypeVar('CIRCUIT_TYPE', bound='AbstractCircuit')
INT_TYPE = Union[int, np.integer]
document(
CIRCUIT_TYPE,
"""Type variable for an AbstractCircuit.

This can be used when writing generic functions that operate on circuits.
For example, suppose we define the following function:

def foo(circuit: CIRCUIT_TYPE) -> CIRCUIT_TYPE:
...

This lets the typechecker know that this function takes any kind of circuit
and returns the same type, that is, if passed a `cirq.Circuit` it will return
`cirq.Circuit`, and similarly if passed `cirq.FrozenCircuit` it will return
`cirq.FrozenCircuit`. This is particularly useful for things like the
transformer API, since it can preserve more type information than if we typed
the function as taking and returning `cirq.AbstractCircuit`.
""",
)

_INT_TYPE = Union[int, np.integer]

_DEVICE_DEP_MESSAGE = 'Attaching devices to circuits will no longer be supported.'

Expand Down Expand Up @@ -752,8 +774,8 @@ def findall_operations(
yield index, op

def findall_operations_with_gate_type(
self, gate_type: Type[T_DESIRED_GATE_TYPE]
) -> Iterable[Tuple[int, 'cirq.GateOperation', T_DESIRED_GATE_TYPE]]:
self, gate_type: Type[_TGate]
) -> Iterable[Tuple[int, 'cirq.GateOperation', _TGate]]:
"""Find the locations of all gate operations of a given type.

Args:
Expand All @@ -767,7 +789,7 @@ def findall_operations_with_gate_type(
result = self.findall_operations(lambda operation: isinstance(operation.gate, gate_type))
for index, op in result:
gate_op = cast(ops.GateOperation, op)
yield index, gate_op, cast(T_DESIRED_GATE_TYPE, gate_op.gate)
yield index, gate_op, cast(_TGate, gate_op.gate)

def has_measurements(self):
return protocols.is_measurement(self)
Expand Down Expand Up @@ -1818,20 +1840,20 @@ def __radd__(self, other):
# Needed for numpy to handle multiplication by np.int64 correctly.
__array_priority__ = 10000

def __imul__(self, repetitions: INT_TYPE):
def __imul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
self._moments *= int(repetitions)
return self

def __mul__(self, repetitions: INT_TYPE):
def __mul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
if self._device == cirq.UNCONSTRAINED_DEVICE:
return Circuit(self._moments * int(repetitions))
return Circuit(self._moments * int(repetitions), device=self._device)

def __rmul__(self, repetitions: INT_TYPE):
def __rmul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
return self * int(repetitions)
Expand Down Expand Up @@ -2750,27 +2772,26 @@ def _list_repr_with_indented_item_lines(items: Sequence[Any]) -> str:
return f'[\n{indented}\n]'


TIn = TypeVar('TIn')
TOut = TypeVar('TOut')
TKey = TypeVar('TKey')
_TIn = TypeVar('_TIn')
_TOut = TypeVar('_TOut')
_TKey = TypeVar('_TKey')


@overload
def _group_until_different(
items: Iterable[TIn],
key: Callable[[TIn], TKey],
) -> Iterable[Tuple[TKey, List[TIn]]]:
items: Iterable[_TIn], key: Callable[[_TIn], _TKey]
) -> Iterable[Tuple[_TKey, List[_TIn]]]:
pass


@overload
def _group_until_different(
items: Iterable[TIn], key: Callable[[TIn], TKey], val: Callable[[TIn], TOut]
) -> Iterable[Tuple[TKey, List[TOut]]]:
items: Iterable[_TIn], key: Callable[[_TIn], _TKey], val: Callable[[_TIn], _TOut]
) -> Iterable[Tuple[_TKey, List[_TOut]]]:
pass


def _group_until_different(items: Iterable[TIn], key: Callable[[TIn], TKey], val=lambda e: e):
def _group_until_different(items: Iterable[_TIn], key: Callable[[_TIn], _TKey], val=lambda e: e):
"""Groups runs of items that are identical according to a keying function.

Args:
Expand Down