Skip to content

Commit 9ea4c55

Browse files
committed
change DataPoint to dict before saving
1 parent 0b9f216 commit 9ea4c55

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

adaptive/learner/average1D.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from adaptive.notebook_integration import ensure_holoviews
99
from adaptive.learner.learner1D import Learner1D
10-
from adaptive.learner.average_mixin import add_average_mixin
10+
from adaptive.learner.average_mixin import add_average_mixin, DataPoint
1111

1212

1313
@add_average_mixin
@@ -65,7 +65,8 @@ def tell_pending(self, x_seed):
6565
def tell_many(self, xs, ys):
6666
# `super().tell_many(xs, ys)` will not work.
6767
for x, y in zip(xs, ys):
68-
self.tell(x, y)
68+
for seed, value in y.items():
69+
self.tell((x, seed), value)
6970

7071
def remove_unfinished(self):
7172
self.pending_points = {}
@@ -87,3 +88,8 @@ def plot(self, *, with_sem=True):
8788
return scatter * spread
8889
else:
8990
return scatter
91+
92+
def _set_data(self, data):
93+
# change dict -> DataPoint, because the points are saved using dicts
94+
data = {k: DataPoint(v) for k, v in data.items()}
95+
self.tell_many(data.keys(), data.values())

adaptive/learner/average2D.py

+5
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,8 @@ def tell(self, point, value):
146146
if not point_exists:
147147
self._ip = None
148148
self._stack.pop(point, None)
149+
150+
def _set_data(self, data):
151+
# change dict -> DataPoint, because the points are saved using dicts
152+
data = {k: DataPoint(v) for k, v in data.items()}
153+
super()._set_data(data)

adaptive/learner/average_mixin.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,19 @@ def needs_more_data(p):
129129
return [inf if needs_more_data(p) else loss
130130
for (p, seed), loss in zip(points, loss_improvements)]
131131

132+
def _get_data(self):
133+
# change DataPoint -> dict for saving
134+
return {k: dict(v) for k, v in self._data.items()}
135+
132136

133137
def add_average_mixin(cls):
134138
names = ('data', 'data_sem', 'mean_values_per_point',
135139
'get_seed', 'loss_per_existing_point', '_add_to_pending',
136140
'_remove_from_to_pending', '_add_to_data', 'ask', 'n_values',
137141
'_normalize_new_points_loss_improvements',
138142
'_normalize_existing_points_loss_improvements',
139-
'_mean_values_per_neighbor')
143+
'_mean_values_per_neighbor',
144+
'_get_data')
140145

141146
for name in names:
142147
setattr(cls, name, getattr(AverageMixin, name))

0 commit comments

Comments
 (0)