diff --git a/pymc/__init__.py b/pymc/__init__.py index 6612800fe4..09314aa5c3 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -72,6 +72,7 @@ def __set_compiler_flags(): from pymc.stats import * from pymc.step_methods import * from pymc.tuning import * +from pymc.util import drop_warning_stat from pymc.variational import * from pymc.vartypes import * diff --git a/pymc/backends/base.py b/pymc/backends/base.py index cfd334ade4..ef6ee1cd64 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -78,10 +78,6 @@ def __init__(self, name, model=None, vars=None, test_point=None): self.chain = None self._is_base_setup = False self.sampler_vars = None - self._warnings = [] - - def _add_warnings(self, warnings): - self._warnings.extend(warnings) # Sampling methods @@ -288,9 +284,6 @@ def __init__(self, straces): self._straces[strace.chain] = strace self._report = SamplerReport() - for strace in straces: - if hasattr(strace, "_warnings"): - self._report._add_warnings(strace._warnings, strace.chain) def __repr__(self): template = "<{}: {} chains, {} iterations, {} variables>" diff --git a/pymc/parallel_sampling.py b/pymc/parallel_sampling.py index d9af68dca0..98c5c59639 100644 --- a/pymc/parallel_sampling.py +++ b/pymc/parallel_sampling.py @@ -40,12 +40,9 @@ class ParallelSamplingError(Exception): - def __init__(self, message, chain, warnings=None): + def __init__(self, message, chain): super().__init__(message) - if warnings is None: - warnings = [] self._chain = chain - self._warnings = warnings # Taken from https://hg.python.org/cpython/rev/c4f92b597074 @@ -74,8 +71,8 @@ def rebuild_exc(exc, tb): # Messages -# ('writing_done', is_last, sample_idx, tuning, stats, warns) -# ('error', warnings, *exception_info) +# ('writing_done', is_last, sample_idx, tuning, stats) +# ('error', *exception_info) # ('abort', reason) # ('write_next',) @@ -133,7 +130,7 @@ def run(self): e = ExceptionWithTraceback(e, e.__traceback__) # Send is not blocking so we have to force a wait for the abort # message - self._msg_pipe.send(("error", None, e)) + self._msg_pipe.send(("error", e)) self._wait_for_abortion() finally: self._msg_pipe.close() @@ -181,9 +178,8 @@ def _start_loop(self): try: point, stats = self._compute_point() except SamplingError as e: - warns = self._collect_warnings() e = ExceptionWithTraceback(e, e.__traceback__) - self._msg_pipe.send(("error", warns, e)) + self._msg_pipe.send(("error", e)) else: return @@ -193,11 +189,7 @@ def _start_loop(self): elif msg[0] == "write_next": self._write_point(point) is_last = draw + 1 == self._draws + self._tune - if is_last: - warns = self._collect_warnings() - else: - warns = None - self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats, warns)) + self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats)) draw += 1 else: raise ValueError("Unknown message " + msg[0]) @@ -210,12 +202,6 @@ def _compute_point(self): stats = None return point, stats - def _collect_warnings(self): - if hasattr(self._step_method, "warnings"): - return self._step_method.warnings() - else: - return [] - def _run_process(*args): _Process(*args).run() @@ -308,11 +294,13 @@ def _send(self, msg, *args): except Exception: pass if message is not None and message[0] == "error": - warns, old_error = message[1:] - if warns is not None: - error = ParallelSamplingError(str(old_error), self.chain, warns) + old_error = message[1] + if old_error is not None: + error = ParallelSamplingError( + f"Chain {self.chain} failed with: {old_error}", self.chain + ) else: - error = RuntimeError("Chain %s failed." % self.chain) + error = RuntimeError(f"Chain {self.chain} failed.") raise error from old_error raise @@ -345,11 +333,13 @@ def recv_draw(processes, timeout=3600): msg = ready[0].recv() if msg[0] == "error": - warns, old_error = msg[1:] - if warns is not None: - error = ParallelSamplingError(str(old_error), proc.chain, warns) + old_error = msg[1] + if old_error is not None: + error = ParallelSamplingError( + f"Chain {proc.chain} failed with: {old_error}", proc.chain + ) else: - error = RuntimeError("Chain %s failed." % proc.chain) + error = RuntimeError(f"Chain {proc.chain} failed.") raise error from old_error elif msg[0] == "writing_done": proc._readable = True @@ -383,7 +373,7 @@ def terminate_all(processes, patience=2): process.join() -Draw = namedtuple("Draw", ["chain", "is_last", "draw_idx", "tuning", "stats", "point", "warnings"]) +Draw = namedtuple("Draw", ["chain", "is_last", "draw_idx", "tuning", "stats", "point"]) class ParallelSampler: @@ -466,7 +456,7 @@ def __iter__(self): while self._active: draw = ProcessAdapter.recv_draw(self._active) - proc, is_last, draw, tuning, stats, warns = draw + proc, is_last, draw, tuning, stats = draw self._total_draws += 1 if not tuning and stats and stats[0].get("diverging"): self._divergences += 1 @@ -491,7 +481,7 @@ def __iter__(self): if not is_last: proc.write_next() - yield Draw(proc.chain, is_last, draw, tuning, stats, point, warns) + yield Draw(proc.chain, is_last, draw, tuning, stats, point) def __enter__(self): self._in_context = True diff --git a/pymc/sampling.py b/pymc/sampling.py index 41ac56921b..8e31ff989e 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -70,12 +70,13 @@ ) from pymc.model import Model, modelcontext from pymc.parallel_sampling import Draw, _cpu_count -from pymc.stats.convergence import run_convergence_checks +from pymc.stats.convergence import SamplerWarning, log_warning, run_convergence_checks from pymc.step_methods import NUTS, CompoundStep, DEMetropolis from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential from pymc.util import ( dataset_to_point_list, + drop_warning_stat, get_default_varnames, get_untransformed_name, is_transformed_name, @@ -323,6 +324,7 @@ def sample( jitter_max_retries: int = 10, *, return_inferencedata: bool = True, + keep_warning_stat: bool = False, idata_kwargs: dict = None, mp_ctx=None, **kwargs, @@ -393,6 +395,13 @@ def sample( `MultiTrace` (False). Defaults to `True`. idata_kwargs : dict, optional Keyword arguments for :func:`pymc.to_inference_data` + keep_warning_stat : bool + If ``True`` the "warning" stat emitted by, for example, HMC samplers will be kept + in the returned ``idata.sample_stat`` group. + This leads to the ``idata`` not supporting ``.to_netcdf()`` or ``.to_zarr()`` and + should only be set to ``True`` if you intend to use the "warning" objects right away. + Defaults to ``False`` such that ``pm.drop_warning_stat`` is applied automatically, + making the ``InferenceData`` compatible with saving. mp_ctx : multiprocessing.context.BaseContent A multiprocessing context for parallel sampling. See multiprocessing documentation for details. @@ -699,6 +708,10 @@ def sample( mtrace.report._add_warnings(convergence_warnings) if return_inferencedata: + # By default we drop the "warning" stat which contains `SamplerWarning` + # objects that can not be stored with `.to_netcdf()`. + if not keep_warning_stat: + return drop_warning_stat(idata) return idata return mtrace @@ -1048,32 +1061,26 @@ def _iter_sample( if step.generates_stats: point, stats = step.step(point) strace.record(point, stats) + log_warning_stats(stats) diverging = i > tune and stats and stats[0].get("diverging") else: point = step.step(point) strace.record(point) if callback is not None: - warns = getattr(step, "warnings", None) callback( trace=strace, - draw=Draw(chain, i == draws, i, i < tune, stats, point, warns), + draw=Draw(chain, i == draws, i, i < tune, stats, point), ) yield strace, diverging except KeyboardInterrupt: strace.close() - if hasattr(step, "warnings"): - warns = step.warnings() - strace._add_warnings(warns) raise except BaseException: strace.close() raise else: strace.close() - if hasattr(step, "warnings"): - warns = step.warnings() - strace._add_warnings(warns) class PopulationStepper: @@ -1356,6 +1363,7 @@ def _iter_population( if steppers[c].generates_stats: points[c], stats = updates[c] strace.record(points[c], stats) + log_warning_stats(stats) else: points[c] = updates[c] strace.record(points[c]) @@ -1513,21 +1521,16 @@ def _mp_sample( with sampler: for draw in sampler: strace = traces[draw.chain] - if draw.stats is not None: - strace.record(draw.point, draw.stats) - else: - strace.record(draw.point) + strace.record(draw.point, draw.stats) + log_warning_stats(draw.stats) if draw.is_last: strace.close() - if draw.warnings is not None: - strace._add_warnings(draw.warnings) if callback is not None: callback(trace=trace, draw=draw) except ps.ParallelSamplingError as error: strace = traces[error._chain] - strace._add_warnings(error._warnings) for strace in traces: strace.close() @@ -1546,6 +1549,22 @@ def _mp_sample( strace.close() +def log_warning_stats(stats: Sequence[Dict[str, Any]]): + """Logs 'warning' stats if present.""" + if stats is None: + return + + for sts in stats: + warn = sts.get("warning", None) + if warn is None: + continue + if isinstance(warn, SamplerWarning): + log_warning(warn) + else: + _log.warning(warn) + return + + def _choose_chains(traces: Sequence[BaseTrace], tune: int) -> Tuple[List[BaseTrace], int]: """ Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized. diff --git a/pymc/stats/convergence.py b/pymc/stats/convergence.py index 3288f5e881..e39beff573 100644 --- a/pymc/stats/convergence.py +++ b/pymc/stats/convergence.py @@ -68,7 +68,7 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWar warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info") return [warn] - warnings = [] + warnings: List[SamplerWarning] = [] valid_name = [rv.name for rv in model.free_RVs + model.deterministics] varnames = [] for rv in model.free_RVs: @@ -104,11 +104,60 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWar warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess) warnings.append(warn) + warnings += warn_divergences(idata) + warnings += warn_treedepth(idata) + + return warnings + + +def warn_divergences(idata: arviz.InferenceData) -> List[SamplerWarning]: + """Checks sampler stats and creates a list of warnings about divergences.""" + sampler_stats = idata.get("sample_stats", None) + if sampler_stats is None: + return [] + + diverging = sampler_stats.get("diverging", None) + if diverging is None: + return [] + + # Warn about divergences + n_div = int(diverging.sum()) + if n_div == 0: + return [] + warning = SamplerWarning( + WarningType.DIVERGENCES, + f"There were {n_div} divergences after tuning. Increase `target_accept` or reparameterize.", + "error", + ) + return [warning] + + +def warn_treedepth(idata: arviz.InferenceData) -> List[SamplerWarning]: + """Checks sampler stats and creates a list of warnings about tree depth.""" + sampler_stats = idata.get("sample_stats", None) + if sampler_stats is None: + return [] + + treedepth = sampler_stats.get("tree_depth", None) + if treedepth is None: + return [] + + warnings = [] + for c in treedepth.chain: + if sum(treedepth.sel(chain=c)) / treedepth.sizes["draw"] > 0.05: + warnings.append( + SamplerWarning( + WarningType.TREEDEPTH, + f"Chain {c} reached the maximum tree depth." + " Increase `max_treedepth`, increase `target_accept` or reparameterize.", + "warn", + ) + ) return warnings def log_warning(warn: SamplerWarning): - level = _LEVELS[warn.level] + level = _LEVELS.get(warn.level, logging.WARNING) logger.log(level, warn.message) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index d824045c7e..4bb86056ed 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -59,13 +59,6 @@ def step(self, point): point = method.step(point) return point - def warnings(self): - warns = [] - for method in self.methods: - if hasattr(method, "warnings"): - warns.extend(method.warnings()) - return warns - def stop_tuning(self): for method in self.methods: method.stop_tuning() diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 863246100d..a08dd86071 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -17,6 +17,7 @@ from abc import abstractmethod from collections import namedtuple +from typing import Optional import numpy as np @@ -134,8 +135,6 @@ def __init__( self.integrator = integration.CpuLeapfrogIntegrator(self.potential, self._logp_dlogp_func) self._step_rand = step_rand - self._warnings = [] - self._samples_after_tune = 0 self._num_divs_sample = 0 @abstractmethod @@ -173,8 +172,7 @@ def astep(self, q0): "critical", self.iter_count, ) - self._warnings.append(warning) - raise SamplingError("Bad initial energy") + raise SamplingError(f"Bad initial energy: {warning}") adapt_step = self.tune and self.adapt_step_size step_size = self.step_adapt.current(adapt_step) @@ -190,6 +188,7 @@ def astep(self, q0): self.step_adapt.update(hmc_step.accept_stat, adapt_step) self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune) + warning: Optional[SamplerWarning] = None if hmc_step.divergence_info: info = hmc_step.divergence_info point = None @@ -205,7 +204,7 @@ def astep(self, q0): point = DictToArrayBijection.rmap(info.state.q) if self._num_divs_sample < 100 and info.state_div is not None: - point = DictToArrayBijection.rmap(info.state_div.q) + point_dest = DictToArrayBijection.rmap(info.state_div.q) if self._num_divs_sample < 100: info_store = info @@ -220,11 +219,7 @@ def astep(self, q0): divergence_info=info_store, ) - self._warnings.append(warning) - self.iter_count += 1 - if not self.tune: - self._samples_after_tune += 1 stats = { "tune": self.tune, @@ -232,6 +227,7 @@ def astep(self, q0): "perf_counter_diff": perf_end - perf_start, "process_time_diff": process_end - process_start, "perf_counter_start": perf_start, + "warning": warning, } stats.update(hmc_step.stats) @@ -247,32 +243,3 @@ def reset_tuning(self, start=None): def reset(self, start=None): self.tune = True self.potential.reset() - - def warnings(self): - # list.copy() is not available in python2 - warnings = self._warnings[:] - - # Generate a global warning for divergences - message = "" - n_divs = self._num_divs_sample - if n_divs and self._samples_after_tune == n_divs: - message = ( - "The chain contains only diverging samples. The model " "is probably misspecified." - ) - elif n_divs == 1: - message = ( - "There was 1 divergence after tuning. Increase " - "`target_accept` or reparameterize." - ) - elif n_divs > 1: - message = ( - "There were %s divergences after tuning. Increase " - "`target_accept` or reparameterize." % n_divs - ) - - if message: - warning = SamplerWarning(WarningType.DIVERGENCES, message, "error") - warnings.append(warning) - - warnings.extend(self.step_adapt.warnings()) - return warnings diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 6f1522596e..4ed192ac99 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -14,6 +14,7 @@ import numpy as np +from pymc.stats.convergence import SamplerWarning from pymc.step_methods.arraystep import Competence from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData from pymc.step_methods.hmc.integration import IntegrationError, State @@ -53,6 +54,7 @@ class HamiltonianMC(BaseHMC): "perf_counter_start": np.float64, "largest_eigval": np.float64, "smallest_eigval": np.float64, + "warning": SamplerWarning, } ] diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 5dd2231a8e..0eed003da1 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -18,7 +18,7 @@ from pymc.aesaraf import floatX from pymc.math import logbern -from pymc.stats.convergence import SamplerWarning, WarningType +from pymc.stats.convergence import SamplerWarning from pymc.step_methods.arraystep import Competence from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData from pymc.step_methods.hmc.integration import IntegrationError @@ -114,6 +114,8 @@ class NUTS(BaseHMC): "largest_eigval": np.float64, "smallest_eigval": np.float64, "index_in_trajectory": np.int64, + "reached_max_treedepth": bool, + "warning": SamplerWarning, } ] @@ -189,6 +191,7 @@ def _hamiltonian_step(self, start, p0, step_size): tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax) + reached_max_treedepth = False for _ in range(max_treedepth): direction = logbern(np.log(0.5)) * 2 - 1 divergence_info, turning = tree.extend(direction) @@ -196,11 +199,11 @@ def _hamiltonian_step(self, start, p0, step_size): if divergence_info or turning: break else: - if not self.tune: - self._reached_max_treedepth += 1 + reached_max_treedepth = not self.tune stats = tree.stats() accept_stat = stats["mean_tree_accept"] + stats["reached_max_treedepth"] = reached_max_treedepth return HMCStepData(tree.proposal, accept_stat, divergence_info, stats) @staticmethod @@ -212,20 +215,6 @@ def competence(var, has_grad): return Competence.PREFERRED return Competence.INCOMPATIBLE - def warnings(self): - warnings = super().warnings() - n_samples = self._samples_after_tune - n_treedepth = self._reached_max_treedepth - - if n_samples > 0 and n_treedepth / float(n_samples) > 0.05: - msg = ( - "The chain reached the maximum tree depth. Increase " - "max_treedepth, increase target_accept or reparameterize." - ) - warn = SamplerWarning(WarningType.TREEDEPTH, msg, "warn") - warnings.append(warn) - return warnings - # A proposal for the next position Proposal = namedtuple("Proposal", "q, q_grad, energy, logp, index_in_trajectory") diff --git a/pymc/tests/stats/test_convergence.py b/pymc/tests/stats/test_convergence.py new file mode 100644 index 0000000000..796731953a --- /dev/null +++ b/pymc/tests/stats/test_convergence.py @@ -0,0 +1,29 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import arviz +import numpy as np + +from pymc.stats import convergence + + +def test_warn_divergences(): + idata = arviz.from_dict( + sample_stats={ + "diverging": np.array([[1, 0, 1, 0], [0, 0, 0, 0]]).astype(bool), + } + ) + warns = convergence.warn_divergences(idata) + assert len(warns) == 1 + assert "2 divergences after tuning" in warns[0].message diff --git a/pymc/tests/step_methods/hmc/test_nuts.py b/pymc/tests/step_methods/hmc/test_nuts.py index fb24c0e8d3..89dba215cf 100644 --- a/pymc/tests/step_methods/hmc/test_nuts.py +++ b/pymc/tests/step_methods/hmc/test_nuts.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import sys import warnings @@ -121,29 +122,17 @@ def test_bad_init_parallel(self): pm.sample(cores=2, random_seed=1) error.match("Initial evaluation") - def test_linalg(self, caplog): + def test_emits_energy_warnings(self, caplog): with pm.Model(): a = pm.Normal("a", size=2, initval=floatX(np.zeros(2))) a = at.switch(a > 0, np.inf, a) b = at.slinalg.solve(floatX(np.eye(2)), a, check_finite=False) pm.Normal("c", mu=b, size=2, initval=floatX(np.r_[0.0, 0.0])) caplog.clear() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - trace = pm.sample(20, tune=5, chains=2, return_inferencedata=False, random_seed=526) - warns = [msg.msg for msg in caplog.records] - assert np.any(trace["diverging"]) - assert ( - any("divergence after tuning" in warn for warn in warns) - or any("divergences after tuning" in warn for warn in warns) - or any("only diverging samples" in warn for warn in warns) - ) - - with pytest.raises(ValueError) as error: - trace.report.raise_ok() - error.match("issues during sampling") - - assert not trace.report.ok + # The logger name must be specified for DEBUG level capturing to work + with caplog.at_level(logging.DEBUG, logger="pymc"): + idata = pm.sample(20, tune=5, chains=2, random_seed=526) + assert any("Energy change" in w.msg for w in caplog.records) def test_sampler_stats(self): with pm.Model() as model: @@ -168,12 +157,19 @@ def test_sampler_stats(self): "perf_counter_diff", "perf_counter_start", "process_time_diff", + "reached_max_treedepth", "index_in_trajectory", "largest_eigval", "smallest_eigval", + "warning", } assert trace.stat_names == expected_stat_names for varname in trace.stat_names: + if varname == "warning": + # Warnings don't squeeze reliably. + # But once we stop squeezing alltogether that's going to be OK. + # See https://github.com/pymc-devs/pymc/issues/6207 + continue assert trace.get_sampler_stats(varname).shape == (10,) # Assert model logp is computed correctly: computing post-sampling diff --git a/pymc/tests/test_parallel_sampling.py b/pymc/tests/test_parallel_sampling.py index 2032e7e071..2883acd297 100644 --- a/pymc/tests/test_parallel_sampling.py +++ b/pymc/tests/test_parallel_sampling.py @@ -14,6 +14,7 @@ import multiprocessing import os import platform +import sys import warnings import aesara @@ -74,7 +75,7 @@ def test_bad_unpickle(): @as_op([at_vector, at.iscalar], [at_vector]) def _crash_remote_process(a, master_pid): if os.getpid() != master_pid: - os.exit(0) + sys.exit(0) return 2 * np.array(a) @@ -86,7 +87,7 @@ def test_remote_pipe_closed(): pm.Normal("y", mu=_crash_remote_process(x, at_pid), shape=2) step = pm.Metropolis() - with pytest.raises(RuntimeError, match="Chain [0-9] failed"): + with pytest.raises(ps.ParallelSamplingError, match="Chain [0-9] failed with") as ex: pm.sample(step=step, mp_ctx="spawn", tune=2, draws=2, cores=2, chains=2) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 5efb34b89a..80f290e906 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -50,6 +50,7 @@ compile_forward_sampling_function, get_vars_in_point_list, ) +from pymc.stats.convergence import SamplerWarning, WarningType from pymc.step_methods import ( NUTS, BinaryGibbsMetropolis, @@ -1737,6 +1738,109 @@ def test_step_args(): npt.assert_allclose(idata1.sample_stats.scaling, 0) +def test_log_warning_stats(caplog): + s1 = dict(warning="Temperature too low!") + s2 = dict(warning="Temperature too high!") + stats = [s1, s2] + + with caplog.at_level(logging.WARNING): + pm.sampling.log_warning_stats(stats) + + # We have a list of stats dicts, because there might be several samplers involved. + assert "too low" in caplog.records[0].message + assert "too high" in caplog.records[1].message + + +def test_log_warning_stats_knows_SamplerWarning(caplog): + """Checks that SamplerWarning "warning" stats get special treatment.""" + stats = [dict(warning=SamplerWarning(WarningType.BAD_ENERGY, "Not that interesting", "debug"))] + + with caplog.at_level(logging.DEBUG, logger="pymc"): + pm.sampling.log_warning_stats(stats) + + assert "Not that interesting" in caplog.records[0].message + + +class ApolypticMetropolis(pm.Metropolis): + """A stepper that warns in every iteration.""" + + stats_dtypes = [ + { + **pm.Metropolis.stats_dtypes[0], + "warning": SamplerWarning, + } + ] + + def astep(self, q0): + draw, stats = super().astep(q0) + stats[0]["warning"] = SamplerWarning( + WarningType.BAD_ENERGY, + "Asteroid incoming!", + "warn", + ) + return draw, stats + + +@pytest.mark.parametrize("cores", [1, 2]) +def test_logs_sampler_warnings(caplog, cores): + """Asserts that "warning" sampler stats are logged during sampling.""" + with pm.Model(): + pm.Normal("n") + with caplog.at_level(logging.WARNING): + idata = pm.sample( + tune=2, + draws=3, + cores=cores, + chains=cores, + step=ApolypticMetropolis(), + compute_convergence_checks=False, + discard_tuned_samples=False, + keep_warning_stat=True, + ) + + # Sampler warnings should be logged + nwarns = sum("Asteroid" in rec.message for rec in caplog.records) + assert nwarns == (2 + 3) * cores + + +@pytest.mark.parametrize("keep_warning_stat", [None, True]) +def test_keep_warning_stat_setting(keep_warning_stat): + """The ``keep_warning_stat`` stat (aka "Adrian's kwarg) enables users + to keep the ``SamplerWarning`` objects from the ``sample_stats.warning`` group. + This breaks ``idata.to_netcdf()`` which is why it defaults to ``False``. + """ + sample_kwargs = dict( + tune=2, + draws=3, + chains=1, + compute_convergence_checks=False, + discard_tuned_samples=False, + keep_warning_stat=keep_warning_stat, + ) + if keep_warning_stat: + sample_kwargs["keep_warning_stat"] = True + with pm.Model(): + pm.Normal("n") + idata = pm.sample(step=ApolypticMetropolis(), **sample_kwargs) + + if keep_warning_stat: + assert "warning" in idata.warmup_sample_stats + assert "warning" in idata.sample_stats + # And end up in the InferenceData + assert "warning" in idata.sample_stats + # NOTE: The stats are squeezed by default but this does not always work. + # This tests flattens so we don't have to be exact in accessing (non-)squeezed items. + # Also see https://github.com/pymc-devs/pymc/issues/6207. + warn_objs = list(idata.sample_stats.warning.sel(chain=0).values.flatten()) + assert any(isinstance(w, SamplerWarning) for w in warn_objs) + assert any("Asteroid" in w.message for w in warn_objs) + else: + assert "warning" not in idata.warmup_sample_stats + assert "warning" not in idata.sample_stats + assert "warning_dim_0" not in idata.warmup_sample_stats + assert "warning_dim_0" not in idata.sample_stats + + def test_init_nuts(caplog): with pm.Model() as model: a = pm.Normal("a") diff --git a/pymc/tests/test_util.py b/pymc/tests/test_util.py index 570c070b78..4f5d72648c 100644 --- a/pymc/tests/test_util.py +++ b/pymc/tests/test_util.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import arviz import numpy as np import pytest import xarray @@ -24,6 +25,7 @@ from pymc.util import ( UNSET, dataset_to_point_list, + drop_warning_stat, hash_key, hashable, locally_cachedmethod, @@ -164,3 +166,35 @@ def test_dataset_to_point_list(): ds[3] = xarray.DataArray([1, 2, 3]) with pytest.raises(ValueError, match="must be str"): dataset_to_point_list(ds, sample_dims=["chain", "draw"]) + + +def test_drop_warning_stat(): + idata = arviz.from_dict( + sample_stats={ + "a": np.ones((2, 5, 4)), + "warning": np.ones((2, 5, 3), dtype=object), + }, + warmup_sample_stats={ + "a": np.ones((2, 5, 4)), + "warning": np.ones((2, 5, 3), dtype=object), + }, + attrs=dict(version="0.1.2"), + coords={ + "adim": [0, 1, None, 3], + "warning_dim_0": list("ABC"), + }, + dims={"a": ["adim"], "warning": ["warning_dim_0"]}, + save_warmup=True, + ) + + new = drop_warning_stat(idata) + + assert new is not idata + assert new.attrs.get("version") == "0.1.2" + + for gname in ["sample_stats", "warmup_sample_stats"]: + ss = new.get(gname) + assert isinstance(ss, xarray.Dataset), gname + assert "a" in ss + assert "warning" not in ss + assert "warning_dim_0" not in ss diff --git a/pymc/util.py b/pymc/util.py index bb126fc62d..ff62512edb 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -14,8 +14,9 @@ import functools -from typing import Any, Dict, List, Tuple, cast +from typing import Any, Dict, List, Tuple, Union, cast +import arviz import cloudpickle import numpy as np import xarray @@ -252,6 +253,39 @@ def dataset_to_point_list( return cast(List[Dict[str, np.ndarray]], points), stacked_dims +def drop_warning_stat(idata: arviz.InferenceData) -> arviz.InferenceData: + """Returns a new ``InferenceData`` object with the "warning" stat removed from sample stats groups. + + This function should be applied to an ``InferenceData`` object obtained with + ``pm.sample(keep_warning_stat=True)`` before trying to ``.to_netcdf()`` or ``.to_zarr()`` it. + """ + nidata = arviz.InferenceData(attrs=idata.attrs) + for gname, group in idata.items(): + if "sample_stat" in gname: + group = group.drop_vars(names=["warning", "warning_dim_0"], errors="ignore") + nidata.add_groups({gname: group}, coords=group.coords, dims=group.dims) + return nidata + + +def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tuple[int, int]: + """Extract and return number of chains and samples in xarray or arviz traces.""" + dataset: xarray.Dataset + if isinstance(data, xarray.Dataset): + dataset = data + elif isinstance(data, arviz.InferenceData): + dataset = data["posterior"] + else: + raise ValueError( + "Argument must be xarray Dataset or arviz InferenceData. Got %s", + data.__class__, + ) + + coords = dataset.coords + nchains = coords["chain"].sizes["chain"] + nsamples = coords["draw"].sizes["draw"] + return nchains, nsamples + + def hashable(a=None) -> int: """ Hashes many kinds of objects, including some that are unhashable through the builtin `hash` function. diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 84b9f558ac..adcf4a153c 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -56,6 +56,7 @@ pymc/smc/sample_smc.py pymc/smc/smc.py pymc/stats/__init__.py +pymc/stats/convergence.py pymc/step_methods/__init__.py pymc/step_methods/compound.py pymc/step_methods/hmc/__init__.py