Skip to content

Commit 04f2612

Browse files
Restore plots and stats wrappers
See pymc-devs#4528
1 parent 03448f7 commit 04f2612

File tree

5 files changed

+196
-9
lines changed

5 files changed

+196
-9
lines changed

Diff for: docs/source/api/plots.rst

+12-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,16 @@ Plots are delegated to the
88
`ArviZ <https://arviz-devs.github.io/arviz/index.html>`_.
99
library, a general purpose library for
1010
"exploratory analysis of Bayesian models."
11-
Refer to its documentation to use the plotting functions directly.
11+
For plots, ``pymc3.<function>`` are now aliases
12+
for ArviZ functions. Thus, the links below will redirect you to
13+
ArviZ docs:
1214

13-
.. automodule:: pymc3.plots.posteriorplot
14-
:members:
15+
- :func:`pymc3.traceplot <arviz:arviz.plot_trace>`
16+
- :func:`pymc3.plot_posterior <arviz:arviz.plot_posterior>`
17+
- :func:`pymc3.forestplot <arviz:arviz.plot_forest>`
18+
- :func:`pymc3.compareplot <arviz:arviz.plot_compare>`
19+
- :func:`pymc3.autocorrplot <arviz:arviz.plot_autocorr>`
20+
- :func:`pymc3.energyplot <arviz:arviz.plot_energy>`
21+
- :func:`pymc3.kdeplot <arviz:arviz.plot_kde>`
22+
- :func:`pymc3.densityplot <arviz:arviz.plot_density>`
23+
- :func:`pymc3.pairplot <arviz:arviz.plot_pair>`

Diff for: docs/source/api/stats.rst

+18-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,21 @@ Statistics and diagnostics are delegated to the
55
`ArviZ <https://arviz-devs.github.io/arviz/index.html>`_.
66
library, a general purpose library for
77
"exploratory analysis of Bayesian models."
8-
Refer to its documentation to use the diagnostics functions directly.
8+
For statistics and diagnostics, ``pymc3.<function>`` are now aliases
9+
for ArviZ functions. Thus, the links below will redirect you to
10+
ArviZ docs:
11+
12+
.. currentmodule:: pymc3.stats
13+
14+
15+
- :func:`pymc3.bfmi <arviz:arviz.bfmi>`
16+
- :func:`pymc3.compare <arviz:arviz.compare>`
17+
- :func:`pymc3.ess <arviz:arviz.ess>`
18+
- :data:`pymc3.geweke <arviz:arviz.geweke>`
19+
- :func:`pymc3.hpd <arviz:arviz.hpd>`
20+
- :func:`pymc3.loo <arviz:arviz.loo>`
21+
- :func:`pymc3.mcse <arviz:arviz.mcse>`
22+
- :func:`pymc3.r2_score <arviz:arviz.r2_score>`
23+
- :func:`pymc3.rhat <arviz:arviz.rhat>`
24+
- :func:`pymc3.summary <arviz:arviz.summary>`
25+
- :func:`pymc3.waic <arviz:arviz.waic>`

Diff for: pymc3/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __set_compiler_flags():
6161
from pymc3.plots import *
6262
from pymc3.sampling import *
6363
from pymc3.smc import *
64+
from pymc3.stats import *
6465
from pymc3.step_methods import *
6566
from pymc3.tests import test
6667
from pymc3.theanof import *

Diff for: pymc3/plots/__init__.py

+96-5
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,108 @@
1414

1515
"""PyMC3 Plotting.
1616
17-
Plots are delegated to the `ArviZ <https://arviz-devs.github.io/arviz/>`_ library, a general purpose library for
18-
exploratory analysis of Bayesian models. For more details, see https://arviz-devs.github.io/arviz/.
19-
20-
Only `plot_posterior_predictive_glm` is kept in the PyMC code base for now, but it will move to ArviZ once the latter adds features for regression plots.
17+
Plots are delegated to the ArviZ library, a general purpose library for
18+
"exploratory analysis of Bayesian models." See https://arviz-devs.github.io/arviz/
19+
for details on plots.
2120
"""
2221
import functools
2322
import sys
2423
import warnings
2524

2625
import arviz as az
2726

27+
28+
def map_args(func):
29+
swaps = [("varnames", "var_names")]
30+
31+
@functools.wraps(func)
32+
def wrapped(*args, **kwargs):
33+
for (old, new) in swaps:
34+
if old in kwargs and new not in kwargs:
35+
warnings.warn(
36+
f"Keyword argument `{old}` renamed to `{new}`, and will be removed in pymc3 3.8"
37+
)
38+
kwargs[new] = kwargs.pop(old)
39+
return func(*args, **kwargs)
40+
41+
return wrapped
42+
43+
44+
# pymc3 custom plots: override these names for custom behavior
45+
autocorrplot = map_args(az.plot_autocorr)
46+
forestplot = map_args(az.plot_forest)
47+
kdeplot = map_args(az.plot_kde)
48+
plot_posterior = map_args(az.plot_posterior)
49+
energyplot = map_args(az.plot_energy)
50+
densityplot = map_args(az.plot_density)
51+
pairplot = map_args(az.plot_pair)
52+
53+
# Use compact traceplot by default
54+
@map_args
55+
@functools.wraps(az.plot_trace)
56+
def traceplot(*args, **kwargs):
57+
try:
58+
kwargs.setdefault("compact", True)
59+
return az.plot_trace(*args, **kwargs)
60+
except TypeError:
61+
kwargs.pop("compact")
62+
return az.plot_trace(*args, **kwargs)
63+
64+
65+
# addition arg mapping for compare plot
66+
@functools.wraps(az.plot_compare)
67+
def compareplot(*args, **kwargs):
68+
if "comp_df" in kwargs:
69+
comp_df = kwargs["comp_df"].copy()
70+
else:
71+
args = list(args)
72+
comp_df = args[0].copy()
73+
if "WAIC" in comp_df.columns:
74+
comp_df = comp_df.rename(
75+
index=str,
76+
columns={
77+
"WAIC": "waic",
78+
"pWAIC": "p_waic",
79+
"dWAIC": "d_waic",
80+
"SE": "se",
81+
"dSE": "dse",
82+
"var_warn": "warning",
83+
},
84+
)
85+
elif "LOO" in comp_df.columns:
86+
comp_df = comp_df.rename(
87+
index=str,
88+
columns={
89+
"LOO": "loo",
90+
"pLOO": "p_loo",
91+
"dLOO": "d_loo",
92+
"SE": "se",
93+
"dSE": "dse",
94+
"shape_warn": "warning",
95+
},
96+
)
97+
if "comp_df" in kwargs:
98+
kwargs["comp_df"] = comp_df
99+
else:
100+
args[0] = comp_df
101+
return az.plot_compare(*args, **kwargs)
102+
103+
28104
from pymc3.plots.posteriorplot import plot_posterior_predictive_glm
29105

30-
__all__ = ["plot_posterior_predictive_glm"]
106+
# Access to arviz plots: base plots provided by arviz
107+
for plot in az.plots.__all__:
108+
setattr(sys.modules[__name__], plot, map_args(getattr(az.plots, plot)))
109+
110+
__all__ = tuple(az.plots.__all__) + (
111+
"autocorrplot",
112+
"compareplot",
113+
"forestplot",
114+
"kdeplot",
115+
"plot_posterior",
116+
"traceplot",
117+
"energyplot",
118+
"densityplot",
119+
"pairplot",
120+
"plot_posterior_predictive_glm",
121+
)

Diff for: pymc3/stats/__init__.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Statistical utility functions for PyMC3
16+
17+
Diagnostics and auxiliary statistical functions are delegated to the ArviZ library, a general
18+
purpose library for "exploratory analysis of Bayesian models." See
19+
https://arviz-devs.github.io/arviz/ for details.
20+
"""
21+
import functools
22+
import warnings
23+
24+
import arviz as az
25+
26+
27+
def map_args(func):
28+
swaps = [("varnames", "var_names")]
29+
30+
@functools.wraps(func)
31+
def wrapped(*args, **kwargs):
32+
for (old, new) in swaps:
33+
if old in kwargs and new not in kwargs:
34+
warnings.warn(
35+
"Keyword argument `{old}` renamed to `{new}`, and will be removed in "
36+
"pymc3 3.9".format(old=old, new=new)
37+
)
38+
kwargs[new] = kwargs.pop(old)
39+
return func(*args, **kwargs)
40+
41+
return wrapped
42+
43+
44+
bfmi = map_args(az.bfmi)
45+
compare = map_args(az.compare)
46+
ess = map_args(az.ess)
47+
geweke = map_args(az.geweke)
48+
hpd = map_args(az.hpd)
49+
loo = map_args(az.loo)
50+
mcse = map_args(az.mcse)
51+
r2_score = map_args(az.r2_score)
52+
rhat = map_args(az.rhat)
53+
summary = map_args(az.summary)
54+
waic = map_args(az.waic)
55+
56+
57+
__all__ = [
58+
"bfmi",
59+
"compare",
60+
"ess",
61+
"geweke",
62+
"hpd",
63+
"loo",
64+
"mcse",
65+
"r2_score",
66+
"rhat",
67+
"summary",
68+
"waic",
69+
]

0 commit comments

Comments
 (0)