31
31
32
32
from pymc .backends .arviz import dict_to_dataset , to_inference_data
33
33
from pymc .backends .base import MultiTrace
34
+ from pymc .distributions .distribution import _support_point
35
+ from pymc .logprob .abstract import _logcdf , _logprob
34
36
from pymc .model import Model , modelcontext
35
37
from pymc .sampling .parallel import _cpu_count
36
38
from pymc .smc .kernels import IMH
@@ -375,11 +377,18 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
375
377
# main process and our worker functions
376
378
_progress = manager .dict ()
377
379
380
+ # check if model contains CustomDistributions defined without dist argument
381
+ custom_methods = _find_custom_methods (params [3 ])
382
+
378
383
# "manually" (de)serialize params before/after multiprocessing
379
384
params = tuple (cloudpickle .dumps (p ) for p in params )
380
385
kernel_kwargs = {key : cloudpickle .dumps (value ) for key , value in kernel_kwargs .items ()}
381
386
382
- with ProcessPoolExecutor (max_workers = cores ) as executor :
387
+ with ProcessPoolExecutor (
388
+ max_workers = cores ,
389
+ initializer = _register_custom_methods ,
390
+ initargs = (custom_methods ,),
391
+ ) as executor :
383
392
for c in range (chains ): # iterate over the jobs we need to run
384
393
# set visible false so we don't have a lot of bars all at once:
385
394
task_id = progress .add_task (
@@ -406,3 +415,25 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
406
415
progress .update (status = f"Stage: { stage } Beta: { beta :.3f} " , task_id = task_id )
407
416
408
417
return tuple (cloudpickle .loads (r .result ()) for r in futures )
418
+
419
+
420
+ def _find_custom_methods (model ):
421
+ custom_methods = {}
422
+ for rv in model .free_RVs + model .observed_RVs :
423
+ cls = rv .owner .op .__class__
424
+ if hasattr (cls , "_random_fn" ):
425
+ custom_methods [cloudpickle .dumps (cls )] = (
426
+ cloudpickle .dumps (_logprob .registry [cls ]),
427
+ cloudpickle .dumps (_logcdf .registry [cls ]),
428
+ cloudpickle .dumps (_support_point .registry [cls ]),
429
+ )
430
+
431
+ return custom_methods
432
+
433
+
434
+ def _register_custom_methods (custom_methods ):
435
+ for cls , (logprob , logcdf , support_point ) in custom_methods .items ():
436
+ cls = cloudpickle .loads (cls )
437
+ _logprob .register (cls , cloudpickle .loads (logprob ))
438
+ _logcdf .register (cls , cloudpickle .loads (logcdf ))
439
+ _support_point .register (cls , cloudpickle .loads (support_point ))
0 commit comments