Skip to content

Commit 682a1f0

Browse files
authored
Merge pull request #378 from python-adaptive/type-hints
Add type-hints to tests and misc
2 parents f32346a + ea30a49 commit 682a1f0

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

adaptive/notebook_integration.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import datetime
35
import importlib
@@ -76,7 +78,7 @@ def ensure_plotly():
7678
raise RuntimeError("plotly is not installed; plotting is disabled.")
7779

7880

79-
def in_ipynb():
81+
def in_ipynb() -> bool:
8082
try:
8183
# If we are running in IPython, then `get_ipython()` is always a global
8284
return get_ipython().__class__.__name__ == "ZMQInteractiveShell"

adaptive/tests/test_learner1d.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ def f_vec(x, offset=0.123214):
277277
def assert_equal_dicts(d1, d2):
278278
xs1, ys1 = zip(*sorted(d1.items()))
279279
xs2, ys2 = zip(*sorted(d2.items()))
280-
ys1 = np.array(ys1, dtype=np.float)
281-
ys2 = np.array(ys2, dtype=np.float)
280+
ys1 = np.array(ys1, dtype=np.float64)
281+
ys2 = np.array(ys2, dtype=np.float64)
282282
np.testing.assert_almost_equal(xs1, xs2)
283283
np.testing.assert_almost_equal(ys1, ys2)
284284

adaptive/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
Float: TypeAlias = Union[float, np.float_]
1212
Int: TypeAlias = Union[int, np.int_]
1313
Real: TypeAlias = Union[Float, Int]
14+
Bool: TypeAlias = Union[bool, np.bool_]

adaptive/utils.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,27 @@
1+
from __future__ import annotations
2+
13
import abc
24
import functools
35
import gzip
46
import inspect
57
import os
68
import pickle
79
import warnings
8-
from contextlib import contextmanager
10+
from contextlib import _GeneratorContextManager, contextmanager
911
from itertools import product
12+
from typing import Any, Callable, Mapping, Sequence
1013

1114
import cloudpickle
1215

1316

14-
def named_product(**items):
17+
def named_product(**items: Mapping[str, Sequence[Any]]):
1518
names = items.keys()
1619
vals = items.values()
1720
return [dict(zip(names, res)) for res in product(*vals)]
1821

1922

2023
@contextmanager
21-
def restore(*learners):
24+
def restore(*learners) -> _GeneratorContextManager:
2225
states = [learner.__getstate__() for learner in learners]
2326
try:
2427
yield
@@ -27,7 +30,7 @@ def restore(*learners):
2730
learner.__setstate__(state)
2831

2932

30-
def cache_latest(f):
33+
def cache_latest(f: Callable) -> Callable:
3134
"""Cache the latest return value of the function and add it
3235
as 'self._cache[f.__name__]'."""
3336

@@ -42,7 +45,7 @@ def wrapper(*args, **kwargs):
4245
return wrapper
4346

4447

45-
def save(fname, data, compress=True):
48+
def save(fname: str, data: Any, compress: bool = True) -> bool:
4649
fname = os.path.expanduser(fname)
4750
dirname = os.path.dirname(fname)
4851
if dirname:
@@ -71,14 +74,14 @@ def save(fname, data, compress=True):
7174
return True
7275

7376

74-
def load(fname, compress=True):
77+
def load(fname: str, compress: bool = True) -> Any:
7578
fname = os.path.expanduser(fname)
7679
_open = gzip.open if compress else open
7780
with _open(fname, "rb") as f:
7881
return cloudpickle.load(f)
7982

8083

81-
def copy_docstring_from(other):
84+
def copy_docstring_from(other: Callable) -> Callable:
8285
def decorator(method):
8386
return functools.wraps(other)(method)
8487

0 commit comments

Comments
 (0)