Skip to content

Commit 75a3274

Browse files
authored
always serialize the function using cloudpickle (#281)
This was a suggestion of @jbweston some time ago. Using this, the leaners are picklable regardless of the serialization package chosen.
1 parent e091a12 commit 75a3274

File tree

6 files changed

+16
-6
lines changed

6 files changed

+16
-6
lines changed

adaptive/learner/average_learner.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from math import sqrt
22

3+
import cloudpickle
34
import numpy as np
45

56
from adaptive.learner.base_learner import BaseLearner
@@ -151,7 +152,7 @@ def _set_data(self, data):
151152

152153
def __getstate__(self):
153154
return (
154-
self.function,
155+
cloudpickle.dumps(self.function),
155156
self.atol,
156157
self.rtol,
157158
self.min_npoints,
@@ -160,5 +161,6 @@ def __getstate__(self):
160161

161162
def __setstate__(self, state):
162163
function, atol, rtol, min_npoints, data = state
164+
function = cloudpickle.loads(function)
163165
self.__init__(function, atol, rtol, min_npoints)
164166
self._set_data(data)

adaptive/learner/integrator_learner.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from math import sqrt
66
from operator import attrgetter
77

8+
import cloudpickle
89
import numpy as np
910
from scipy.linalg import norm
1011
from sortedcontainers import SortedSet
@@ -594,13 +595,14 @@ def _set_data(self, data):
594595

595596
def __getstate__(self):
596597
return (
597-
self.function,
598+
cloudpickle.dumps(self.function),
598599
self.bounds,
599600
self.tol,
600601
self._get_data(),
601602
)
602603

603604
def __setstate__(self, state):
604605
function, bounds, tol, data = state
606+
function = cloudpickle.loads(function)
605607
self.__init__(function, bounds, tol)
606608
self._set_data(data)

adaptive/learner/learner1D.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Iterable
44
from copy import deepcopy
55

6+
import cloudpickle
67
import numpy as np
78
import sortedcollections
89
import sortedcontainers
@@ -634,7 +635,7 @@ def _set_data(self, data):
634635

635636
def __getstate__(self):
636637
return (
637-
self.function,
638+
cloudpickle.dumps(self.function),
638639
self.bounds,
639640
self.loss_per_interval,
640641
dict(self.losses), # SortedDict cannot be pickled
@@ -644,6 +645,7 @@ def __getstate__(self):
644645

645646
def __setstate__(self, state):
646647
function, bounds, loss_per_interval, losses, losses_combined, data = state
648+
function = cloudpickle.loads(function)
647649
self.__init__(function, bounds, loss_per_interval)
648650
self._set_data(data)
649651
self.losses.update(losses)

adaptive/learner/learner2D.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from copy import copy
55
from math import sqrt
66

7+
import cloudpickle
78
import numpy as np
89
from scipy import interpolate
910

@@ -709,7 +710,7 @@ def _set_data(self, data):
709710

710711
def __getstate__(self):
711712
return (
712-
self.function,
713+
cloudpickle.dumps(self.function),
713714
self.bounds,
714715
self.loss_per_triangle,
715716
self._stack,
@@ -718,6 +719,7 @@ def __getstate__(self):
718719

719720
def __setstate__(self, state):
720721
function, bounds, loss_per_triangle, _stack, data = state
722+
function = cloudpickle.loads(function)
721723
self.__init__(function, bounds, loss_per_triangle)
722724
self._set_data(data)
723725
self._stack = _stack

adaptive/learner/sequence_learner.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from copy import copy
22

3+
import cloudpickle
34
from sortedcontainers import SortedDict, SortedSet
45

56
from adaptive.learner.base_learner import BaseLearner
@@ -131,12 +132,13 @@ def _set_data(self, data):
131132

132133
def __getstate__(self):
133134
return (
134-
self._original_function,
135+
cloudpickle.dumps(self._original_function),
135136
self.sequence,
136137
self._get_data(),
137138
)
138139

139140
def __setstate__(self, state):
140141
function, sequence, data = state
142+
function = cloudpickle.loads(function)
141143
self.__init__(function, sequence)
142144
self._set_data(data)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def get_version_and_cmdclass(package_name):
2929
"sortedcollections >= 1.1",
3030
"sortedcontainers >= 2.0",
3131
"atomicwrites",
32+
"cloudpickle",
3233
]
3334

3435
extras_require = {
@@ -51,7 +52,6 @@ def get_version_and_cmdclass(package_name):
5152
"pre_commit",
5253
],
5354
"other": [
54-
"cloudpickle",
5555
"dill",
5656
"distributed",
5757
"ipyparallel>=6.2.5", # because of https://github.com/ipython/ipyparallel/issues/404

0 commit comments

Comments
 (0)