Skip to content

Preregister shapes of sampler stats #6517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion pymc/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,18 @@
from __future__ import annotations

from functools import partial
from typing import Any, Callable, Dict, Generic, List, NamedTuple, TypeVar
from typing import (
Any,
Callable,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)

import numpy as np

Expand All @@ -33,6 +44,8 @@
PointType: TypeAlias = Dict[str, np.ndarray]
StatsDict: TypeAlias = Dict[str, Any]
StatsType: TypeAlias = List[StatsDict]
StatDtype: TypeAlias = Union[type, np.dtype]
StatShape: TypeAlias = Optional[Sequence[Optional[int]]]


# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
Expand Down
88 changes: 86 additions & 2 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
@author: johnsalvatier
"""

import warnings

from abc import ABC, abstractmethod
from enum import IntEnum, unique
from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, List, Mapping, Sequence, Tuple, Union

import numpy as np

from pytensor.graph.basic import Variable

from pymc.blocking import PointType, StatsDict, StatsType
from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType
from pymc.model import modelcontext

__all__ = ("Competence", "CompoundStep")
Expand All @@ -48,9 +50,61 @@ class Competence(IntEnum):
IDEAL = 3


def infer_warn_stats_info(
stats_dtypes: List[Dict[str, StatDtype]],
sds: Dict[str, Tuple[StatDtype, StatShape]],
stepname: str,
) -> Tuple[List[Dict[str, StatDtype]], Dict[str, Tuple[StatDtype, StatShape]]]:
"""Helper function to get `stats_dtypes` and `stats_dtypes_shapes` from either of them."""
# Avoid side-effects on the original lists/dicts
stats_dtypes = [d.copy() for d in stats_dtypes]
sds = sds.copy()
# Disallow specification of both attributes
if stats_dtypes and sds:
raise TypeError(
"Only one of `stats_dtypes_shapes` or `stats_dtypes` must be specified."
f" `{stepname}.stats_dtypes` should be removed."
)

# Infer one from the other
if not sds and stats_dtypes:
warnings.warn(
f"`{stepname}.stats_dtypes` is deprecated."
" Please update it to specify `stats_dtypes_shapes` instead.",
DeprecationWarning,
)
if len(stats_dtypes) > 1:
raise TypeError(
f"`{stepname}.stats_dtypes` must be a list containing at most one dict."
)
for sd in stats_dtypes:
for sname, dtype in sd.items():
sds[sname] = (dtype, None)
elif sds:
stats_dtypes.append({sname: dtype for sname, (dtype, _) in sds.items()})
return stats_dtypes, sds


class BlockedStep(ABC):
stats_dtypes: List[Dict[str, type]] = []
"""A list containing <=1 dictionary that maps stat names to dtypes.

This attribute is deprecated.
Use `stats_dtypes_shapes` instead.
"""

stats_dtypes_shapes: Dict[str, Tuple[StatDtype, StatShape]] = {}
"""Maps stat names to dtypes and shapes.

Shapes are interpreted in the following ways:
- `[]` is a scalar.
- `[3,]` is a length-3 vector.
- `[4, None]` is a matrix with 4 rows and a dynamic number of columns.
- `None` is a sparse stat (i.e. not always present) or a NumPy array with varying `ndim`.
"""

vars: List[Variable] = []
"""Variables that the step method is assigned to."""

def __new__(cls, *args, **kwargs):
blocked = kwargs.get("blocked")
Expand All @@ -77,12 +131,21 @@ def __new__(cls, *args, **kwargs):
if len(vars) == 0:
raise ValueError("No free random variables to sample.")

# Auto-fill stats metadata attributes from whichever was given.
stats_dtypes, stats_dtypes_shapes = infer_warn_stats_info(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assigning stats_dtypes and stats_dtypes_shapes here in the constructor means that they could fall out of sync later, if only one were to get modified. Is this a case we should consider? I could imagine this happening if a step method wanted to determine stat shape at initialization, for example perhaps for a stat shape that varies with the number of variables passed to the step method.

I'm not sure if that is a compelling case or not. But if it is — perhap these two attributes would be better exposed via @property with getters and setters to ensure they stay in sync?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should get the flexibility to have samplers initialize stats on instantiation instead of specifying them as class attributes.

I also considered properties, but this wouldn't have worked nicely with the assignment of these fields in the class definition.

I don't think that sync is a problem because we can remove the old attribute.

I will think about how a refactor to definition at initialization will look like.

cls.stats_dtypes,
cls.stats_dtypes_shapes,
cls.__name__,
)

if not blocked and len(vars) > 1:
# In this case we create a separate sampler for each var
# and append them to a CompoundStep
steps = []
for var in vars:
step = super().__new__(cls)
step.stats_dtypes = stats_dtypes
step.stats_dtypes_shapes = stats_dtypes_shapes
# If we don't return the instance we have to manually
# call __init__
step.__init__([var], *args, **kwargs)
Expand All @@ -93,6 +156,8 @@ def __new__(cls, *args, **kwargs):
return CompoundStep(steps)
else:
step = super().__new__(cls)
step.stats_dtypes = stats_dtypes
step.stats_dtypes_shapes = stats_dtypes_shapes
# Hack for creating the class correctly when unpickling.
step.__newargs = (vars,) + args, kwargs
return step
Expand Down Expand Up @@ -126,6 +191,20 @@ def stop_tuning(self):
self.tune = False


def get_stats_dtypes_shapes_from_steps(
steps: Iterable[BlockedStep],
) -> Dict[str, Tuple[StatDtype, StatShape]]:
"""Combines stats dtype shape dictionaries from multiple step methods.

In the resulting stats dict, each sampler stat is prefixed by `sampler_#__`.
"""
result = {}
for s, step in enumerate(steps):
for sname, (dtype, shape) in step.stats_dtypes_shapes.items():
result[f"sampler_{s}__{sname}"] = (dtype, shape)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently these dictionary keys are not passed to arviz when creating InferenceData objects. But if/when they are, we'll probably want a way for the user to map back from <step name> to <list of variables that were sampled by that stepper>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. ArviZ should add such a field, but even before that gets done we can add this to mcbackend.RunMeta

return result


class CompoundStep:
"""Step method composed of a list of several other step
methods applied in sequence."""
Expand All @@ -135,6 +214,7 @@ def __init__(self, methods):
self.stats_dtypes = []
for method in self.methods:
self.stats_dtypes.extend(method.stats_dtypes)
self.stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps(methods)
self.name = (
f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]"
)
Expand Down Expand Up @@ -187,6 +267,10 @@ def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None:
for s, names_dtypes in enumerate(sampler_stats_dtypes)
]

@property
def n_samplers(self) -> int:
return len(self._stat_groups)

def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict:
"""Combine stats dicts of multiple samplers into one dict."""
stats_dict = {}
Expand Down
40 changes: 19 additions & 21 deletions pymc/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,25 @@ class HamiltonianMC(BaseHMC):

name = "hmc"
default_blocked = True
stats_dtypes = [
{
"step_size": np.float64,
"n_steps": np.int64,
"tune": bool,
"step_size_bar": np.float64,
"accept": np.float64,
"diverging": bool,
"energy_error": np.float64,
"energy": np.float64,
"path_length": np.float64,
"accepted": bool,
"model_logp": np.float64,
"process_time_diff": np.float64,
"perf_counter_diff": np.float64,
"perf_counter_start": np.float64,
"largest_eigval": np.float64,
"smallest_eigval": np.float64,
"warning": SamplerWarning,
}
]
stats_dtypes_shapes = {
"step_size": (np.float64, []),
"n_steps": (np.int64, []),
"tune": (bool, []),
"step_size_bar": (np.float64, []),
"accept": (np.float64, []),
"diverging": (bool, []),
"energy_error": (np.float64, []),
"energy": (np.float64, []),
"path_length": (np.float64, []),
"accepted": (bool, []),
"model_logp": (np.float64, []),
"process_time_diff": (np.float64, []),
"perf_counter_diff": (np.float64, []),
"perf_counter_start": (np.float64, []),
"largest_eigval": (np.float64, []),
"smallest_eigval": (np.float64, []),
"warning": (SamplerWarning, None),
}

def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs):
"""
Expand Down
44 changes: 21 additions & 23 deletions pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,29 +97,27 @@ class NUTS(BaseHMC):
name = "nuts"

default_blocked = True
stats_dtypes = [
{
"depth": np.int64,
"step_size": np.float64,
"tune": bool,
"mean_tree_accept": np.float64,
"step_size_bar": np.float64,
"tree_size": np.float64,
"diverging": bool,
"energy_error": np.float64,
"energy": np.float64,
"max_energy_error": np.float64,
"model_logp": np.float64,
"process_time_diff": np.float64,
"perf_counter_diff": np.float64,
"perf_counter_start": np.float64,
"largest_eigval": np.float64,
"smallest_eigval": np.float64,
"index_in_trajectory": np.int64,
"reached_max_treedepth": bool,
"warning": SamplerWarning,
}
]
stats_dtypes_shapes = {
"depth": (np.int64, []),
"step_size": (np.float64, []),
"tune": (bool, []),
"mean_tree_accept": (np.float64, []),
"step_size_bar": (np.float64, []),
"tree_size": (np.float64, []),
"diverging": (bool, []),
"energy_error": (np.float64, []),
"energy": (np.float64, []),
"max_energy_error": (np.float64, []),
"model_logp": (np.float64, []),
"process_time_diff": (np.float64, []),
"perf_counter_diff": (np.float64, []),
"perf_counter_start": (np.float64, []),
"largest_eigval": (np.float64, []),
"smallest_eigval": (np.float64, []),
"index_in_trajectory": (np.int64, []),
"reached_max_treedepth": (bool, []),
"warning": (SamplerWarning, None),
}

def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs):
r"""Set up the No-U-Turn sampler.
Expand Down
58 changes: 25 additions & 33 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,12 @@ class Metropolis(ArrayStepShared):
name = "metropolis"

default_blocked = False
stats_dtypes = [
{
"accept": np.float64,
"accepted": np.float64,
"tune": bool,
"scaling": np.float64,
}
]
stats_dtypes_shapes = {
"accept": (np.float64, []),
"accepted": (np.float64, []),
"tune": (bool, []),
"scaling": (np.float64, []),
}

def __init__(
self,
Expand Down Expand Up @@ -363,13 +361,11 @@ class BinaryMetropolis(ArrayStep):

name = "binary_metropolis"

stats_dtypes = [
{
"accept": np.float64,
"tune": bool,
"p_jump": np.float64,
}
]
stats_dtypes_shapes = {
"accept": (np.float64, []),
"tune": (bool, []),
"p_jump": (np.float64, []),
}

def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
model = pm.modelcontext(model)
Expand Down Expand Up @@ -726,15 +722,13 @@ class DEMetropolis(PopulationArrayStepShared):
name = "DEMetropolis"

default_blocked = True
stats_dtypes = [
{
"accept": np.float64,
"accepted": bool,
"tune": bool,
"scaling": np.float64,
"lambda": np.float64,
}
]
stats_dtypes_shapes = {
"accept": (np.float64, []),
"accepted": (bool, []),
"tune": (bool, []),
"scaling": (np.float64, []),
"lambda": (np.float64, []),
}

def __init__(
self,
Expand Down Expand Up @@ -871,15 +865,13 @@ class DEMetropolisZ(ArrayStepShared):
name = "DEMetropolisZ"

default_blocked = True
stats_dtypes = [
{
"accept": np.float64,
"accepted": bool,
"tune": bool,
"scaling": np.float64,
"lambda": np.float64,
}
]
stats_dtypes_shapes = {
"accept": (np.float64, []),
"accepted": (bool, []),
"tune": (bool, []),
"scaling": (np.float64, []),
"lambda": (np.float64, []),
}

def __init__(
self,
Expand Down
10 changes: 4 additions & 6 deletions pymc/step_methods/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,10 @@ class Slice(ArrayStep):

name = "slice"
default_blocked = False
stats_dtypes = [
{
"nstep_out": int,
"nstep_in": int,
}
]
stats_dtypes_shapes = {
"nstep_out": (int, []),
"nstep_in": (int, []),
}

def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, **kwargs):
self.model = modelcontext(model)
Expand Down
10 changes: 4 additions & 6 deletions pymc/tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,12 +633,10 @@ def test_step_args():
class ApocalypticMetropolis(pm.Metropolis):
"""A stepper that warns in every iteration."""

stats_dtypes = [
{
**pm.Metropolis.stats_dtypes[0],
"warning": SamplerWarning,
}
]
stats_dtypes_shapes = {
**pm.Metropolis.stats_dtypes_shapes,
"warning": (SamplerWarning, None),
}

def astep(self, q0):
draw, stats = super().astep(q0)
Expand Down
Loading