Skip to content

Commit f296a5c

Browse files
authored
switch from pickle/dill to cloudpickle (#4858)
* use cloudpickle for serialization * add cloudpickle to requirements * update tests for cloudpickle * update release notes with cloudpickle * update conda envs with cloudpickle * remove special case serialization for DensityDist.logp * add pickle import back in for pickle.PickleError * remove strict error message check in test
1 parent a3ee747 commit f296a5c

17 files changed

+42
-105
lines changed

Diff for: RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
- Logp method of `Uniform` and `DiscreteUniform` no longer depends on `pymc3.distributions.dist_math.bound` for proper evaluation (see [#4541](https://github.com/pymc-devs/pymc3/pull/4541)).
2727
- `Model.RV_dims` and `Model.coords` are now read-only properties. To modify the `coords` dictionary use `Model.add_coord`. Also `dims` or coordinate values that are `None` will be auto-completed (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)).
2828
- The length of `dims` in the model is now tracked symbolically through `Model.dim_lengths` (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)).
29+
- We now include `cloudpickle` as a required dependency, and no longer depend on `dill` (see [#4858](https://github.com/pymc-devs/pymc3/pull/4858)).
2930
- ...
3031

3132
## PyMC3 3.11.2 (14 March 2021)

Diff for: conda-envs/environment-dev-py37.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ dependencies:
66
- aesara>=2.0.9
77
- arviz>=0.11.2
88
- cachetools>=4.2.1
9-
- dill
9+
- cloudpickle
1010
- fastprogress>=0.2.0
1111
- h5py>=2.7
1212
- ipython>=7.16

Diff for: conda-envs/environment-dev-py38.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ dependencies:
66
- aesara>=2.0.9
77
- arviz>=0.11.2
88
- cachetools>=4.2.1
9-
- dill
9+
- cloudpickle
1010
- fastprogress>=0.2.0
1111
- h5py>=2.7
1212
- ipython>=7.16

Diff for: conda-envs/environment-dev-py39.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ dependencies:
66
- aesara>=2.0.9
77
- arviz>=0.11.2
88
- cachetools>=4.2.1
9-
- dill
9+
- cloudpickle
1010
- fastprogress>=0.2.0
1111
- h5py>=2.7
1212
- ipython>=7.16

Diff for: conda-envs/windows-environment-dev-py38.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ dependencies:
77
- aesara>=2.0.9
88
- arviz>=0.11.2
99
- cachetools>=4.2.1
10-
- dill
10+
- cloudpickle
1111
- fastprogress>=0.2.0
1212
- h5py>=2.7
1313
- libpython

Diff for: pymc3/distributions/distribution.py

-22
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import aesara
2525
import aesara.tensor as at
26-
import dill
2726

2827
from aesara.tensor.random.op import RandomVariable
2928
from aesara.tensor.random.var import RandomStateSharedVariable
@@ -533,26 +532,5 @@ def __init__(
533532
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
534533
self.check_shape_in_random = check_shape_in_random
535534

536-
def __getstate__(self):
537-
# We use dill to serialize the logp function, as this is almost
538-
# always defined in the notebook and won't be pickled correctly.
539-
# Fix https://github.com/pymc-devs/pymc3/issues/3844
540-
try:
541-
logp = dill.dumps(self.logp)
542-
except RecursionError as err:
543-
if type(self.logp) == types.MethodType:
544-
raise ValueError(
545-
"logp for DensityDist is a bound method, leading to RecursionError while serializing"
546-
) from err
547-
else:
548-
raise err
549-
vals = self.__dict__.copy()
550-
vals["logp"] = logp
551-
return vals
552-
553-
def __setstate__(self, vals):
554-
vals["logp"] = dill.loads(vals["logp"])
555-
self.__dict__ = vals
556-
557535
def _distr_parameters_for_repr(self):
558536
return []

Diff for: pymc3/parallel_sampling.py

+6-31
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
import logging
1717
import multiprocessing
1818
import multiprocessing.sharedctypes
19-
import pickle
2019
import platform
2120
import time
2221
import traceback
2322

2423
from collections import namedtuple
2524

25+
import cloudpickle
2626
import numpy as np
2727

2828
from fastprogress.fastprogress import progress_bar
@@ -93,7 +93,6 @@ def __init__(
9393
draws: int,
9494
tune: int,
9595
seed,
96-
pickle_backend,
9796
):
9897
self._msg_pipe = msg_pipe
9998
self._step_method = step_method
@@ -103,7 +102,6 @@ def __init__(
103102
self._at_seed = seed + 1
104103
self._draws = draws
105104
self._tune = tune
106-
self._pickle_backend = pickle_backend
107105

108106
def _unpickle_step_method(self):
109107
unpickle_error = (
@@ -112,22 +110,10 @@ def _unpickle_step_method(self):
112110
"or forkserver."
113111
)
114112
if self._step_method_is_pickled:
115-
if self._pickle_backend == "pickle":
116-
try:
117-
self._step_method = pickle.loads(self._step_method)
118-
except Exception:
119-
raise ValueError(unpickle_error)
120-
elif self._pickle_backend == "dill":
121-
try:
122-
import dill
123-
except ImportError:
124-
raise ValueError("dill must be installed for pickle_backend='dill'.")
125-
try:
126-
self._step_method = dill.loads(self._step_method)
127-
except Exception:
128-
raise ValueError(unpickle_error)
129-
else:
130-
raise ValueError("Unknown pickle backend")
113+
try:
114+
self._step_method = cloudpickle.loads(self._step_method)
115+
except Exception:
116+
raise ValueError(unpickle_error)
131117

132118
def run(self):
133119
try:
@@ -243,7 +229,6 @@ def __init__(
243229
seed,
244230
start,
245231
mp_ctx,
246-
pickle_backend,
247232
):
248233
self.chain = chain
249234
process_name = "worker_chain_%s" % chain
@@ -287,7 +272,6 @@ def __init__(
287272
draws,
288273
tune,
289274
seed,
290-
pickle_backend,
291275
),
292276
)
293277
self._process.start()
@@ -406,7 +390,6 @@ def __init__(
406390
start_chain_num: int = 0,
407391
progressbar: bool = True,
408392
mp_ctx=None,
409-
pickle_backend: str = "pickle",
410393
):
411394

412395
if any(len(arg) != chains for arg in [seeds, start_points]):
@@ -420,14 +403,7 @@ def __init__(
420403

421404
step_method_pickled = None
422405
if mp_ctx.get_start_method() != "fork":
423-
if pickle_backend == "pickle":
424-
step_method_pickled = pickle.dumps(step_method, protocol=-1)
425-
elif pickle_backend == "dill":
426-
try:
427-
import dill
428-
except ImportError:
429-
raise ValueError("dill must be installed for pickle_backend='dill'.")
430-
step_method_pickled = dill.dumps(step_method, protocol=-1)
406+
step_method_pickled = cloudpickle.dumps(step_method, protocol=-1)
431407

432408
self._samplers = [
433409
ProcessAdapter(
@@ -439,7 +415,6 @@ def __init__(
439415
seed,
440416
start,
441417
mp_ctx,
442-
pickle_backend,
443418
)
444419
for chain, seed, start in zip(range(chains), seeds, start_points)
445420
]

Diff for: pymc3/sampling.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
2727

2828
import aesara.gradient as tg
29+
import cloudpickle
2930
import numpy as np
3031
import xarray
3132

@@ -268,7 +269,6 @@ def sample(
268269
return_inferencedata=None,
269270
idata_kwargs: dict = None,
270271
mp_ctx=None,
271-
pickle_backend: str = "pickle",
272272
**kwargs,
273273
):
274274
r"""Draw samples from the posterior using the given step methods.
@@ -362,10 +362,6 @@ def sample(
362362
mp_ctx : multiprocessing.context.BaseContent
363363
A multiprocessing context for parallel sampling. See multiprocessing
364364
documentation for details.
365-
pickle_backend : str
366-
One of `'pickle'` or `'dill'`. The library used to pickle models
367-
in parallel sampling if the multiprocessing context is not of type
368-
`fork`.
369365
370366
Returns
371367
-------
@@ -548,7 +544,6 @@ def sample(
548544
"discard_tuned_samples": discard_tuned_samples,
549545
}
550546
parallel_args = {
551-
"pickle_backend": pickle_backend,
552547
"mp_ctx": mp_ctx,
553548
}
554549

@@ -1100,7 +1095,7 @@ def __init__(self, steppers, parallelize, progressbar=True):
11001095
enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
11011096
):
11021097
secondary_end, primary_end = multiprocessing.Pipe()
1103-
stepper_dumps = pickle.dumps(stepper, protocol=4)
1098+
stepper_dumps = cloudpickle.dumps(stepper, protocol=4)
11041099
process = multiprocessing.Process(
11051100
target=self.__class__._run_secondary,
11061101
args=(c, stepper_dumps, secondary_end),
@@ -1159,7 +1154,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
11591154
# re-seed each child process to make them unique
11601155
np.random.seed(None)
11611156
try:
1162-
stepper = pickle.loads(stepper_dumps)
1157+
stepper = cloudpickle.loads(stepper_dumps)
11631158
# the stepper is not necessarily a PopulationArraySharedStep itself,
11641159
# but rather a CompoundStep. PopulationArrayStepShared.population
11651160
# has to be updated, therefore we identify the substeppers first.
@@ -1418,7 +1413,6 @@ def _mp_sample(
14181413
callback=None,
14191414
discard_tuned_samples=True,
14201415
mp_ctx=None,
1421-
pickle_backend="pickle",
14221416
**kwargs,
14231417
):
14241418
"""Main iteration for multiprocess sampling.
@@ -1491,7 +1485,6 @@ def _mp_sample(
14911485
chain,
14921486
progressbar,
14931487
mp_ctx=mp_ctx,
1494-
pickle_backend=pickle_backend,
14951488
)
14961489
try:
14971490
try:

Diff for: pymc3/tests/test_distributions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3109,9 +3109,9 @@ def func(x):
31093109
y = pm.DensityDist("y", func)
31103110
pm.sample(draws=5, tune=1, mp_ctx="spawn")
31113111

3112-
import pickle
3112+
import cloudpickle
31133113

3114-
pickle.loads(pickle.dumps(y))
3114+
cloudpickle.loads(cloudpickle.dumps(y))
31153115

31163116

31173117
def test_distinct_rvs():

Diff for: pymc3/tests/test_minibatches.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
# limitations under the License.
1414

1515
import itertools
16-
import pickle
1716

1817
import aesara
18+
import cloudpickle
1919
import numpy as np
2020
import pytest
2121

@@ -132,10 +132,10 @@ def gen():
132132

133133
def test_pickling(self, datagen):
134134
gen = generator(datagen)
135-
pickle.loads(pickle.dumps(gen))
135+
cloudpickle.loads(cloudpickle.dumps(gen))
136136
bad_gen = generator(integers())
137-
with pytest.raises(Exception):
138-
pickle.dumps(bad_gen)
137+
with pytest.raises(TypeError):
138+
cloudpickle.dumps(bad_gen)
139139

140140
def test_gen_cloning_with_shape_change(self, datagen):
141141
gen = generator(datagen)

Diff for: pymc3/tests/test_model.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import pickle
1514
import unittest
1615

1716
from functools import reduce
1817

1918
import aesara
2019
import aesara.sparse as sparse
2120
import aesara.tensor as at
21+
import cloudpickle
2222
import numpy as np
2323
import numpy.ma as ma
2424
import numpy.testing as npt
@@ -407,9 +407,7 @@ def test_model_pickle(tmpdir):
407407
x = pm.Normal("x")
408408
pm.Normal("y", observed=1)
409409

410-
file_path = tmpdir.join("model.p")
411-
with open(file_path, "wb") as buff:
412-
pickle.dump(model, buff)
410+
cloudpickle.loads(cloudpickle.dumps(model))
413411

414412

415413
def test_model_pickle_deterministic(tmpdir):
@@ -420,9 +418,7 @@ def test_model_pickle_deterministic(tmpdir):
420418
pm.Deterministic("w", x / z)
421419
pm.Normal("y", observed=1)
422420

423-
file_path = tmpdir.join("model.p")
424-
with open(file_path, "wb") as buff:
425-
pickle.dump(model, buff)
421+
cloudpickle.loads(cloudpickle.dumps(model))
426422

427423

428424
def test_model_vars():

Diff for: pymc3/tests/test_parallel_sampling.py

-8
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,6 @@ def _crash_remote_process(a, master_pid):
7171
return 2 * np.array(a)
7272

7373

74-
def test_dill():
75-
with pm.Model():
76-
pm.Normal("x")
77-
pm.sample(tune=1, draws=1, chains=2, cores=2, pickle_backend="dill", mp_ctx="spawn")
78-
79-
8074
def test_remote_pipe_closed():
8175
master_pid = os.getpid()
8276
with pm.Model():
@@ -112,7 +106,6 @@ def test_abort():
112106
mp_ctx=ctx,
113107
start={"a": np.array([1.0]), "b_log__": np.array(2.0)},
114108
step_method_pickled=None,
115-
pickle_backend="pickle",
116109
)
117110
proc.start()
118111
while True:
@@ -147,7 +140,6 @@ def test_explicit_sample():
147140
mp_ctx=ctx,
148141
start={"a": np.array([1.0]), "b_log__": np.array(2.0)},
149142
step_method_pickled=None,
150-
pickle_backend="pickle",
151143
)
152144
proc.start()
153145
while True:

Diff for: pymc3/tests/test_pickling.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import pickle
1616
import traceback
1717

18+
import cloudpickle
19+
1820
from pymc3.tests.models import simple_model
1921

2022

@@ -26,8 +28,8 @@ def test_model_roundtrip(self):
2628
m = self.model
2729
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
2830
try:
29-
s = pickle.dumps(m, proto)
30-
pickle.loads(s)
31+
s = cloudpickle.dumps(m, proto)
32+
cloudpickle.loads(s)
3133
except Exception:
3234
raise AssertionError(
3335
"Exception while trying roundtrip with pickle protocol %d:\n" % proto

0 commit comments

Comments
 (0)