1
1
from __future__ import annotations
2
2
3
3
from copy import copy
4
+ from typing import Any , Tuple
4
5
5
6
import cloudpickle
6
7
from sortedcontainers import SortedDict , SortedSet
16
17
except ModuleNotFoundError :
17
18
with_pandas = False
18
19
20
+ try :
21
+ from typing import TypeAlias
22
+ except ImportError :
23
+ from typing_extensions import TypeAlias
24
+
25
+
26
+ PointType : TypeAlias = Tuple [int , Any ]
27
+
19
28
20
29
class _IgnoreFirstArgument :
21
30
"""Remove the first argument from the call signature.
@@ -30,7 +39,7 @@ class _IgnoreFirstArgument:
30
39
def __init__ (self , function ):
31
40
self .function = function
32
41
33
- def __call__ (self , index_point , * args , ** kwargs ):
42
+ def __call__ (self , index_point : PointType , * args , ** kwargs ):
34
43
index , point = index_point
35
44
return self .function (point , * args , ** kwargs )
36
45
@@ -77,7 +86,9 @@ def __init__(self, function, sequence):
77
86
self .data = SortedDict ()
78
87
self .pending_points = set ()
79
88
80
- def ask (self , n , tell_pending = True ):
89
+ def ask (
90
+ self , n : int , tell_pending : bool = True
91
+ ) -> tuple [list [PointType ], list [float ]]:
81
92
indices = []
82
93
points = []
83
94
loss_improvements = []
@@ -95,40 +106,40 @@ def ask(self, n, tell_pending=True):
95
106
96
107
return points , loss_improvements
97
108
98
- def loss (self , real = True ):
109
+ def loss (self , real : bool = True ) -> float :
99
110
if not (self ._to_do_indices or self .pending_points ):
100
- return 0
111
+ return 0.0
101
112
else :
102
113
npoints = self .npoints + (0 if real else len (self .pending_points ))
103
114
return (self ._ntotal - npoints ) / self ._ntotal
104
115
105
- def remove_unfinished (self ):
116
+ def remove_unfinished (self ) -> None :
106
117
for i in self .pending_points :
107
118
self ._to_do_indices .add (i )
108
119
self .pending_points = set ()
109
120
110
- def tell (self , point , value ) :
121
+ def tell (self , point : PointType , value : Any ) -> None :
111
122
index , point = point
112
123
self .data [index ] = value
113
124
self .pending_points .discard (index )
114
125
self ._to_do_indices .discard (index )
115
126
116
- def tell_pending (self , point ) :
127
+ def tell_pending (self , point : PointType ) -> None :
117
128
index , point = point
118
129
self .pending_points .add (index )
119
130
self ._to_do_indices .discard (index )
120
131
121
- def done (self ):
132
+ def done (self ) -> bool :
122
133
return not self ._to_do_indices and not self .pending_points
123
134
124
- def result (self ):
135
+ def result (self ) -> list [ Any ] :
125
136
"""Get the function values in the same order as ``sequence``."""
126
137
if not self .done ():
127
138
raise Exception ("Learner is not yet complete." )
128
139
return list (self .data .values ())
129
140
130
141
@property
131
- def npoints (self ):
142
+ def npoints (self ) -> int :
132
143
return len (self .data )
133
144
134
145
def to_dataframe (
@@ -215,10 +226,10 @@ def load_dataframe(
215
226
self .function , df , function_prefix
216
227
)
217
228
218
- def _get_data (self ):
229
+ def _get_data (self ) -> dict [ int , Any ] :
219
230
return self .data
220
231
221
- def _set_data (self , data ) :
232
+ def _set_data (self , data : dict [ int , Any ]) -> None :
222
233
if data :
223
234
indices , values = zip (* data .items ())
224
235
# the points aren't used by tell, so we can safely pass None
0 commit comments