Skip to content

Commit cc8be65

Browse files
committed
Typehint SequenceLearner
1 parent 914495c commit cc8be65

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

adaptive/learner/sequence_learner.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from copy import copy
4+
from typing import Any, Tuple
45

56
import cloudpickle
67
from sortedcontainers import SortedDict, SortedSet
@@ -16,6 +17,14 @@
1617
except ModuleNotFoundError:
1718
with_pandas = False
1819

20+
try:
21+
from typing import TypeAlias
22+
except ImportError:
23+
from typing_extensions import TypeAlias
24+
25+
26+
PointType: TypeAlias = Tuple[int, Any]
27+
1928

2029
class _IgnoreFirstArgument:
2130
"""Remove the first argument from the call signature.
@@ -30,7 +39,7 @@ class _IgnoreFirstArgument:
3039
def __init__(self, function):
3140
self.function = function
3241

33-
def __call__(self, index_point, *args, **kwargs):
42+
def __call__(self, index_point: PointType, *args, **kwargs):
3443
index, point = index_point
3544
return self.function(point, *args, **kwargs)
3645

@@ -77,7 +86,9 @@ def __init__(self, function, sequence):
7786
self.data = SortedDict()
7887
self.pending_points = set()
7988

80-
def ask(self, n, tell_pending=True):
89+
def ask(
90+
self, n: int, tell_pending: bool = True
91+
) -> tuple[list[PointType], list[float]]:
8192
indices = []
8293
points = []
8394
loss_improvements = []
@@ -95,40 +106,40 @@ def ask(self, n, tell_pending=True):
95106

96107
return points, loss_improvements
97108

98-
def loss(self, real=True):
109+
def loss(self, real: bool = True) -> float:
99110
if not (self._to_do_indices or self.pending_points):
100-
return 0
111+
return 0.0
101112
else:
102113
npoints = self.npoints + (0 if real else len(self.pending_points))
103114
return (self._ntotal - npoints) / self._ntotal
104115

105-
def remove_unfinished(self):
116+
def remove_unfinished(self) -> None:
106117
for i in self.pending_points:
107118
self._to_do_indices.add(i)
108119
self.pending_points = set()
109120

110-
def tell(self, point, value):
121+
def tell(self, point: PointType, value: Any) -> None:
111122
index, point = point
112123
self.data[index] = value
113124
self.pending_points.discard(index)
114125
self._to_do_indices.discard(index)
115126

116-
def tell_pending(self, point):
127+
def tell_pending(self, point: PointType) -> None:
117128
index, point = point
118129
self.pending_points.add(index)
119130
self._to_do_indices.discard(index)
120131

121-
def done(self):
132+
def done(self) -> bool:
122133
return not self._to_do_indices and not self.pending_points
123134

124-
def result(self):
135+
def result(self) -> list[Any]:
125136
"""Get the function values in the same order as ``sequence``."""
126137
if not self.done():
127138
raise Exception("Learner is not yet complete.")
128139
return list(self.data.values())
129140

130141
@property
131-
def npoints(self):
142+
def npoints(self) -> int:
132143
return len(self.data)
133144

134145
def to_dataframe(
@@ -215,10 +226,10 @@ def load_dataframe(
215226
self.function, df, function_prefix
216227
)
217228

218-
def _get_data(self):
229+
def _get_data(self) -> dict[int, Any]:
219230
return self.data
220231

221-
def _set_data(self, data):
232+
def _set_data(self, data: dict[int, Any]) -> None:
222233
if data:
223234
indices, values = zip(*data.items())
224235
# the points aren't used by tell, so we can safely pass None

0 commit comments

Comments
 (0)