Skip to content

Commit 9aede58

Browse files
authored
Sanitize type annotations in cirq.circuits (#4776)
Follow-up to the cirq.sim one. Note this reduces things that cause recursive dependencies as well, so watch out for this in code reviews.
1 parent ff671ae commit 9aede58

File tree

7 files changed

+71
-65
lines changed

7 files changed

+71
-65
lines changed

cirq-core/cirq/circuits/circuit.py

+20-18
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def moments(self) -> Sequence['cirq.Moment']:
122122

123123
@property
124124
@abc.abstractmethod
125-
def device(self) -> devices.Device:
125+
def device(self) -> 'cirq.Device':
126126
pass
127127

128128
def freeze(self) -> 'cirq.FrozenCircuit':
@@ -589,7 +589,7 @@ def findall_operations_until_blocked(
589589
self,
590590
start_frontier: Dict['cirq.Qid', int],
591591
is_blocker: Callable[['cirq.Operation'], bool] = lambda op: False,
592-
) -> List[Tuple[int, ops.Operation]]:
592+
) -> List[Tuple[int, 'cirq.Operation']]:
593593
"""Finds all operations until a blocking operation is hit.
594594
595595
An operation is considered blocking if
@@ -740,7 +740,7 @@ def findall_operations(
740740

741741
def findall_operations_with_gate_type(
742742
self, gate_type: Type[T_DESIRED_GATE_TYPE]
743-
) -> Iterable[Tuple[int, ops.GateOperation, T_DESIRED_GATE_TYPE]]:
743+
) -> Iterable[Tuple[int, 'cirq.GateOperation', T_DESIRED_GATE_TYPE]]:
744744
"""Find the locations of all gate operations of a given type.
745745
746746
Args:
@@ -852,7 +852,7 @@ def all_qubits(self) -> FrozenSet['cirq.Qid']:
852852
"""Returns the qubits acted upon by Operations in this circuit."""
853853
return frozenset(q for m in self.moments for q in m.qubits)
854854

855-
def all_operations(self) -> Iterator[ops.Operation]:
855+
def all_operations(self) -> Iterator['cirq.Operation']:
856856
"""Iterates over the operations applied by this circuit.
857857
858858
Operations from earlier moments will be iterated over first. Operations
@@ -1162,7 +1162,7 @@ def to_text_diagram_drawer(
11621162
get_circuit_diagram_info: Optional[
11631163
Callable[['cirq.Operation', 'cirq.CircuitDiagramInfoArgs'], 'cirq.CircuitDiagramInfo']
11641164
] = None,
1165-
) -> TextDiagramDrawer:
1165+
) -> 'cirq.TextDiagramDrawer':
11661166
"""Returns a TextDiagramDrawer with the circuit drawn into it.
11671167
11681168
Args:
@@ -1250,7 +1250,7 @@ def _to_qasm_output(
12501250
header: Optional[str] = None,
12511251
precision: int = 10,
12521252
qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT,
1253-
) -> QasmOutput:
1253+
) -> 'cirq.QasmOutput':
12541254
"""Returns a QASM object equivalent to the circuit.
12551255
12561256
Args:
@@ -1273,7 +1273,7 @@ def _to_qasm_output(
12731273

12741274
def _to_quil_output(
12751275
self, qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT
1276-
) -> QuilOutput:
1276+
) -> 'cirq.QuilOutput':
12771277
qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
12781278
return QuilOutput(operations=self.all_operations(), qubits=qubits)
12791279

@@ -1697,23 +1697,23 @@ def __init__(
16971697
self.append(contents, strategy=strategy)
16981698

16991699
@property
1700-
def device(self) -> devices.Device:
1700+
def device(self) -> 'cirq.Device':
17011701
return self._device
17021702

17031703
@device.setter
17041704
def device(self, new_device: 'cirq.Device') -> None:
17051705
new_device.validate_circuit(self)
17061706
self._device = new_device
17071707

1708-
def __copy__(self) -> 'Circuit':
1708+
def __copy__(self) -> 'cirq.Circuit':
17091709
return self.copy()
17101710

1711-
def copy(self) -> 'Circuit':
1711+
def copy(self) -> 'cirq.Circuit':
17121712
copied_circuit = Circuit(device=self._device)
17131713
copied_circuit._moments = self._moments[:]
17141714
return copied_circuit
17151715

1716-
def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'Circuit':
1716+
def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'cirq.Circuit':
17171717
new_circuit = Circuit(device=self.device)
17181718
new_circuit._moments = list(moments)
17191719
return new_circuit
@@ -1793,7 +1793,7 @@ def __rmul__(self, repetitions: INT_TYPE):
17931793
return NotImplemented
17941794
return self * int(repetitions)
17951795

1796-
def __pow__(self, exponent: int) -> 'Circuit':
1796+
def __pow__(self, exponent: int) -> 'cirq.Circuit':
17971797
"""A circuit raised to a power, only valid for exponent -1, the inverse.
17981798
17991799
This will fail if anything other than -1 is passed to the Circuit by
@@ -1819,7 +1819,7 @@ def with_device(
18191819
self,
18201820
new_device: 'cirq.Device',
18211821
qubit_mapping: Callable[['cirq.Qid'], 'cirq.Qid'] = lambda e: e,
1822-
) -> 'Circuit':
1822+
) -> 'cirq.Circuit':
18231823
"""Maps the current circuit onto a new device, and validates.
18241824
18251825
Args:
@@ -2296,7 +2296,9 @@ def clear_operations_touching(
22962296
if 0 <= k < len(self._moments):
22972297
self._moments[k] = self._moments[k].without_operations_touching(qubits)
22982298

2299-
def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'Circuit':
2299+
def _resolve_parameters_(
2300+
self, resolver: 'cirq.ParamResolver', recursive: bool
2301+
) -> 'cirq.Circuit':
23002302
resolved_moments = []
23012303
for moment in self:
23022304
resolved_operations = _resolve_operations(moment.operations, resolver, recursive)
@@ -2391,7 +2393,7 @@ def _draw_moment_annotations(
23912393
col: int,
23922394
use_unicode_characters: bool,
23932395
label_map: Dict['cirq.LabelEntity', int],
2394-
out_diagram: TextDiagramDrawer,
2396+
out_diagram: 'cirq.TextDiagramDrawer',
23952397
precision: Optional[int],
23962398
get_circuit_diagram_info: Callable[
23972399
['cirq.Operation', 'cirq.CircuitDiagramInfoArgs'], 'cirq.CircuitDiagramInfo'
@@ -2421,7 +2423,7 @@ def _draw_moment_in_diagram(
24212423
moment: 'cirq.Moment',
24222424
use_unicode_characters: bool,
24232425
label_map: Dict['cirq.LabelEntity', int],
2424-
out_diagram: TextDiagramDrawer,
2426+
out_diagram: 'cirq.TextDiagramDrawer',
24252427
precision: Optional[int],
24262428
moment_groups: List[Tuple[int, int]],
24272429
get_circuit_diagram_info: Optional[
@@ -2542,7 +2544,7 @@ def _formatted_phase(coefficient: complex, unicode: bool, precision: Optional[in
25422544
def _draw_moment_groups_in_diagram(
25432545
moment_groups: List[Tuple[int, int]],
25442546
use_unicode_characters: bool,
2545-
out_diagram: TextDiagramDrawer,
2547+
out_diagram: 'cirq.TextDiagramDrawer',
25462548
):
25472549
out_diagram.insert_empty_rows(0)
25482550
h = out_diagram.height()
@@ -2572,7 +2574,7 @@ def _draw_moment_groups_in_diagram(
25722574

25732575

25742576
def _apply_unitary_circuit(
2575-
circuit: AbstractCircuit,
2577+
circuit: 'cirq.AbstractCircuit',
25762578
state: np.ndarray,
25772579
qubits: Tuple['cirq.Qid', ...],
25782580
dtype: Type[np.number],

cirq-core/cirq/circuits/circuit_dag.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
self,
7575
can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits,
7676
incoming_graph_data: Any = None,
77-
device: devices.Device = devices.UNCONSTRAINED_DEVICE,
77+
device: 'cirq.Device' = devices.UNCONSTRAINED_DEVICE,
7878
) -> None:
7979
"""Initializes a CircuitDag.
8080
@@ -100,7 +100,7 @@ def make_node(op: 'cirq.Operation') -> Unique:
100100

101101
@staticmethod
102102
def from_circuit(
103-
circuit: circuit.Circuit,
103+
circuit: 'cirq.Circuit',
104104
can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits,
105105
) -> 'CircuitDag':
106106
return CircuitDag.from_ops(
@@ -111,7 +111,7 @@ def from_circuit(
111111
def from_ops(
112112
*operations: 'cirq.OP_TREE',
113113
can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits,
114-
device: devices.Device = devices.UNCONSTRAINED_DEVICE,
114+
device: 'cirq.Device' = devices.UNCONSTRAINED_DEVICE,
115115
) -> 'CircuitDag':
116116
dag = CircuitDag(can_reorder=can_reorder, device=device)
117117
for op in ops.flatten_op_tree(operations):
@@ -147,21 +147,21 @@ def __ne__(self, other):
147147

148148
__hash__ = None # type: ignore
149149

150-
def ordered_nodes(self) -> Iterator[Unique[ops.Operation]]:
150+
def ordered_nodes(self) -> Iterator[Unique['cirq.Operation']]:
151151
if not self.nodes():
152152
return
153153
g = self.copy()
154154

155-
def get_root_node(some_node: Unique[ops.Operation]) -> Unique[ops.Operation]:
155+
def get_root_node(some_node: Unique['cirq.Operation']) -> Unique['cirq.Operation']:
156156
pred = g.pred
157157
while pred[some_node]:
158158
some_node = next(iter(pred[some_node]))
159159
return some_node
160160

161-
def get_first_node() -> Unique[ops.Operation]:
161+
def get_first_node() -> Unique['cirq.Operation']:
162162
return get_root_node(next(iter(g.nodes())))
163163

164-
def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique[ops.Operation]:
164+
def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique['cirq.Operation']:
165165
if succ:
166166
return get_root_node(next(iter(succ)))
167167

@@ -178,20 +178,20 @@ def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique[ops.Oper
178178

179179
node = get_next_node(succ)
180180

181-
def all_operations(self) -> Iterator[ops.Operation]:
181+
def all_operations(self) -> Iterator['cirq.Operation']:
182182
return (node.val for node in self.ordered_nodes())
183183

184184
def all_qubits(self):
185185
return frozenset(q for node in self.nodes for q in node.val.qubits)
186186

187-
def to_circuit(self) -> circuit.Circuit:
187+
def to_circuit(self) -> 'cirq.Circuit':
188188
return circuit.Circuit(
189189
self.all_operations(), strategy=circuit.InsertStrategy.EARLIEST, device=self.device
190190
)
191191

192192
def findall_nodes_until_blocked(
193-
self, is_blocker: Callable[[ops.Operation], bool]
194-
) -> Iterator[Unique[ops.Operation]]:
193+
self, is_blocker: Callable[['cirq.Operation'], bool]
194+
) -> Iterator[Unique['cirq.Operation']]:
195195
"""Finds all nodes before blocking ones.
196196
197197
Args:

cirq-core/cirq/circuits/circuit_operation.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,14 @@ def __post_init__(self):
143143
# Ensure that param_resolver is converted to an actual ParamResolver.
144144
object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver))
145145

146-
def base_operation(self) -> 'CircuitOperation':
146+
def base_operation(self) -> 'cirq.CircuitOperation':
147147
"""Returns a copy of this operation with only the wrapped circuit.
148148
149149
Key and qubit mappings, parameter values, and repetitions are not copied.
150150
"""
151151
return CircuitOperation(self.circuit)
152152

153-
def replace(self, **changes) -> 'CircuitOperation':
153+
def replace(self, **changes) -> 'cirq.CircuitOperation':
154154
"""Returns a copy of this operation with the specified changes."""
155155
return dataclasses.replace(self, **changes)
156156

@@ -435,7 +435,7 @@ def repeat(
435435

436436
return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)
437437

438-
def __pow__(self, power: int) -> 'CircuitOperation':
438+
def __pow__(self, power: int) -> 'cirq.CircuitOperation':
439439
return self.repeat(power)
440440

441441
def _with_key_path_(self, path: Tuple[str, ...]):
@@ -462,13 +462,13 @@ def _with_rescoped_keys_(
462462
def with_key_path(self, path: Tuple[str, ...]):
463463
return self._with_key_path_(path)
464464

465-
def with_repetition_ids(self, repetition_ids: List[str]) -> 'CircuitOperation':
465+
def with_repetition_ids(self, repetition_ids: List[str]) -> 'cirq.CircuitOperation':
466466
return self.replace(repetition_ids=repetition_ids)
467467

468468
def with_qubit_mapping(
469469
self,
470470
qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']],
471-
) -> 'CircuitOperation':
471+
) -> 'cirq.CircuitOperation':
472472
"""Returns a copy of this operation with an updated qubit mapping.
473473
474474
Users should pass either 'qubit_map' or 'transform' to this method.
@@ -509,7 +509,7 @@ def with_qubit_mapping(
509509
)
510510
return new_op
511511

512-
def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'CircuitOperation':
512+
def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'cirq.CircuitOperation':
513513
"""Returns a copy of this operation with an updated qubit mapping.
514514
515515
Args:
@@ -529,7 +529,7 @@ def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'CircuitOperation':
529529
raise ValueError(f'Expected {expected} qubits, got {len(new_qubits)}.')
530530
return self.with_qubit_mapping(dict(zip(self.qubits, new_qubits)))
531531

532-
def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOperation':
532+
def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'cirq.CircuitOperation':
533533
"""Returns a copy of this operation with an updated key mapping.
534534
535535
Args:
@@ -563,10 +563,12 @@ def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOpera
563563
)
564564
return new_op
565565

566-
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'CircuitOperation':
566+
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'cirq.CircuitOperation':
567567
return self.with_measurement_key_mapping(key_map)
568568

569-
def with_params(self, param_values: study.ParamResolverOrSimilarType) -> 'CircuitOperation':
569+
def with_params(
570+
self, param_values: 'cirq.ParamResolverOrSimilarType'
571+
) -> 'cirq.CircuitOperation':
570572
"""Returns a copy of this operation with an updated ParamResolver.
571573
572574
Note that any resulting parameter mappings with no corresponding
@@ -592,7 +594,7 @@ def with_params(self, param_values: study.ParamResolverOrSimilarType) -> 'Circui
592594
# TODO: handle recursive parameter resolution gracefully
593595
def _resolve_parameters_(
594596
self, resolver: 'cirq.ParamResolver', recursive: bool
595-
) -> 'CircuitOperation':
597+
) -> 'cirq.CircuitOperation':
596598
if recursive:
597599
raise ValueError(
598600
'Recursive resolution of CircuitOperation parameters is prohibited. '

cirq-core/cirq/circuits/frozen_circuit.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def moments(self) -> Sequence['cirq.Moment']:
8383
return self._moments
8484

8585
@property
86-
def device(self) -> devices.Device:
86+
def device(self) -> 'cirq.Device':
8787
return self._device
8888

8989
def __hash__(self):
@@ -116,7 +116,7 @@ def all_qubits(self) -> FrozenSet['cirq.Qid']:
116116
self._all_qubits = super().all_qubits()
117117
return self._all_qubits
118118

119-
def all_operations(self) -> Iterator[ops.Operation]:
119+
def all_operations(self) -> Iterator['cirq.Operation']:
120120
if self._all_operations is None:
121121
self._all_operations = tuple(super().all_operations())
122122
return iter(self._all_operations)
@@ -152,29 +152,29 @@ def all_measurement_key_names(self) -> AbstractSet[str]:
152152
def _measurement_key_names_(self) -> AbstractSet[str]:
153153
return self.all_measurement_key_names()
154154

155-
def __add__(self, other) -> 'FrozenCircuit':
155+
def __add__(self, other) -> 'cirq.FrozenCircuit':
156156
return (self.unfreeze() + other).freeze()
157157

158-
def __radd__(self, other) -> 'FrozenCircuit':
158+
def __radd__(self, other) -> 'cirq.FrozenCircuit':
159159
return (other + self.unfreeze()).freeze()
160160

161161
# Needed for numpy to handle multiplication by np.int64 correctly.
162162
__array_priority__ = 10000
163163

164164
# TODO: handle multiplication / powers differently?
165-
def __mul__(self, other) -> 'FrozenCircuit':
165+
def __mul__(self, other) -> 'cirq.FrozenCircuit':
166166
return (self.unfreeze() * other).freeze()
167167

168-
def __rmul__(self, other) -> 'FrozenCircuit':
168+
def __rmul__(self, other) -> 'cirq.FrozenCircuit':
169169
return (other * self.unfreeze()).freeze()
170170

171-
def __pow__(self, other) -> 'FrozenCircuit':
171+
def __pow__(self, other) -> 'cirq.FrozenCircuit':
172172
try:
173173
return (self.unfreeze() ** other).freeze()
174174
except:
175175
return NotImplemented
176176

177-
def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
177+
def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'cirq.FrozenCircuit':
178178
new_circuit = FrozenCircuit(device=self.device)
179179
new_circuit._moments = tuple(moments)
180180
return new_circuit
@@ -183,12 +183,12 @@ def with_device(
183183
self,
184184
new_device: 'cirq.Device',
185185
qubit_mapping: Callable[['cirq.Qid'], 'cirq.Qid'] = lambda e: e,
186-
) -> 'FrozenCircuit':
186+
) -> 'cirq.FrozenCircuit':
187187
return self.unfreeze().with_device(new_device, qubit_mapping).freeze()
188188

189189
def _resolve_parameters_(
190190
self, resolver: 'cirq.ParamResolver', recursive: bool
191-
) -> 'FrozenCircuit':
191+
) -> 'cirq.FrozenCircuit':
192192
return self.unfreeze()._resolve_parameters_(resolver, recursive).freeze()
193193

194194
def tetris_concat(

0 commit comments

Comments
 (0)