Skip to content

Commit 9757434

Browse files
committed
Register custom overloads in all processes
1 parent 65f0b1e commit 9757434

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

pymc/smc/sampling.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
from pymc.backends.arviz import dict_to_dataset, to_inference_data
3333
from pymc.backends.base import MultiTrace
34+
from pymc.distributions.distribution import _support_point
35+
from pymc.logprob.abstract import _logcdf, _logprob
3436
from pymc.model import Model, modelcontext
3537
from pymc.sampling.parallel import _cpu_count
3638
from pymc.smc.kernels import IMH
@@ -375,11 +377,18 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
375377
# main process and our worker functions
376378
_progress = manager.dict()
377379

380+
# check if model contains CustomDistributions defined without dist argument
381+
custom_methods = _find_custom_methods(params[3])
382+
378383
# "manually" (de)serialize params before/after multiprocessing
379384
params = tuple(cloudpickle.dumps(p) for p in params)
380385
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
381386

382-
with ProcessPoolExecutor(max_workers=cores) as executor:
387+
with ProcessPoolExecutor(
388+
max_workers=cores,
389+
initializer=_register_custom_methods,
390+
initargs=(custom_methods,),
391+
) as executor:
383392
for c in range(chains): # iterate over the jobs we need to run
384393
# set visible false so we don't have a lot of bars all at once:
385394
task_id = progress.add_task(
@@ -406,3 +415,25 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
406415
progress.update(status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id)
407416

408417
return tuple(cloudpickle.loads(r.result()) for r in futures)
418+
419+
420+
def _find_custom_methods(model):
421+
custom_methods = {}
422+
for rv in model.free_RVs + model.observed_RVs:
423+
cls = rv.owner.op.__class__
424+
if hasattr(cls, "_random_fn"):
425+
custom_methods[cloudpickle.dumps(cls)] = (
426+
cloudpickle.dumps(_logprob.registry[cls]),
427+
cloudpickle.dumps(_logcdf.registry[cls]),
428+
cloudpickle.dumps(_support_point.registry[cls]),
429+
)
430+
431+
return custom_methods
432+
433+
434+
def _register_custom_methods(custom_methods):
435+
for cls, (logprob, logcdf, support_point) in custom_methods.items():
436+
cls = cloudpickle.loads(cls)
437+
_logprob.register(cls, cloudpickle.loads(logprob))
438+
_logcdf.register(cls, cloudpickle.loads(logcdf))
439+
_support_point.register(cls, cloudpickle.loads(support_point))

0 commit comments

Comments
 (0)