Skip to content

Commit 7e9e17d

Browse files
Specify stats_dtypes_shapes for all samplers
1 parent 99f17c7 commit 7e9e17d

File tree

6 files changed

+75
-91
lines changed

6 files changed

+75
-91
lines changed

pymc/step_methods/hmc/hmc.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,25 @@ class HamiltonianMC(BaseHMC):
3939

4040
name = "hmc"
4141
default_blocked = True
42-
stats_dtypes = [
43-
{
44-
"step_size": np.float64,
45-
"n_steps": np.int64,
46-
"tune": bool,
47-
"step_size_bar": np.float64,
48-
"accept": np.float64,
49-
"diverging": bool,
50-
"energy_error": np.float64,
51-
"energy": np.float64,
52-
"path_length": np.float64,
53-
"accepted": bool,
54-
"model_logp": np.float64,
55-
"process_time_diff": np.float64,
56-
"perf_counter_diff": np.float64,
57-
"perf_counter_start": np.float64,
58-
"largest_eigval": np.float64,
59-
"smallest_eigval": np.float64,
60-
"warning": SamplerWarning,
61-
}
62-
]
42+
stats_dtypes_shapes = {
43+
"step_size": (np.float64, []),
44+
"n_steps": (np.int64, []),
45+
"tune": (bool, []),
46+
"step_size_bar": (np.float64, []),
47+
"accept": (np.float64, []),
48+
"diverging": (bool, []),
49+
"energy_error": (np.float64, []),
50+
"energy": (np.float64, []),
51+
"path_length": (np.float64, []),
52+
"accepted": (bool, []),
53+
"model_logp": (np.float64, []),
54+
"process_time_diff": (np.float64, []),
55+
"perf_counter_diff": (np.float64, []),
56+
"perf_counter_start": (np.float64, []),
57+
"largest_eigval": (np.float64, []),
58+
"smallest_eigval": (np.float64, []),
59+
"warning": (SamplerWarning, None),
60+
}
6361

6462
def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs):
6563
"""

pymc/step_methods/hmc/nuts.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -97,29 +97,27 @@ class NUTS(BaseHMC):
9797
name = "nuts"
9898

9999
default_blocked = True
100-
stats_dtypes = [
101-
{
102-
"depth": np.int64,
103-
"step_size": np.float64,
104-
"tune": bool,
105-
"mean_tree_accept": np.float64,
106-
"step_size_bar": np.float64,
107-
"tree_size": np.float64,
108-
"diverging": bool,
109-
"energy_error": np.float64,
110-
"energy": np.float64,
111-
"max_energy_error": np.float64,
112-
"model_logp": np.float64,
113-
"process_time_diff": np.float64,
114-
"perf_counter_diff": np.float64,
115-
"perf_counter_start": np.float64,
116-
"largest_eigval": np.float64,
117-
"smallest_eigval": np.float64,
118-
"index_in_trajectory": np.int64,
119-
"reached_max_treedepth": bool,
120-
"warning": SamplerWarning,
121-
}
122-
]
100+
stats_dtypes_shapes = {
101+
"depth": (np.int64, []),
102+
"step_size": (np.float64, []),
103+
"tune": (bool, []),
104+
"mean_tree_accept": (np.float64, []),
105+
"step_size_bar": (np.float64, []),
106+
"tree_size": (np.float64, []),
107+
"diverging": (bool, []),
108+
"energy_error": (np.float64, []),
109+
"energy": (np.float64, []),
110+
"max_energy_error": (np.float64, []),
111+
"model_logp": (np.float64, []),
112+
"process_time_diff": (np.float64, []),
113+
"perf_counter_diff": (np.float64, []),
114+
"perf_counter_start": (np.float64, []),
115+
"largest_eigval": (np.float64, []),
116+
"smallest_eigval": (np.float64, []),
117+
"index_in_trajectory": (np.int64, []),
118+
"reached_max_treedepth": (bool, []),
119+
"warning": (SamplerWarning, None),
120+
}
123121

124122
def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs):
125123
r"""Set up the No-U-Turn sampler.

pymc/step_methods/metropolis.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,12 @@ class Metropolis(ArrayStepShared):
117117
name = "metropolis"
118118

119119
default_blocked = False
120-
stats_dtypes = [
121-
{
122-
"accept": np.float64,
123-
"accepted": np.float64,
124-
"tune": bool,
125-
"scaling": np.float64,
126-
}
127-
]
120+
stats_dtypes_shapes = {
121+
"accept": (np.float64, []),
122+
"accepted": (np.float64, []),
123+
"tune": (bool, []),
124+
"scaling": (np.float64, []),
125+
}
128126

129127
def __init__(
130128
self,
@@ -363,13 +361,11 @@ class BinaryMetropolis(ArrayStep):
363361

364362
name = "binary_metropolis"
365363

366-
stats_dtypes = [
367-
{
368-
"accept": np.float64,
369-
"tune": bool,
370-
"p_jump": np.float64,
371-
}
372-
]
364+
stats_dtypes_shapes = {
365+
"accept": (np.float64, []),
366+
"tune": (bool, []),
367+
"p_jump": (np.float64, []),
368+
}
373369

374370
def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
375371
model = pm.modelcontext(model)
@@ -726,15 +722,13 @@ class DEMetropolis(PopulationArrayStepShared):
726722
name = "DEMetropolis"
727723

728724
default_blocked = True
729-
stats_dtypes = [
730-
{
731-
"accept": np.float64,
732-
"accepted": bool,
733-
"tune": bool,
734-
"scaling": np.float64,
735-
"lambda": np.float64,
736-
}
737-
]
725+
stats_dtypes_shapes = {
726+
"accept": (np.float64, []),
727+
"accepted": (bool, []),
728+
"tune": (bool, []),
729+
"scaling": (np.float64, []),
730+
"lambda": (np.float64, []),
731+
}
738732

739733
def __init__(
740734
self,
@@ -871,15 +865,13 @@ class DEMetropolisZ(ArrayStepShared):
871865
name = "DEMetropolisZ"
872866

873867
default_blocked = True
874-
stats_dtypes = [
875-
{
876-
"accept": np.float64,
877-
"accepted": bool,
878-
"tune": bool,
879-
"scaling": np.float64,
880-
"lambda": np.float64,
881-
}
882-
]
868+
stats_dtypes_shapes = {
869+
"accept": (np.float64, []),
870+
"accepted": (bool, []),
871+
"tune": (bool, []),
872+
"scaling": (np.float64, []),
873+
"lambda": (np.float64, []),
874+
}
883875

884876
def __init__(
885877
self,

pymc/step_methods/slicer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,10 @@ class Slice(ArrayStep):
5050

5151
name = "slice"
5252
default_blocked = False
53-
stats_dtypes = [
54-
{
55-
"nstep_out": int,
56-
"nstep_in": int,
57-
}
58-
]
53+
stats_dtypes_shapes = {
54+
"nstep_out": (int, []),
55+
"nstep_in": (int, []),
56+
}
5957

6058
def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, **kwargs):
6159
self.model = modelcontext(model)

pymc/tests/sampling/test_mcmc.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -633,12 +633,10 @@ def test_step_args():
633633
class ApocalypticMetropolis(pm.Metropolis):
634634
"""A stepper that warns in every iteration."""
635635

636-
stats_dtypes = [
637-
{
638-
**pm.Metropolis.stats_dtypes[0],
639-
"warning": SamplerWarning,
640-
}
641-
]
636+
stats_dtypes_shapes = {
637+
**pm.Metropolis.stats_dtypes_shapes,
638+
"warning": (SamplerWarning, None),
639+
}
642640

643641
def astep(self, q0):
644642
draw, stats = super().astep(q0)

pymc/tests/step_methods/test_compound.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def test_stats_from_steps(self):
135135
cs = pm.CompoundStep([s1, s2])
136136
# Make sure that sampler initialization does not modify the
137137
# class-level default values of the attributes.
138-
assert pm.NUTS.stats_dtypes_shapes == {}
139-
assert pm.Metropolis.stats_dtypes_shapes == {}
138+
assert pm.NUTS.stats_dtypes == []
139+
assert pm.Metropolis.stats_dtypes == []
140140

141141
sds = get_stats_dtypes_shapes_from_steps([s1, s2])
142142
assert "sampler_0__step_size" in sds

0 commit comments

Comments
 (0)