Skip to content

Typehint SequenceLearner #366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 11, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions adaptive/learner/sequence_learner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from copy import copy
from typing import Any, Tuple

import cloudpickle
from sortedcontainers import SortedDict, SortedSet

from adaptive.learner.base_learner import BaseLearner
from adaptive.types import Int
from adaptive.utils import assign_defaults, partial_function_from_dataframe

try:
Expand All @@ -16,6 +18,14 @@
except ModuleNotFoundError:
with_pandas = False

try:
from typing import TypeAlias
except ImportError:
from typing_extensions import TypeAlias


PointType: TypeAlias = Tuple[Int, Any]


class _IgnoreFirstArgument:
"""Remove the first argument from the call signature.
Expand All @@ -30,7 +40,7 @@ class _IgnoreFirstArgument:
def __init__(self, function):
self.function = function

def __call__(self, index_point, *args, **kwargs):
def __call__(self, index_point: PointType, *args, **kwargs):
index, point = index_point
return self.function(point, *args, **kwargs)

Expand Down Expand Up @@ -81,7 +91,9 @@ def new(self) -> SequenceLearner:
"""Return a new `~adaptive.SequenceLearner` without the data."""
return SequenceLearner(self._original_function, self.sequence)

def ask(self, n, tell_pending=True):
def ask(
self, n: int, tell_pending: bool = True
) -> tuple[list[PointType], list[float]]:
indices = []
points = []
loss_improvements = []
Expand All @@ -99,40 +111,40 @@ def ask(self, n, tell_pending=True):

return points, loss_improvements

def loss(self, real=True):
def loss(self, real: bool = True) -> float:
if not (self._to_do_indices or self.pending_points):
return 0
return 0.0
else:
npoints = self.npoints + (0 if real else len(self.pending_points))
return (self._ntotal - npoints) / self._ntotal

def remove_unfinished(self):
def remove_unfinished(self) -> None:
for i in self.pending_points:
self._to_do_indices.add(i)
self.pending_points = set()

def tell(self, point, value):
def tell(self, point: PointType, value: Any) -> None:
index, point = point
self.data[index] = value
self.pending_points.discard(index)
self._to_do_indices.discard(index)

def tell_pending(self, point):
def tell_pending(self, point: PointType) -> None:
index, point = point
self.pending_points.add(index)
self._to_do_indices.discard(index)

def done(self):
def done(self) -> bool:
return not self._to_do_indices and not self.pending_points

def result(self):
def result(self) -> list[Any]:
"""Get the function values in the same order as ``sequence``."""
if not self.done():
raise Exception("Learner is not yet complete.")
return list(self.data.values())

@property
def npoints(self):
def npoints(self) -> int:
return len(self.data)

def to_dataframe(
Expand Down Expand Up @@ -213,16 +225,18 @@ def load_dataframe(
y_name : str, optional
The ``y_name`` used in ``to_dataframe``, by default "y"
"""
self.tell_many(df[[index_name, x_name]].values, df[y_name].values)
indices = df[index_name].values
xs = df[x_name].values
self.tell_many(zip(indices, xs), df[y_name].values)
if with_default_function_args:
self.function = partial_function_from_dataframe(
self._original_function, df, function_prefix
)

def _get_data(self):
def _get_data(self) -> dict[int, Any]:
return self.data

def _set_data(self, data):
def _set_data(self, data: dict[int, Any]) -> None:
if data:
indices, values = zip(*data.items())
# the points aren't used by tell, so we can safely pass None
Expand Down