-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Preregister shapes of sampler stats #6517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Preregister shapes of sampler stats #6517
Conversation
168e19c
to
6d2bf43
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6517 +/- ##
===========================================
- Coverage 94.73% 23.69% -71.05%
===========================================
Files 147 147
Lines 27864 27913 +49
===========================================
- Hits 26398 6613 -19785
- Misses 1466 21300 +19834
|
@covertg want to take a shot a reviewing here? |
Sure, I'll do my best! First time for pymc so obviously please take it with some grains of salt, but hope this is helpful. I don't have the experience with NUTS to confirm the shapes and dtypes so I won't comment there. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main comment is regarding whether we really want to fix stats_dtypes
and stats_dtypes_shapes
in the constructor __new__
, or whether it would be better to allow the step method to modify them later.
@@ -77,12 +131,21 @@ def __new__(cls, *args, **kwargs): | |||
if len(vars) == 0: | |||
raise ValueError("No free random variables to sample.") | |||
|
|||
# Auto-fill stats metadata attributes from whichever was given. | |||
stats_dtypes, stats_dtypes_shapes = infer_warn_stats_info( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assigning stats_dtypes
and stats_dtypes_shapes
here in the constructor means that they could fall out of sync later, if only one were to get modified. Is this a case we should consider? I could imagine this happening if a step method wanted to determine stat shape at initialization, for example perhaps for a stat shape that varies with the number of variables passed to the step method.
I'm not sure if that is a compelling case or not. But if it is — perhap these two attributes would be better exposed via @property
with getters and setters to ensure they stay in sync?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that we should get the flexibility to have samplers initialize stats on instantiation instead of specifying them as class attributes.
I also considered properties, but this wouldn't have worked nicely with the assignment of these fields in the class definition.
I don't think that sync is a problem because we can remove the old attribute.
I will think about how a refactor to definition at initialization will look like.
result = {} | ||
for s, step in enumerate(steps): | ||
for sname, (dtype, shape) in step.stats_dtypes_shapes.items(): | ||
result[f"sampler_{s}__{sname}"] = (dtype, shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently these dictionary keys are not passed to arviz when creating InferenceData objects. But if/when they are, we'll probably want a way for the user to map back from <step name>
to <list of variables that were sampled by that stepper>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. ArviZ should add such a field, but even before that gets done we can add this to mcbackend.RunMeta
6d2bf43
to
cb90744
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, except for a nitpick
pymc/step_methods/compound.py
Outdated
Shapes are interpreted in the following ways: | ||
- `[]` is a scalar. | ||
- `[3,]` is a length-3 vector. | ||
- `[4, -1]` is a matrix with 4 rows and a dynamic number of columns. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would vote to use None
at the PyMC level as that maps directly to the convention used in PyTensor. When we use these to create a McBackend trace then we can map None
to -1
.
Other backends can use None
to specify unknown shape if they want.
- `[4, -1]` is a matrix with 4 rows and a dynamic number of columns. | |
- `[4, None]` is a matrix with 4 rows and a dynamic number of columns. |
cb90744
to
b14897e
Compare
As described in #6503 this adds a
stats_dtypes_shape
attribute toBlockedStep
to replaceBlockedStep.stats_dtypes
.It is implemented in a backwards-compatible manner with a deprecation warning for samplers that are not yet updated.
I also specified all stat shapes I was confident about. If someone with NUTS stat experiments could comment the remaining ones that'd be great!
Related issues
BlockedStep.stats_dtypes_shapes
to step signatures #6503Checklist
New features
stats_dtypes_shapes
class attribute. Thestats_dtypes
attribute is being deprecated.