Skip to content

Commit 65a951f

Browse files
authored
Merge pull request #8 from pymc-devs/master
Use fastprogress instead of tqdm progressbar (pymc-devs#3693)
2 parents 46e6983 + 1c30a6f commit 65a951f

File tree

7 files changed

+273
-228
lines changed

7 files changed

+273
-228
lines changed

RELEASE-NOTES.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# Release Notes
22

3-
## PyMC3 3.8 (on deck)
3+
## PyMC3 3.9 (On deck)
4+
5+
### New features
6+
- use [fastprogress](https://github.com/fastai/fastprogress) instead of tqdm [#3693](https://github.com/pymc-devs/pymc3/pull/3693)
7+
8+
## PyMC3 3.8 (November 29 2019)
49

510
### New features
611
- Implemented robust u turn check in NUTS (similar to stan-dev/stan#2800). See PR [#3605]

pymc3/parallel_sampling.py

+37-38
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import errno
1010

1111
import numpy as np
12+
from fastprogress import progress_bar
1213

1314
from . import theanof
1415

@@ -17,28 +18,31 @@
1718

1819
def _get_broken_pipe_exception():
1920
import sys
20-
if sys.platform == 'win32':
21-
return RuntimeError("The communication pipe between the main process "
22-
"and its spawned children is broken.\n"
23-
"In Windows OS, this usually means that the child "
24-
"process raised an exception while it was being "
25-
"spawned, before it was setup to communicate to "
26-
"the main process.\n"
27-
"The exceptions raised by the child process while "
28-
"spawning cannot be caught or handled from the "
29-
"main process, and when running from an IPython or "
30-
"jupyter notebook interactive kernel, the child's "
31-
"exception and traceback appears to be lost.\n"
32-
"A known way to see the child's error, and try to "
33-
"fix or handle it, is to run the problematic code "
34-
"as a batch script from a system's Command Prompt. "
35-
"The child's exception will be printed to the "
36-
"Command Promt's stderr, and it should be visible "
37-
"above this error and traceback.\n"
38-
"Note that if running a jupyter notebook that was "
39-
"invoked from a Command Prompt, the child's "
40-
"exception should have been printed to the Command "
41-
"Prompt on which the notebook is running.")
21+
22+
if sys.platform == "win32":
23+
return RuntimeError(
24+
"The communication pipe between the main process "
25+
"and its spawned children is broken.\n"
26+
"In Windows OS, this usually means that the child "
27+
"process raised an exception while it was being "
28+
"spawned, before it was setup to communicate to "
29+
"the main process.\n"
30+
"The exceptions raised by the child process while "
31+
"spawning cannot be caught or handled from the "
32+
"main process, and when running from an IPython or "
33+
"jupyter notebook interactive kernel, the child's "
34+
"exception and traceback appears to be lost.\n"
35+
"A known way to see the child's error, and try to "
36+
"fix or handle it, is to run the problematic code "
37+
"as a batch script from a system's Command Prompt. "
38+
"The child's exception will be printed to the "
39+
"Command Promt's stderr, and it should be visible "
40+
"above this error and traceback.\n"
41+
"Note that if running a jupyter notebook that was "
42+
"invoked from a Command Prompt, the child's "
43+
"exception should have been printed to the Command "
44+
"Prompt on which the notebook is running."
45+
)
4246
else:
4347
return None
4448

@@ -237,7 +241,6 @@ def __init__(self, draws, tune, step_method, chain, seed, start):
237241
tune,
238242
seed,
239243
)
240-
# We fork right away, so that the main process can start tqdm threads
241244
try:
242245
self._process.start()
243246
except IOError as e:
@@ -346,8 +349,6 @@ def __init__(
346349
start_chain_num=0,
347350
progressbar=True,
348351
):
349-
if progressbar:
350-
from tqdm import tqdm
351352

352353
if any(len(arg) != chains for arg in [seeds, start_points]):
353354
raise ValueError("Number of seeds and start_points must be %s." % chains)
@@ -369,14 +370,13 @@ def __init__(
369370

370371
self._progress = None
371372
self._divergences = 0
373+
self._total_draws = 0
372374
self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences"
373375
self._chains = chains
374-
if progressbar:
375-
self._progress = tqdm(
376-
total=chains * (draws + tune),
377-
unit="draws",
378-
desc=self._desc.format(self)
379-
)
376+
self._progress = progress_bar(
377+
range(chains * (draws + tune)), display=progressbar, auto_update=False
378+
)
379+
self._progress.comment = self._desc.format(self)
380380

381381
def _make_active(self):
382382
while self._inactive and len(self._active) < self._max_active:
@@ -393,11 +393,11 @@ def __iter__(self):
393393
while self._active:
394394
draw = ProcessAdapter.recv_draw(self._active)
395395
proc, is_last, draw, tuning, stats, warns = draw
396-
if self._progress is not None:
397-
if not tuning and stats and stats[0].get('diverging'):
398-
self._divergences += 1
399-
self._progress.set_description(self._desc.format(self))
400-
self._progress.update()
396+
self._total_draws += 1
397+
if not tuning and stats and stats[0].get("diverging"):
398+
self._divergences += 1
399+
self._progress.comment = self._desc.format(self)
400+
self._progress.update(self._total_draws)
401401

402402
if is_last:
403403
proc.join()
@@ -423,8 +423,7 @@ def __enter__(self):
423423

424424
def __exit__(self, *args):
425425
ProcessAdapter.terminate_all(self._samplers)
426-
if self._progress is not None:
427-
self._progress.close()
426+
428427

429428
def _cpu_count():
430429
"""Try to guess the number of CPUs in the system.

pymc3/sampling.py

+27-28
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from .parallel_sampling import _cpu_count
4141
from pymc3.step_methods.hmc import quadpotential
4242
import pymc3 as pm
43-
from tqdm import tqdm
43+
from fastprogress import progress_bar
4444

4545

4646
import sys
@@ -568,11 +568,17 @@ def _sample_population(
568568
# create the generator that iterates all chains in parallel
569569
chains = [chain + c for c in range(chains)]
570570
sampling = _prepare_iter_population(
571-
draws, chains, step, start, parallelize, tune=tune, model=model, random_seed=random_seed
571+
draws,
572+
chains,
573+
step,
574+
start,
575+
parallelize,
576+
tune=tune,
577+
model=model,
578+
random_seed=random_seed,
572579
)
573580

574-
if progressbar:
575-
sampling = tqdm(sampling, total=draws)
581+
sampling = progress_bar(sampling, total=draws, display=progressbar)
576582

577583
latest_traces = None
578584
for it, traces in enumerate(sampling):
@@ -596,23 +602,20 @@ def _sample(
596602

597603
sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed)
598604
_pbar_data = None
599-
if progressbar:
600-
_pbar_data = {"chain": chain, "divergences": 0}
601-
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
602-
sampling = tqdm(sampling, total=draws, desc=_desc.format(**_pbar_data))
605+
_pbar_data = {"chain": chain, "divergences": 0}
606+
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
607+
sampling = progress_bar(sampling, total=draws, display=progressbar)
608+
sampling.comment = _desc.format(**_pbar_data)
603609
try:
604610
strace = None
605611
for it, (strace, diverging) in enumerate(sampling):
606612
if it >= skip_first:
607613
trace = MultiTrace([strace])
608614
if diverging and _pbar_data is not None:
609615
_pbar_data["divergences"] += 1
610-
sampling.set_description(_desc.format(**_pbar_data))
616+
sampling.comment = _desc.format(**_pbar_data)
611617
except KeyboardInterrupt:
612618
pass
613-
finally:
614-
if progressbar:
615-
sampling.close()
616619
return strace
617620

618621

@@ -753,7 +756,7 @@ def __init__(self, steppers, parallelize):
753756
)
754757
import multiprocessing
755758

756-
for c, stepper in enumerate(tqdm(steppers)):
759+
for c, stepper in enumerate(progress_bar(steppers)):
757760
slave_end, master_end = multiprocessing.Pipe()
758761
stepper_dumps = pickle.dumps(stepper, protocol=4)
759762
process = multiprocessing.Process(
@@ -1235,9 +1238,13 @@ def sample_posterior_predictive(
12351238
nchain = 1
12361239

12371240
if keep_size and samples is not None:
1238-
raise IncorrectArgumentsError("Should not specify both keep_size and samples argukments")
1241+
raise IncorrectArgumentsError(
1242+
"Should not specify both keep_size and samples argukments"
1243+
)
12391244
if keep_size and size is not None:
1240-
raise IncorrectArgumentsError("Should not specify both keep_size and size argukments")
1245+
raise IncorrectArgumentsError(
1246+
"Should not specify both keep_size and size argukments"
1247+
)
12411248

12421249
if samples is None:
12431250
samples = sum(len(v) for v in trace._straces.values())
@@ -1253,7 +1260,9 @@ def sample_posterior_predictive(
12531260

12541261
if var_names is not None:
12551262
if vars is not None:
1256-
raise IncorrectArgumentsError("Should not specify both vars and var_names arguments.")
1263+
raise IncorrectArgumentsError(
1264+
"Should not specify both vars and var_names arguments."
1265+
)
12571266
else:
12581267
vars = [model[x] for x in var_names]
12591268
elif vars is not None: # var_names is None, and vars is not.
@@ -1266,8 +1275,7 @@ def sample_posterior_predictive(
12661275

12671276
indices = np.arange(samples)
12681277

1269-
if progressbar:
1270-
indices = tqdm(indices, total=samples)
1278+
indices = progress_bar(indices, total=samples, display=progressbar)
12711279

12721280
ppc_trace_t = _DefaultTrace(samples)
12731281
try:
@@ -1285,10 +1293,6 @@ def sample_posterior_predictive(
12851293
except KeyboardInterrupt:
12861294
pass
12871295

1288-
finally:
1289-
if progressbar:
1290-
indices.close()
1291-
12921296
ppc_trace = ppc_trace_t.trace_dict
12931297
if keep_size:
12941298
for k, ary in ppc_trace.items():
@@ -1411,8 +1415,7 @@ def sample_posterior_predictive_w(
14111415

14121416
indices = np.random.randint(0, len_trace, samples)
14131417

1414-
if progressbar:
1415-
indices = tqdm(indices, total=samples)
1418+
indices = progress_bar(indices, total=samples, display=progressbar)
14161419

14171420
try:
14181421
ppc = defaultdict(list)
@@ -1426,10 +1429,6 @@ def sample_posterior_predictive_w(
14261429
except KeyboardInterrupt:
14271430
pass
14281431

1429-
finally:
1430-
if progressbar:
1431-
indices.close()
1432-
14331432
return {k: np.asarray(v) for k, v in ppc.items()}
14341433

14351434

0 commit comments

Comments
 (0)