diff --git a/pymc/blocking.py b/pymc/blocking.py index b034d8624e..7ad228c86c 100644 --- a/pymc/blocking.py +++ b/pymc/blocking.py @@ -20,7 +20,18 @@ from __future__ import annotations from functools import partial -from typing import Any, Callable, Dict, Generic, List, NamedTuple, TypeVar +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + NamedTuple, + Optional, + Sequence, + TypeVar, + Union, +) import numpy as np @@ -33,6 +44,8 @@ PointType: TypeAlias = Dict[str, np.ndarray] StatsDict: TypeAlias = Dict[str, Any] StatsType: TypeAlias = List[StatsDict] +StatDtype: TypeAlias = Union[type, np.dtype] +StatShape: TypeAlias = Optional[Sequence[Optional[int]]] # `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 34627dc3f7..4b9147f526 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -18,15 +18,17 @@ @author: johnsalvatier """ +import warnings + from abc import ABC, abstractmethod from enum import IntEnum, unique -from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union +from typing import Any, Dict, Iterable, List, Mapping, Sequence, Tuple, Union import numpy as np from pytensor.graph.basic import Variable -from pymc.blocking import PointType, StatsDict, StatsType +from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType from pymc.model import modelcontext __all__ = ("Competence", "CompoundStep") @@ -48,9 +50,61 @@ class Competence(IntEnum): IDEAL = 3 +def infer_warn_stats_info( + stats_dtypes: List[Dict[str, StatDtype]], + sds: Dict[str, Tuple[StatDtype, StatShape]], + stepname: str, +) -> Tuple[List[Dict[str, StatDtype]], Dict[str, Tuple[StatDtype, StatShape]]]: + """Helper function to get `stats_dtypes` and `stats_dtypes_shapes` from either of them.""" + # Avoid side-effects on the original lists/dicts + stats_dtypes = [d.copy() for d in stats_dtypes] + sds = sds.copy() + # Disallow specification of both attributes + if stats_dtypes and sds: + raise TypeError( + "Only one of `stats_dtypes_shapes` or `stats_dtypes` must be specified." + f" `{stepname}.stats_dtypes` should be removed." + ) + + # Infer one from the other + if not sds and stats_dtypes: + warnings.warn( + f"`{stepname}.stats_dtypes` is deprecated." + " Please update it to specify `stats_dtypes_shapes` instead.", + DeprecationWarning, + ) + if len(stats_dtypes) > 1: + raise TypeError( + f"`{stepname}.stats_dtypes` must be a list containing at most one dict." + ) + for sd in stats_dtypes: + for sname, dtype in sd.items(): + sds[sname] = (dtype, None) + elif sds: + stats_dtypes.append({sname: dtype for sname, (dtype, _) in sds.items()}) + return stats_dtypes, sds + + class BlockedStep(ABC): stats_dtypes: List[Dict[str, type]] = [] + """A list containing <=1 dictionary that maps stat names to dtypes. + + This attribute is deprecated. + Use `stats_dtypes_shapes` instead. + """ + + stats_dtypes_shapes: Dict[str, Tuple[StatDtype, StatShape]] = {} + """Maps stat names to dtypes and shapes. + + Shapes are interpreted in the following ways: + - `[]` is a scalar. + - `[3,]` is a length-3 vector. + - `[4, None]` is a matrix with 4 rows and a dynamic number of columns. + - `None` is a sparse stat (i.e. not always present) or a NumPy array with varying `ndim`. + """ + vars: List[Variable] = [] + """Variables that the step method is assigned to.""" def __new__(cls, *args, **kwargs): blocked = kwargs.get("blocked") @@ -77,12 +131,21 @@ def __new__(cls, *args, **kwargs): if len(vars) == 0: raise ValueError("No free random variables to sample.") + # Auto-fill stats metadata attributes from whichever was given. + stats_dtypes, stats_dtypes_shapes = infer_warn_stats_info( + cls.stats_dtypes, + cls.stats_dtypes_shapes, + cls.__name__, + ) + if not blocked and len(vars) > 1: # In this case we create a separate sampler for each var # and append them to a CompoundStep steps = [] for var in vars: step = super().__new__(cls) + step.stats_dtypes = stats_dtypes + step.stats_dtypes_shapes = stats_dtypes_shapes # If we don't return the instance we have to manually # call __init__ step.__init__([var], *args, **kwargs) @@ -93,6 +156,8 @@ def __new__(cls, *args, **kwargs): return CompoundStep(steps) else: step = super().__new__(cls) + step.stats_dtypes = stats_dtypes + step.stats_dtypes_shapes = stats_dtypes_shapes # Hack for creating the class correctly when unpickling. step.__newargs = (vars,) + args, kwargs return step @@ -126,6 +191,20 @@ def stop_tuning(self): self.tune = False +def get_stats_dtypes_shapes_from_steps( + steps: Iterable[BlockedStep], +) -> Dict[str, Tuple[StatDtype, StatShape]]: + """Combines stats dtype shape dictionaries from multiple step methods. + + In the resulting stats dict, each sampler stat is prefixed by `sampler_#__`. + """ + result = {} + for s, step in enumerate(steps): + for sname, (dtype, shape) in step.stats_dtypes_shapes.items(): + result[f"sampler_{s}__{sname}"] = (dtype, shape) + return result + + class CompoundStep: """Step method composed of a list of several other step methods applied in sequence.""" @@ -135,6 +214,7 @@ def __init__(self, methods): self.stats_dtypes = [] for method in self.methods: self.stats_dtypes.extend(method.stats_dtypes) + self.stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps(methods) self.name = ( f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]" ) @@ -187,6 +267,10 @@ def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None: for s, names_dtypes in enumerate(sampler_stats_dtypes) ] + @property + def n_samplers(self) -> int: + return len(self._stat_groups) + def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict: """Combine stats dicts of multiple samplers into one dict.""" stats_dict = {} diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 9c531b6950..2a92a0c332 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -39,27 +39,25 @@ class HamiltonianMC(BaseHMC): name = "hmc" default_blocked = True - stats_dtypes = [ - { - "step_size": np.float64, - "n_steps": np.int64, - "tune": bool, - "step_size_bar": np.float64, - "accept": np.float64, - "diverging": bool, - "energy_error": np.float64, - "energy": np.float64, - "path_length": np.float64, - "accepted": bool, - "model_logp": np.float64, - "process_time_diff": np.float64, - "perf_counter_diff": np.float64, - "perf_counter_start": np.float64, - "largest_eigval": np.float64, - "smallest_eigval": np.float64, - "warning": SamplerWarning, - } - ] + stats_dtypes_shapes = { + "step_size": (np.float64, []), + "n_steps": (np.int64, []), + "tune": (bool, []), + "step_size_bar": (np.float64, []), + "accept": (np.float64, []), + "diverging": (bool, []), + "energy_error": (np.float64, []), + "energy": (np.float64, []), + "path_length": (np.float64, []), + "accepted": (bool, []), + "model_logp": (np.float64, []), + "process_time_diff": (np.float64, []), + "perf_counter_diff": (np.float64, []), + "perf_counter_start": (np.float64, []), + "largest_eigval": (np.float64, []), + "smallest_eigval": (np.float64, []), + "warning": (SamplerWarning, None), + } def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs): """ diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 68fd0c479d..9d377205b3 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -97,29 +97,27 @@ class NUTS(BaseHMC): name = "nuts" default_blocked = True - stats_dtypes = [ - { - "depth": np.int64, - "step_size": np.float64, - "tune": bool, - "mean_tree_accept": np.float64, - "step_size_bar": np.float64, - "tree_size": np.float64, - "diverging": bool, - "energy_error": np.float64, - "energy": np.float64, - "max_energy_error": np.float64, - "model_logp": np.float64, - "process_time_diff": np.float64, - "perf_counter_diff": np.float64, - "perf_counter_start": np.float64, - "largest_eigval": np.float64, - "smallest_eigval": np.float64, - "index_in_trajectory": np.int64, - "reached_max_treedepth": bool, - "warning": SamplerWarning, - } - ] + stats_dtypes_shapes = { + "depth": (np.int64, []), + "step_size": (np.float64, []), + "tune": (bool, []), + "mean_tree_accept": (np.float64, []), + "step_size_bar": (np.float64, []), + "tree_size": (np.float64, []), + "diverging": (bool, []), + "energy_error": (np.float64, []), + "energy": (np.float64, []), + "max_energy_error": (np.float64, []), + "model_logp": (np.float64, []), + "process_time_diff": (np.float64, []), + "perf_counter_diff": (np.float64, []), + "perf_counter_start": (np.float64, []), + "largest_eigval": (np.float64, []), + "smallest_eigval": (np.float64, []), + "index_in_trajectory": (np.int64, []), + "reached_max_treedepth": (bool, []), + "warning": (SamplerWarning, None), + } def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs): r"""Set up the No-U-Turn sampler. diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index de53edb9f0..0ee10a24d6 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -117,14 +117,12 @@ class Metropolis(ArrayStepShared): name = "metropolis" default_blocked = False - stats_dtypes = [ - { - "accept": np.float64, - "accepted": np.float64, - "tune": bool, - "scaling": np.float64, - } - ] + stats_dtypes_shapes = { + "accept": (np.float64, []), + "accepted": (np.float64, []), + "tune": (bool, []), + "scaling": (np.float64, []), + } def __init__( self, @@ -363,13 +361,11 @@ class BinaryMetropolis(ArrayStep): name = "binary_metropolis" - stats_dtypes = [ - { - "accept": np.float64, - "tune": bool, - "p_jump": np.float64, - } - ] + stats_dtypes_shapes = { + "accept": (np.float64, []), + "tune": (bool, []), + "p_jump": (np.float64, []), + } def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None): model = pm.modelcontext(model) @@ -726,15 +722,13 @@ class DEMetropolis(PopulationArrayStepShared): name = "DEMetropolis" default_blocked = True - stats_dtypes = [ - { - "accept": np.float64, - "accepted": bool, - "tune": bool, - "scaling": np.float64, - "lambda": np.float64, - } - ] + stats_dtypes_shapes = { + "accept": (np.float64, []), + "accepted": (bool, []), + "tune": (bool, []), + "scaling": (np.float64, []), + "lambda": (np.float64, []), + } def __init__( self, @@ -871,15 +865,13 @@ class DEMetropolisZ(ArrayStepShared): name = "DEMetropolisZ" default_blocked = True - stats_dtypes = [ - { - "accept": np.float64, - "accepted": bool, - "tune": bool, - "scaling": np.float64, - "lambda": np.float64, - } - ] + stats_dtypes_shapes = { + "accept": (np.float64, []), + "accepted": (bool, []), + "tune": (bool, []), + "scaling": (np.float64, []), + "lambda": (np.float64, []), + } def __init__( self, diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 75fc409ab4..8d844187fb 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -50,12 +50,10 @@ class Slice(ArrayStep): name = "slice" default_blocked = False - stats_dtypes = [ - { - "nstep_out": int, - "nstep_in": int, - } - ] + stats_dtypes_shapes = { + "nstep_out": (int, []), + "nstep_in": (int, []), + } def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, **kwargs): self.model = modelcontext(model) diff --git a/pymc/tests/sampling/test_mcmc.py b/pymc/tests/sampling/test_mcmc.py index 85686d235e..f6ebad3450 100644 --- a/pymc/tests/sampling/test_mcmc.py +++ b/pymc/tests/sampling/test_mcmc.py @@ -633,12 +633,10 @@ def test_step_args(): class ApocalypticMetropolis(pm.Metropolis): """A stepper that warns in every iteration.""" - stats_dtypes = [ - { - **pm.Metropolis.stats_dtypes[0], - "warning": SamplerWarning, - } - ] + stats_dtypes_shapes = { + **pm.Metropolis.stats_dtypes_shapes, + "warning": (SamplerWarning, None), + } def astep(self, q0): draw, stats = super().astep(q0) diff --git a/pymc/tests/step_methods/test_compound.py b/pymc/tests/step_methods/test_compound.py index 3a3a9b809b..a87cd914d0 100644 --- a/pymc/tests/step_methods/test_compound.py +++ b/pymc/tests/step_methods/test_compound.py @@ -25,7 +25,12 @@ Metropolis, Slice, ) -from pymc.step_methods.compound import StatsBijection, flatten_steps +from pymc.step_methods.compound import ( + StatsBijection, + flatten_steps, + get_stats_dtypes_shapes_from_steps, + infer_warn_stats_info, +) from pymc.tests.helpers import StepMethodTester, fast_unstable_sampling_mode from pymc.tests.models import simple_2model_continuous @@ -94,6 +99,52 @@ def test_compound_step(self): assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(step.vars) +class TestStatsMetadata: + def test_infer_warn_stats_info(self): + """ + Until `BlockedStep.stats_dtypes` is removed, the new `stats_dtypes_shapes` + attributed is inferred from `stats_dtypes`, or vice versa. + """ + # Infer new + with pytest.warns(DeprecationWarning, match="to specify"): + old, new = infer_warn_stats_info([{"a": int, "b": object}], {}, "bla") + assert isinstance(old, list) + assert len(old) == 1 + assert old[0] == {"a": int, "b": object} + assert isinstance(new, dict) + assert new["a"] == (int, None) + assert new["b"] == (object, None) + + # Infer old + old, new = infer_warn_stats_info([], {"a": (int, []), "b": (float, [2])}, "bla") + assert isinstance(old, list) + assert len(old) == 1 + assert old[0] == {"a": int, "b": float} + assert isinstance(new, dict) + assert new["a"] == (int, []) + assert new["b"] == (float, [2]) + + # Disallow specifying both (single source of truth problem) + with pytest.raises(TypeError, match="Only one of"): + infer_warn_stats_info([{"a": float}], {"b": (int, [])}, "bla") + + def test_stats_from_steps(self): + with pm.Model(): + s1 = pm.NUTS(pm.Normal("n")) + s2 = pm.Metropolis(pm.Bernoulli("b", 0.5)) + cs = pm.CompoundStep([s1, s2]) + # Make sure that sampler initialization does not modify the + # class-level default values of the attributes. + assert pm.NUTS.stats_dtypes == [] + assert pm.Metropolis.stats_dtypes == [] + + sds = get_stats_dtypes_shapes_from_steps([s1, s2]) + assert "sampler_0__step_size" in sds + assert "sampler_1__accepted" in sds + assert len(cs.stats_dtypes) == 2 + assert cs.stats_dtypes_shapes == sds + + class TestStatsBijection: def test_flatten_steps(self): with pm.Model(): @@ -116,6 +167,7 @@ def test_stats_bijection(self): {"a": float, "c": int}, ] bij = StatsBijection(step_stats_dtypes) + assert bij.n_samplers == 2 stats_l = [ dict(a=1.5, b=3), dict(a=2.5, c=4),