40
40
from .parallel_sampling import _cpu_count
41
41
from pymc3 .step_methods .hmc import quadpotential
42
42
import pymc3 as pm
43
- from tqdm import tqdm
43
+ from fastprogress import progress_bar
44
44
45
45
46
46
import sys
@@ -568,11 +568,17 @@ def _sample_population(
568
568
# create the generator that iterates all chains in parallel
569
569
chains = [chain + c for c in range (chains )]
570
570
sampling = _prepare_iter_population (
571
- draws , chains , step , start , parallelize , tune = tune , model = model , random_seed = random_seed
571
+ draws ,
572
+ chains ,
573
+ step ,
574
+ start ,
575
+ parallelize ,
576
+ tune = tune ,
577
+ model = model ,
578
+ random_seed = random_seed ,
572
579
)
573
580
574
- if progressbar :
575
- sampling = tqdm (sampling , total = draws )
581
+ sampling = progress_bar (sampling , total = draws , display = progressbar )
576
582
577
583
latest_traces = None
578
584
for it , traces in enumerate (sampling ):
@@ -596,23 +602,20 @@ def _sample(
596
602
597
603
sampling = _iter_sample (draws , step , start , trace , chain , tune , model , random_seed )
598
604
_pbar_data = None
599
- if progressbar :
600
- _pbar_data = { " chain" : chain , " divergences" : 0 }
601
- _desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
602
- sampling = tqdm ( sampling , total = draws , desc = _desc .format (** _pbar_data ) )
605
+ _pbar_data = { "chain" : chain , "divergences" : 0 }
606
+ _desc = "Sampling chain { chain:d}, { divergences:,d} divergences"
607
+ sampling = progress_bar ( sampling , total = draws , display = progressbar )
608
+ sampling . comment = _desc .format (** _pbar_data )
603
609
try :
604
610
strace = None
605
611
for it , (strace , diverging ) in enumerate (sampling ):
606
612
if it >= skip_first :
607
613
trace = MultiTrace ([strace ])
608
614
if diverging and _pbar_data is not None :
609
615
_pbar_data ["divergences" ] += 1
610
- sampling .set_description ( _desc .format (** _pbar_data ) )
616
+ sampling .comment = _desc .format (** _pbar_data )
611
617
except KeyboardInterrupt :
612
618
pass
613
- finally :
614
- if progressbar :
615
- sampling .close ()
616
619
return strace
617
620
618
621
@@ -753,7 +756,7 @@ def __init__(self, steppers, parallelize):
753
756
)
754
757
import multiprocessing
755
758
756
- for c , stepper in enumerate (tqdm (steppers )):
759
+ for c , stepper in enumerate (progress_bar (steppers )):
757
760
slave_end , master_end = multiprocessing .Pipe ()
758
761
stepper_dumps = pickle .dumps (stepper , protocol = 4 )
759
762
process = multiprocessing .Process (
@@ -1235,9 +1238,13 @@ def sample_posterior_predictive(
1235
1238
nchain = 1
1236
1239
1237
1240
if keep_size and samples is not None :
1238
- raise IncorrectArgumentsError ("Should not specify both keep_size and samples argukments" )
1241
+ raise IncorrectArgumentsError (
1242
+ "Should not specify both keep_size and samples argukments"
1243
+ )
1239
1244
if keep_size and size is not None :
1240
- raise IncorrectArgumentsError ("Should not specify both keep_size and size argukments" )
1245
+ raise IncorrectArgumentsError (
1246
+ "Should not specify both keep_size and size argukments"
1247
+ )
1241
1248
1242
1249
if samples is None :
1243
1250
samples = sum (len (v ) for v in trace ._straces .values ())
@@ -1253,7 +1260,9 @@ def sample_posterior_predictive(
1253
1260
1254
1261
if var_names is not None :
1255
1262
if vars is not None :
1256
- raise IncorrectArgumentsError ("Should not specify both vars and var_names arguments." )
1263
+ raise IncorrectArgumentsError (
1264
+ "Should not specify both vars and var_names arguments."
1265
+ )
1257
1266
else :
1258
1267
vars = [model [x ] for x in var_names ]
1259
1268
elif vars is not None : # var_names is None, and vars is not.
@@ -1266,8 +1275,7 @@ def sample_posterior_predictive(
1266
1275
1267
1276
indices = np .arange (samples )
1268
1277
1269
- if progressbar :
1270
- indices = tqdm (indices , total = samples )
1278
+ indices = progress_bar (indices , total = samples , display = progressbar )
1271
1279
1272
1280
ppc_trace_t = _DefaultTrace (samples )
1273
1281
try :
@@ -1285,10 +1293,6 @@ def sample_posterior_predictive(
1285
1293
except KeyboardInterrupt :
1286
1294
pass
1287
1295
1288
- finally :
1289
- if progressbar :
1290
- indices .close ()
1291
-
1292
1296
ppc_trace = ppc_trace_t .trace_dict
1293
1297
if keep_size :
1294
1298
for k , ary in ppc_trace .items ():
@@ -1411,8 +1415,7 @@ def sample_posterior_predictive_w(
1411
1415
1412
1416
indices = np .random .randint (0 , len_trace , samples )
1413
1417
1414
- if progressbar :
1415
- indices = tqdm (indices , total = samples )
1418
+ indices = progress_bar (indices , total = samples , display = progressbar )
1416
1419
1417
1420
try :
1418
1421
ppc = defaultdict (list )
@@ -1426,10 +1429,6 @@ def sample_posterior_predictive_w(
1426
1429
except KeyboardInterrupt :
1427
1430
pass
1428
1431
1429
- finally :
1430
- if progressbar :
1431
- indices .close ()
1432
-
1433
1432
return {k : np .asarray (v ) for k , v in ppc .items ()}
1434
1433
1435
1434
0 commit comments