Skip to content

Commit 2e6549b

Browse files
committed
Use PEP-673 Self type annotations
1 parent e7ef9d4 commit 2e6549b

25 files changed

+125
-164
lines changed

Diff for: cirq-core/cirq/circuits/circuit.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
TypeVar,
4848
Union,
4949
)
50+
from typing_extensions import Self
5051

5152
import networkx
5253
import numpy as np
@@ -236,15 +237,15 @@ def __getitem__(self, key: Tuple[int, Iterable['cirq.Qid']]) -> 'cirq.Moment':
236237
pass
237238

238239
@overload
239-
def __getitem__(self: CIRCUIT_TYPE, key: slice) -> CIRCUIT_TYPE:
240+
def __getitem__(self, key: slice) -> Self:
240241
pass
241242

242243
@overload
243-
def __getitem__(self: CIRCUIT_TYPE, key: Tuple[slice, 'cirq.Qid']) -> CIRCUIT_TYPE:
244+
def __getitem__(self, key: Tuple[slice, 'cirq.Qid']) -> Self:
244245
pass
245246

246247
@overload
247-
def __getitem__(self: CIRCUIT_TYPE, key: Tuple[slice, Iterable['cirq.Qid']]) -> CIRCUIT_TYPE:
248+
def __getitem__(self, key: Tuple[slice, Iterable['cirq.Qid']]) -> Self:
248249
pass
249250

250251
def __getitem__(self, key):
@@ -913,9 +914,7 @@ def all_operations(self) -> Iterator['cirq.Operation']:
913914
"""
914915
return (op for moment in self for op in moment.operations)
915916

916-
def map_operations(
917-
self: CIRCUIT_TYPE, func: Callable[['cirq.Operation'], 'cirq.OP_TREE']
918-
) -> CIRCUIT_TYPE:
917+
def map_operations(self, func: Callable[['cirq.Operation'], 'cirq.OP_TREE']) -> Self:
919918
"""Applies the given function to all operations in this circuit.
920919
921920
Args:
@@ -1287,9 +1286,7 @@ def _is_parameterized_(self) -> bool:
12871286
def _parameter_names_(self) -> AbstractSet[str]:
12881287
return {name for op in self.all_operations() for name in protocols.parameter_names(op)}
12891288

1290-
def _resolve_parameters_(
1291-
self: CIRCUIT_TYPE, resolver: 'cirq.ParamResolver', recursive: bool
1292-
) -> CIRCUIT_TYPE:
1289+
def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> Self:
12931290
changed = False
12941291
resolved_moments: List['cirq.Moment'] = []
12951292
for moment in self:
@@ -1540,7 +1537,7 @@ def get_independent_qubit_sets(self) -> List[Set['cirq.Qid']]:
15401537
uf.union(*op.qubits)
15411538
return sorted([qs for qs in uf.to_sets()], key=min)
15421539

1543-
def factorize(self: CIRCUIT_TYPE) -> Iterable[CIRCUIT_TYPE]:
1540+
def factorize(self) -> Iterable[Self]:
15441541
"""Factorize circuit into a sequence of independent circuits (factors).
15451542
15461543
Factorization is possible when the circuit's qubits can be divided

Diff for: cirq-core/cirq/circuits/moment.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
Sequence,
3131
Tuple,
3232
TYPE_CHECKING,
33-
TypeVar,
3433
Union,
3534
)
35+
from typing_extensions import Self
3636

3737
import numpy as np
3838

@@ -52,8 +52,6 @@
5252
"text_diagram_drawer", globals(), "cirq.circuits.text_diagram_drawer"
5353
)
5454

55-
TSelf_Moment = TypeVar('TSelf_Moment', bound='Moment')
56-
5755

5856
def _default_breakdown(qid: 'cirq.Qid') -> Tuple[Any, Any]:
5957
# Attempt to convert into a position on the complex plane.
@@ -373,9 +371,8 @@ def _decompose_(self) -> 'cirq.OP_TREE':
373371
return self._operations
374372

375373
def transform_qubits(
376-
self: TSelf_Moment,
377-
qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']],
378-
) -> TSelf_Moment:
374+
self, qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']]
375+
) -> Self:
379376
"""Returns the same moment, but with different qubits.
380377
381378
Args:

Diff for: cirq-core/cirq/devices/grid_qubit.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import functools
16-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, TypeVar, TYPE_CHECKING, Union
17-
1815
import abc
16+
import functools
17+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, TYPE_CHECKING, Union
18+
from typing_extensions import Self
1919

2020
import numpy as np
2121

@@ -24,8 +24,6 @@
2424
if TYPE_CHECKING:
2525
import cirq
2626

27-
TSelf = TypeVar('TSelf', bound='_BaseGridQid')
28-
2927

3028
@functools.total_ordering
3129
class _BaseGridQid(ops.Qid):
@@ -69,13 +67,13 @@ def neighbors(self, qids: Optional[Iterable[ops.Qid]] = None) -> Set['_BaseGridQ
6967
return neighbors
7068

7169
@abc.abstractmethod
72-
def _with_row_col(self: TSelf, row: int, col: int) -> TSelf:
70+
def _with_row_col(self, row: int, col: int) -> Self:
7371
"""Returns a qid with the same type but a different coordinate."""
7472

7573
def __complex__(self) -> complex:
7674
return self.col + 1j * self.row
7775

78-
def __add__(self: TSelf, other: Union[Tuple[int, int], TSelf]) -> 'TSelf':
76+
def __add__(self, other: Union[Tuple[int, int], Self]) -> Self:
7977
if isinstance(other, _BaseGridQid):
8078
if self.dimension != other.dimension:
8179
raise TypeError(
@@ -94,7 +92,7 @@ def __add__(self: TSelf, other: Union[Tuple[int, int], TSelf]) -> 'TSelf':
9492
)
9593
return self._with_row_col(row=self.row + other[0], col=self.col + other[1])
9694

97-
def __sub__(self: TSelf, other: Union[Tuple[int, int], TSelf]) -> 'TSelf':
95+
def __sub__(self, other: Union[Tuple[int, int], Self]) -> Self:
9896
if isinstance(other, _BaseGridQid):
9997
if self.dimension != other.dimension:
10098
raise TypeError(
@@ -113,13 +111,13 @@ def __sub__(self: TSelf, other: Union[Tuple[int, int], TSelf]) -> 'TSelf':
113111
)
114112
return self._with_row_col(row=self.row - other[0], col=self.col - other[1])
115113

116-
def __radd__(self: TSelf, other: Tuple[int, int]) -> 'TSelf':
114+
def __radd__(self, other: Tuple[int, int]) -> Self:
117115
return self + other
118116

119-
def __rsub__(self: TSelf, other: Tuple[int, int]) -> 'TSelf':
117+
def __rsub__(self, other: Tuple[int, int]) -> Self:
120118
return -self + other
121119

122-
def __neg__(self: TSelf) -> 'TSelf':
120+
def __neg__(self) -> Self:
123121
return self._with_row_col(row=-self.row, col=-self.col)
124122

125123

Diff for: cirq-core/cirq/devices/line_qubit.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import functools
16-
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, TypeVar, TYPE_CHECKING, Union
17-
1815
import abc
16+
import functools
17+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, TYPE_CHECKING, Union
18+
from typing_extensions import Self
1919

2020
from cirq import ops, protocols
2121

2222
if TYPE_CHECKING:
2323
import cirq
2424

25-
TSelf = TypeVar('TSelf', bound='_BaseLineQid')
26-
2725

2826
@functools.total_ordering
2927
class _BaseLineQid(ops.Qid):
@@ -66,10 +64,10 @@ def neighbors(self, qids: Optional[Iterable[ops.Qid]] = None) -> Set['_BaseLineQ
6664
return neighbors
6765

6866
@abc.abstractmethod
69-
def _with_x(self: TSelf, x: int) -> TSelf:
67+
def _with_x(self, x: int) -> Self:
7068
"""Returns a qubit with the same type but a different value of `x`."""
7169

72-
def __add__(self: TSelf, other: Union[int, TSelf]) -> TSelf:
70+
def __add__(self, other: Union[int, Self]) -> Self:
7371
if isinstance(other, _BaseLineQid):
7472
if self.dimension != other.dimension:
7573
raise TypeError(
@@ -81,7 +79,7 @@ def __add__(self: TSelf, other: Union[int, TSelf]) -> TSelf:
8179
raise TypeError(f"Can only add ints and {type(self).__name__}. Instead was {other}")
8280
return self._with_x(self.x + other)
8381

84-
def __sub__(self: TSelf, other: Union[int, TSelf]) -> TSelf:
82+
def __sub__(self, other: Union[int, Self]) -> Self:
8583
if isinstance(other, _BaseLineQid):
8684
if self.dimension != other.dimension:
8785
raise TypeError(
@@ -95,13 +93,13 @@ def __sub__(self: TSelf, other: Union[int, TSelf]) -> TSelf:
9593
)
9694
return self._with_x(self.x - other)
9795

98-
def __radd__(self: TSelf, other: int) -> TSelf:
96+
def __radd__(self, other: int) -> Self:
9997
return self + other
10098

101-
def __rsub__(self: TSelf, other: int) -> TSelf:
99+
def __rsub__(self, other: int) -> Self:
102100
return -self + other
103101

104-
def __neg__(self: TSelf) -> TSelf:
102+
def __neg__(self) -> Self:
105103
return self._with_x(-self.x)
106104

107105
def __complex__(self) -> complex:

Diff for: cirq-core/cirq/ops/arithmetic_operation.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
import abc
1717
import itertools
18-
from typing import Union, Iterable, List, Sequence, cast, Tuple, TypeVar, TYPE_CHECKING
18+
from typing import Union, Iterable, List, Sequence, cast, Tuple, TYPE_CHECKING
19+
from typing_extensions import Self
1920

2021
import numpy as np
2122

@@ -25,9 +26,6 @@
2526
import cirq
2627

2728

28-
TSelfGate = TypeVar('TSelfGate', bound='ArithmeticGate')
29-
30-
3129
class ArithmeticGate(Gate, metaclass=abc.ABCMeta):
3230
r"""A helper gate for implementing reversible classical arithmetic.
3331
@@ -55,7 +53,7 @@ class ArithmeticGate(Gate, metaclass=abc.ABCMeta):
5553
...
5654
... def with_registers(
5755
... self, *new_registers: 'Union[int, Sequence[int]]'
58-
... ) -> 'TSelfGate':
56+
... ) -> 'Add':
5957
... return Add(*new_registers)
6058
...
6159
... def apply(self, *register_values: int) -> 'Union[int, Iterable[int]]':
@@ -105,7 +103,7 @@ def registers(self) -> Sequence[Union[int, Sequence[int]]]:
105103
raise NotImplementedError()
106104

107105
@abc.abstractmethod
108-
def with_registers(self: TSelfGate, *new_registers: Union[int, Sequence[int]]) -> TSelfGate:
106+
def with_registers(self, *new_registers: Union[int, Sequence[int]]) -> Self:
109107
"""Returns the same fate targeting different registers.
110108
111109
Args:

Diff for: cirq-core/cirq/ops/common_gates.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
TYPE_CHECKING,
3737
Union,
3838
)
39+
from typing_extensions import Self
3940

4041
import numpy as np
4142
import sympy
@@ -357,7 +358,7 @@ def __init__(self, *, rads: value.TParamVal):
357358
self._rads = rads
358359
super().__init__(exponent=rads / _pi(rads), global_shift=-0.5)
359360

360-
def _with_exponent(self: 'Rx', exponent: value.TParamVal) -> 'Rx':
361+
def _with_exponent(self, exponent: value.TParamVal) -> 'Rx':
361362
return Rx(rads=exponent * _pi(exponent))
362363

363364
def _circuit_diagram_info_(
@@ -541,7 +542,7 @@ def __init__(self, *, rads: value.TParamVal):
541542
self._rads = rads
542543
super().__init__(exponent=rads / _pi(rads), global_shift=-0.5)
543544

544-
def _with_exponent(self: 'Ry', exponent: value.TParamVal) -> 'Ry':
545+
def _with_exponent(self, exponent: value.TParamVal) -> 'Ry':
545546
return Ry(rads=exponent * _pi(exponent))
546547

547548
def _circuit_diagram_info_(
@@ -891,7 +892,7 @@ def __init__(self, *, rads: value.TParamVal):
891892
self._rads = rads
892893
super().__init__(exponent=rads / _pi(rads), global_shift=-0.5)
893894

894-
def _with_exponent(self: 'Rz', exponent: value.TParamVal) -> 'Rz':
895+
def _with_exponent(self, exponent: value.TParamVal) -> 'Rz':
895896
return Rz(rads=exponent * _pi(exponent))
896897

897898
def _circuit_diagram_info_(

0 commit comments

Comments
 (0)