Skip to content

Commit ea30a49

Browse files
committed
Add type-hints to adaptive/notebook_integration.py
1 parent 97c85b9 commit ea30a49

File tree

5 files changed

+28
-9
lines changed

5 files changed

+28
-9
lines changed

adaptive/notebook_integration.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import datetime
35
import importlib
@@ -76,7 +78,7 @@ def ensure_plotly():
7678
raise RuntimeError("plotly is not installed; plotting is disabled.")
7779

7880

79-
def in_ipynb():
81+
def in_ipynb() -> bool:
8082
try:
8183
# If we are running in IPython, then `get_ipython()` is always a global
8284
return get_ipython().__class__.__name__ == "ZMQInteractiveShell"

adaptive/tests/test_average_learner1d.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
from itertools import chain
2+
13
import numpy as np
2-
import pandas as pd
3-
from pandas.testing import assert_series_equal
44

55
from adaptive import AverageLearner1D
66
from adaptive.tests.test_learners import (
@@ -11,13 +11,27 @@
1111

1212

1313
def almost_equal_dicts(a, b):
14-
assert_series_equal(pd.Series(sorted(a.items())), pd.Series(sorted(b.items())))
14+
assert a.keys() == b.keys()
15+
for k, v1 in a.items():
16+
v2 = b[k]
17+
if (
18+
v1 is None
19+
or v2 is None
20+
or isinstance(v1, (tuple, list))
21+
and any(x is None for x in chain(v1, v2))
22+
):
23+
assert v1 == v2
24+
else:
25+
try:
26+
np.testing.assert_almost_equal(v1, v2)
27+
except TypeError:
28+
raise AssertionError(f"{v1} != {v2}")
1529

1630

1731
def test_tell_many_at_point():
1832
f = generate_random_parametrization(noisy_peak)
1933
learner = AverageLearner1D(f, bounds=(-2, 2))
20-
control = AverageLearner1D(f, bounds=(-2, 2))
34+
control = learner.new()
2135
learner._recompute_losses_factor = 1
2236
control._recompute_losses_factor = 1
2337
simple_run(learner, 100)

adaptive/tests/test_learners.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import flaky
1414
import numpy as np
15-
import pandas
1615
import pytest
1716
import scipy.spatial
1817

@@ -28,6 +27,7 @@
2827
LearnerND,
2928
SequenceLearner,
3029
)
30+
from adaptive.learner.learner1D import with_pandas
3131
from adaptive.runner import simple
3232

3333
try:
@@ -708,6 +708,7 @@ def wrapper(*args, **kwargs):
708708
return wrapper
709709

710710

711+
@pytest.mark.skipif(not with_pandas, reason="pandas is not installed")
711712
@run_with(
712713
Learner1D,
713714
Learner2D,
@@ -719,6 +720,8 @@ def wrapper(*args, **kwargs):
719720
with_all_loss_functions=False,
720721
)
721722
def test_to_dataframe(learner_type, f, learner_kwargs):
723+
import pandas
724+
722725
if learner_type is LearnerND:
723726
kw = {"point_names": tuple("xyz")[: len(learner_kwargs["bounds"])]}
724727
else:

adaptive/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def wrapper(*args, **kwargs):
4545
return wrapper
4646

4747

48-
def save(fname: str, data: Any, compress: bool = True) -> None:
48+
def save(fname: str, data: Any, compress: bool = True) -> bool:
4949
fname = os.path.expanduser(fname)
5050
dirname = os.path.dirname(fname)
5151
if dirname:
@@ -74,7 +74,7 @@ def save(fname: str, data: Any, compress: bool = True) -> None:
7474
return True
7575

7676

77-
def load(fname: str, compress: bool = True):
77+
def load(fname: str, compress: bool = True) -> Any:
7878
fname = os.path.expanduser(fname)
7979
_open = gzip.open if compress else open
8080
with _open(fname, "rb") as f:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def get_version_and_cmdclass(package_name):
4242
"holoviews>=1.9.1",
4343
"ipywidgets",
4444
"bokeh",
45+
"pandas",
4546
"matplotlib",
4647
"plotly",
4748
],
@@ -52,7 +53,6 @@ def get_version_and_cmdclass(package_name):
5253
"pytest-randomly",
5354
"pytest-timeout",
5455
"pre_commit",
55-
"pandas",
5656
"typeguard",
5757
],
5858
"other": [

0 commit comments

Comments
 (0)