Skip to content

Commit c912ea3

Browse files
authored
Document CIRCUIT_TYPE and hide other typevars/aliases in circuits.py (#5229)
Fixes #5150 (assuming this renders nicely on the docsite; how can I check that locally?) This adds an underscore prefix to hide some type aliases and type vars that are not part of the public interface of the module. Also adds a docstring to the `CIRCUIT_TYPE` variable, which is used in a few other places.
1 parent ca581d5 commit c912ea3

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

Diff for: cirq-core/cirq/circuits/circuit.py

+38-17
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252

5353
import cirq._version
5454
from cirq import _compat, devices, ops, protocols, qis
55+
from cirq._doc import document
5556
from cirq.circuits._bucket_priority_queue import BucketPriorityQueue
5657
from cirq.circuits.circuit_operation import CircuitOperation
5758
from cirq.circuits.insert_strategy import InsertStrategy
@@ -65,9 +66,30 @@
6566
if TYPE_CHECKING:
6667
import cirq
6768

68-
T_DESIRED_GATE_TYPE = TypeVar('T_DESIRED_GATE_TYPE', bound='ops.Gate')
69+
70+
_TGate = TypeVar('_TGate', bound='cirq.Gate')
71+
6972
CIRCUIT_TYPE = TypeVar('CIRCUIT_TYPE', bound='AbstractCircuit')
70-
INT_TYPE = Union[int, np.integer]
73+
document(
74+
CIRCUIT_TYPE,
75+
"""Type variable for an AbstractCircuit.
76+
77+
This can be used when writing generic functions that operate on circuits.
78+
For example, suppose we define the following function:
79+
80+
def foo(circuit: CIRCUIT_TYPE) -> CIRCUIT_TYPE:
81+
...
82+
83+
This lets the typechecker know that this function takes any kind of circuit
84+
and returns the same type, that is, if passed a `cirq.Circuit` it will return
85+
`cirq.Circuit`, and similarly if passed `cirq.FrozenCircuit` it will return
86+
`cirq.FrozenCircuit`. This is particularly useful for things like the
87+
transformer API, since it can preserve more type information than if we typed
88+
the function as taking and returning `cirq.AbstractCircuit`.
89+
""",
90+
)
91+
92+
_INT_TYPE = Union[int, np.integer]
7193

7294
_DEVICE_DEP_MESSAGE = 'Attaching devices to circuits will no longer be supported.'
7395

@@ -752,8 +774,8 @@ def findall_operations(
752774
yield index, op
753775

754776
def findall_operations_with_gate_type(
755-
self, gate_type: Type[T_DESIRED_GATE_TYPE]
756-
) -> Iterable[Tuple[int, 'cirq.GateOperation', T_DESIRED_GATE_TYPE]]:
777+
self, gate_type: Type[_TGate]
778+
) -> Iterable[Tuple[int, 'cirq.GateOperation', _TGate]]:
757779
"""Find the locations of all gate operations of a given type.
758780
759781
Args:
@@ -767,7 +789,7 @@ def findall_operations_with_gate_type(
767789
result = self.findall_operations(lambda operation: isinstance(operation.gate, gate_type))
768790
for index, op in result:
769791
gate_op = cast(ops.GateOperation, op)
770-
yield index, gate_op, cast(T_DESIRED_GATE_TYPE, gate_op.gate)
792+
yield index, gate_op, cast(_TGate, gate_op.gate)
771793

772794
def has_measurements(self):
773795
return protocols.is_measurement(self)
@@ -1818,20 +1840,20 @@ def __radd__(self, other):
18181840
# Needed for numpy to handle multiplication by np.int64 correctly.
18191841
__array_priority__ = 10000
18201842

1821-
def __imul__(self, repetitions: INT_TYPE):
1843+
def __imul__(self, repetitions: _INT_TYPE):
18221844
if not isinstance(repetitions, (int, np.integer)):
18231845
return NotImplemented
18241846
self._moments *= int(repetitions)
18251847
return self
18261848

1827-
def __mul__(self, repetitions: INT_TYPE):
1849+
def __mul__(self, repetitions: _INT_TYPE):
18281850
if not isinstance(repetitions, (int, np.integer)):
18291851
return NotImplemented
18301852
if self._device == cirq.UNCONSTRAINED_DEVICE:
18311853
return Circuit(self._moments * int(repetitions))
18321854
return Circuit(self._moments * int(repetitions), device=self._device)
18331855

1834-
def __rmul__(self, repetitions: INT_TYPE):
1856+
def __rmul__(self, repetitions: _INT_TYPE):
18351857
if not isinstance(repetitions, (int, np.integer)):
18361858
return NotImplemented
18371859
return self * int(repetitions)
@@ -2750,27 +2772,26 @@ def _list_repr_with_indented_item_lines(items: Sequence[Any]) -> str:
27502772
return f'[\n{indented}\n]'
27512773

27522774

2753-
TIn = TypeVar('TIn')
2754-
TOut = TypeVar('TOut')
2755-
TKey = TypeVar('TKey')
2775+
_TIn = TypeVar('_TIn')
2776+
_TOut = TypeVar('_TOut')
2777+
_TKey = TypeVar('_TKey')
27562778

27572779

27582780
@overload
27592781
def _group_until_different(
2760-
items: Iterable[TIn],
2761-
key: Callable[[TIn], TKey],
2762-
) -> Iterable[Tuple[TKey, List[TIn]]]:
2782+
items: Iterable[_TIn], key: Callable[[_TIn], _TKey]
2783+
) -> Iterable[Tuple[_TKey, List[_TIn]]]:
27632784
pass
27642785

27652786

27662787
@overload
27672788
def _group_until_different(
2768-
items: Iterable[TIn], key: Callable[[TIn], TKey], val: Callable[[TIn], TOut]
2769-
) -> Iterable[Tuple[TKey, List[TOut]]]:
2789+
items: Iterable[_TIn], key: Callable[[_TIn], _TKey], val: Callable[[_TIn], _TOut]
2790+
) -> Iterable[Tuple[_TKey, List[_TOut]]]:
27702791
pass
27712792

27722793

2773-
def _group_until_different(items: Iterable[TIn], key: Callable[[TIn], TKey], val=lambda e: e):
2794+
def _group_until_different(items: Iterable[_TIn], key: Callable[[_TIn], _TKey], val=lambda e: e):
27742795
"""Groups runs of items that are identical according to a keying function.
27752796
27762797
Args:

0 commit comments

Comments
 (0)