-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Preregister shapes of sampler stats #6517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,15 +18,17 @@ | |
@author: johnsalvatier | ||
""" | ||
|
||
import warnings | ||
|
||
from abc import ABC, abstractmethod | ||
from enum import IntEnum, unique | ||
from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union | ||
from typing import Any, Dict, Iterable, List, Mapping, Sequence, Tuple, Union | ||
|
||
import numpy as np | ||
|
||
from pytensor.graph.basic import Variable | ||
|
||
from pymc.blocking import PointType, StatsDict, StatsType | ||
from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType | ||
from pymc.model import modelcontext | ||
|
||
__all__ = ("Competence", "CompoundStep") | ||
|
@@ -48,9 +50,61 @@ class Competence(IntEnum): | |
IDEAL = 3 | ||
|
||
|
||
def infer_warn_stats_info( | ||
stats_dtypes: List[Dict[str, StatDtype]], | ||
sds: Dict[str, Tuple[StatDtype, StatShape]], | ||
stepname: str, | ||
) -> Tuple[List[Dict[str, StatDtype]], Dict[str, Tuple[StatDtype, StatShape]]]: | ||
"""Helper function to get `stats_dtypes` and `stats_dtypes_shapes` from either of them.""" | ||
# Avoid side-effects on the original lists/dicts | ||
stats_dtypes = [d.copy() for d in stats_dtypes] | ||
sds = sds.copy() | ||
# Disallow specification of both attributes | ||
if stats_dtypes and sds: | ||
raise TypeError( | ||
"Only one of `stats_dtypes_shapes` or `stats_dtypes` must be specified." | ||
f" `{stepname}.stats_dtypes` should be removed." | ||
) | ||
|
||
# Infer one from the other | ||
if not sds and stats_dtypes: | ||
warnings.warn( | ||
f"`{stepname}.stats_dtypes` is deprecated." | ||
" Please update it to specify `stats_dtypes_shapes` instead.", | ||
DeprecationWarning, | ||
) | ||
if len(stats_dtypes) > 1: | ||
raise TypeError( | ||
f"`{stepname}.stats_dtypes` must be a list containing at most one dict." | ||
) | ||
for sd in stats_dtypes: | ||
for sname, dtype in sd.items(): | ||
sds[sname] = (dtype, None) | ||
elif sds: | ||
stats_dtypes.append({sname: dtype for sname, (dtype, _) in sds.items()}) | ||
return stats_dtypes, sds | ||
|
||
|
||
class BlockedStep(ABC): | ||
stats_dtypes: List[Dict[str, type]] = [] | ||
"""A list containing <=1 dictionary that maps stat names to dtypes. | ||
|
||
This attribute is deprecated. | ||
Use `stats_dtypes_shapes` instead. | ||
""" | ||
|
||
stats_dtypes_shapes: Dict[str, Tuple[StatDtype, StatShape]] = {} | ||
"""Maps stat names to dtypes and shapes. | ||
|
||
Shapes are interpreted in the following ways: | ||
- `[]` is a scalar. | ||
- `[3,]` is a length-3 vector. | ||
- `[4, None]` is a matrix with 4 rows and a dynamic number of columns. | ||
- `None` is a sparse stat (i.e. not always present) or a NumPy array with varying `ndim`. | ||
""" | ||
|
||
vars: List[Variable] = [] | ||
"""Variables that the step method is assigned to.""" | ||
|
||
def __new__(cls, *args, **kwargs): | ||
blocked = kwargs.get("blocked") | ||
|
@@ -77,12 +131,21 @@ def __new__(cls, *args, **kwargs): | |
if len(vars) == 0: | ||
raise ValueError("No free random variables to sample.") | ||
|
||
# Auto-fill stats metadata attributes from whichever was given. | ||
stats_dtypes, stats_dtypes_shapes = infer_warn_stats_info( | ||
cls.stats_dtypes, | ||
cls.stats_dtypes_shapes, | ||
cls.__name__, | ||
) | ||
|
||
if not blocked and len(vars) > 1: | ||
# In this case we create a separate sampler for each var | ||
# and append them to a CompoundStep | ||
steps = [] | ||
for var in vars: | ||
step = super().__new__(cls) | ||
step.stats_dtypes = stats_dtypes | ||
step.stats_dtypes_shapes = stats_dtypes_shapes | ||
# If we don't return the instance we have to manually | ||
# call __init__ | ||
step.__init__([var], *args, **kwargs) | ||
|
@@ -93,6 +156,8 @@ def __new__(cls, *args, **kwargs): | |
return CompoundStep(steps) | ||
else: | ||
step = super().__new__(cls) | ||
step.stats_dtypes = stats_dtypes | ||
step.stats_dtypes_shapes = stats_dtypes_shapes | ||
# Hack for creating the class correctly when unpickling. | ||
step.__newargs = (vars,) + args, kwargs | ||
return step | ||
|
@@ -126,6 +191,20 @@ def stop_tuning(self): | |
self.tune = False | ||
|
||
|
||
def get_stats_dtypes_shapes_from_steps( | ||
steps: Iterable[BlockedStep], | ||
) -> Dict[str, Tuple[StatDtype, StatShape]]: | ||
"""Combines stats dtype shape dictionaries from multiple step methods. | ||
|
||
In the resulting stats dict, each sampler stat is prefixed by `sampler_#__`. | ||
""" | ||
result = {} | ||
for s, step in enumerate(steps): | ||
for sname, (dtype, shape) in step.stats_dtypes_shapes.items(): | ||
result[f"sampler_{s}__{sname}"] = (dtype, shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently these dictionary keys are not passed to arviz when creating InferenceData objects. But if/when they are, we'll probably want a way for the user to map back from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. ArviZ should add such a field, but even before that gets done we can add this to |
||
return result | ||
|
||
|
||
class CompoundStep: | ||
"""Step method composed of a list of several other step | ||
methods applied in sequence.""" | ||
|
@@ -135,6 +214,7 @@ def __init__(self, methods): | |
self.stats_dtypes = [] | ||
for method in self.methods: | ||
self.stats_dtypes.extend(method.stats_dtypes) | ||
self.stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps(methods) | ||
self.name = ( | ||
f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]" | ||
) | ||
|
@@ -187,6 +267,10 @@ def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None: | |
for s, names_dtypes in enumerate(sampler_stats_dtypes) | ||
] | ||
|
||
@property | ||
def n_samplers(self) -> int: | ||
return len(self._stat_groups) | ||
|
||
def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict: | ||
"""Combine stats dicts of multiple samplers into one dict.""" | ||
stats_dict = {} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assigning
stats_dtypes
andstats_dtypes_shapes
here in the constructor means that they could fall out of sync later, if only one were to get modified. Is this a case we should consider? I could imagine this happening if a step method wanted to determine stat shape at initialization, for example perhaps for a stat shape that varies with the number of variables passed to the step method.I'm not sure if that is a compelling case or not. But if it is — perhap these two attributes would be better exposed via
@property
with getters and setters to ensure they stay in sync?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that we should get the flexibility to have samplers initialize stats on instantiation instead of specifying them as class attributes.
I also considered properties, but this wouldn't have worked nicely with the assignment of these fields in the class definition.
I don't think that sync is a problem because we can remove the old attribute.
I will think about how a refactor to definition at initialization will look like.