Skip to content

Include n_tune, n_draws and t_sampling in SamplerReport #3827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- `DEMetropolisZ`, an improved variant of `DEMetropolis` brings better parallelization and higher efficiency with fewer chains with a slower initial convergence. This implementation is experimental. See [#3784](https://github.com/pymc-devs/pymc3/pull/3784) for more info.
- Notebooks that give insight into `DEMetropolis`, `DEMetropolisZ` and the `DifferentialEquation` interface are now located in the [Tutorials/Deep Dive](https://docs.pymc.io/nb_tutorials/index.html) section.
- Add `fast_sample_posterior_predictive`, a vectorized alternative to `sample_posterior_predictive`. This alternative is substantially faster for large models.
- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))

### Maintenance
- Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6.
Expand Down
25 changes: 24 additions & 1 deletion pymc3/backends/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import namedtuple
import logging
import enum
import typing
from ..util import is_transformed_name, get_untransformed_name


Expand Down Expand Up @@ -51,11 +52,15 @@ class WarningType(enum.Enum):


class SamplerReport:
"""This object bundles warnings, convergence statistics and metadata of a sampling run."""
def __init__(self):
self._chain_warnings = {}
self._global_warnings = []
self._ess = None
self._rhat = None
self._n_tune = None
self._n_draws = None
self._t_sampling = None

@property
def _warnings(self):
Expand All @@ -68,6 +73,25 @@ def ok(self):
return all(_LEVELS[warn.level] < _LEVELS['warn']
for warn in self._warnings)

@property
def n_tune(self) -> typing.Optional[int]:
"""Number of tune iterations - not necessarily kept in trace!"""
return self._n_tune

@property
def n_draws(self) -> typing.Optional[int]:
"""Number of draw iterations."""
return self._n_draws

@property
def t_sampling(self) -> typing.Optional[float]:
"""
Number of seconds that the sampling procedure took.

(Includes parallelization overhead.)
"""
return self._t_sampling

def raise_ok(self, level='error'):
errors = [warn for warn in self._warnings
if _LEVELS[warn.level] >= _LEVELS[level]]
Expand Down Expand Up @@ -151,7 +175,6 @@ def _add_warnings(self, warnings, chain=None):
warn_list.extend(warnings)

def _log_summary(self):

def log_warning(warn):
level = _LEVELS[warn.level]
logger.log(level, warn.message)
Expand Down
34 changes: 32 additions & 2 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from copy import copy
import pickle
import logging
import time
import warnings

import numpy as np
Expand Down Expand Up @@ -488,6 +489,7 @@ def sample(
)

parallel = cores > 1 and chains > 1 and not has_population_samplers
t_start = time.time()
if parallel:
_log.info("Multiprocess sampling ({} chains in {} jobs)".format(chains, cores))
_print_step_hierarchy(step)
Expand Down Expand Up @@ -533,8 +535,36 @@ def sample(
_print_step_hierarchy(step)
trace = _sample_many(**sample_args)

discard = tune if discard_tuned_samples else 0
trace = trace[discard:]
t_sampling = time.time() - t_start
# count the number of tune/draw iterations that happened
# ideally via the "tune" statistic, but not all samplers record it!
if 'tune' in trace.stat_names:
stat = trace.get_sampler_stats('tune', chains=0)
# when CompoundStep is used, the stat is 2 dimensional!
if len(stat.shape) == 2:
stat = stat[:,0]
stat = tuple(stat)
n_tune = stat.count(True)
n_draws = stat.count(False)
else:
# these may be wrong when KeyboardInterrupt happened, but they're better than nothing
n_tune = min(tune, len(trace))
n_draws = max(0, len(trace) - n_tune)

if discard_tuned_samples:
trace = trace[n_tune:]

# save metadata in SamplerReport
trace.report._n_tune = n_tune
trace.report._n_draws = n_draws
trace.report._t_sampling = t_sampling

n_chains = len(trace.chains)
_log.info(
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations '
f'({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) '
f'took {trace.report.t_sampling:.0f} seconds.'
)

if compute_convergence_checks:
if draws - tune < 100:
Expand Down
16 changes: 16 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,22 @@ def test_sample_tune_len(self):
trace = pm.sample(draws=100, tune=50, cores=4)
assert len(trace) == 100

@pytest.mark.parametrize("step_cls", [pm.NUTS, pm.Metropolis, pm.Slice])
@pytest.mark.parametrize("discard", [True, False])
def test_trace_report(self, step_cls, discard):
with self.model:
# add more variables, because stats are 2D with CompoundStep!
pm.Uniform('uni')
trace = pm.sample(
draws=100, tune=50, cores=1,
discard_tuned_samples=discard,
step=step_cls()
)
assert trace.report.n_tune == 50
assert trace.report.n_draws == 100
assert isinstance(trace.report.t_sampling, float)
pass

@pytest.mark.parametrize('cores', [1, 2])
def test_sampler_stat_tune(self, cores):
with self.model:
Expand Down