Skip to content

Let plot_posterior_predictive_glm work with inferencedata too #4234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
cbab50c
adapt for both
MarcoGorelli Nov 19, 2020
2eeb5ae
add multitrace test
MarcoGorelli Nov 19, 2020
22991c9
add inferencedata test, cover all lines
MarcoGorelli Nov 19, 2020
3eb3410
update release notes
MarcoGorelli Nov 19, 2020
5c37ef7
sort imports (it's not checked yet?)
MarcoGorelli Nov 19, 2020
58d0839
fixup test
MarcoGorelli Nov 19, 2020
c145b0a
fixup PR number
MarcoGorelli Nov 19, 2020
e28f2b8
:label: add type annotations, correct docstring
MarcoGorelli Nov 20, 2020
08c11fc
Merge branch 'extend-plot_posterior_predictive_glm' of github.com:mar…
MarcoGorelli Nov 20, 2020
987e170
sort
MarcoGorelli Nov 20, 2020
f93da72
sort
MarcoGorelli Nov 20, 2020
40ac14b
add copyright note
MarcoGorelli Nov 20, 2020
9540abb
use todict
MarcoGorelli Nov 20, 2020
a0fa02b
:art:
MarcoGorelli Nov 20, 2020
258018a
don't cover import error (matplotlib is a dev requirement)
MarcoGorelli Nov 22, 2020
9147b34
don't cover import typechecking
MarcoGorelli Nov 22, 2020
ddcebeb
remove optional mpl import
MarcoGorelli Nov 26, 2020
fa7687f
remove optional mpl import
MarcoGorelli Nov 26, 2020
2ceaa13
remove optional mpl import
MarcoGorelli Nov 26, 2020
9cf211a
Merge remote-tracking branch 'upstream/master' into extend-plot_poste…
MarcoGorelli Nov 26, 2020
63a5a85
Merge branch 'master' into extend-plot_posterior_predictive_glm
MarcoGorelli Dec 16, 2020
0106d77
Update RELEASE-NOTES.md
MarcoGorelli Dec 16, 2020
3fe8210
use Python3.7+ type hints
MarcoGorelli Dec 17, 2020
9e012fb
Merge remote-tracking branch 'upstream/master' into extend-plot_poste…
MarcoGorelli Dec 17, 2020
2a38f67
Merge branch 'extend-plot_posterior_predictive_glm' of github.com:Mar…
MarcoGorelli Dec 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
- Add Bayesian Additive Regression Trees (BARTs) [#4183](https://github.com/pymc-devs/pymc3/pull/4183))
- Added a new `MixtureSameFamily` distribution to handle mixtures of arbitrary dimensions in vectorized form (see [#4185](https://github.com/pymc-devs/pymc3/issues/4185)).
- Added semantically meaningful `str` representations to PyMC3 objects for console, notebook, and GraphViz use (see [#4076](https://github.com/pymc-devs/pymc3/pull/4076), [#4065](https://github.com/pymc-devs/pymc3/pull/4065), [#4159](https://github.com/pymc-devs/pymc3/pull/4159), and [#4217](https://github.com/pymc-devs/pymc3/pull/4217))
- `plot_posterior_predictive_glm` now works with `arviz.InferenceData` as well (see [#4234](https://github.com/pymc-devs/pymc3/pull/4234))



Expand Down
17 changes: 16 additions & 1 deletion pymc3/plots/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
pass
import numpy as np

from pymc3.backends.base import MultiTrace


def plot_posterior_predictive_glm(trace, eval=None, lm=None, samples=30, **kwargs):
"""Plot posterior predictive of a linear model.
Expand Down Expand Up @@ -47,10 +49,23 @@ def plot_posterior_predictive_glm(trace, eval=None, lm=None, samples=30, **kwarg
if "c" not in kwargs and "color" not in kwargs:
kwargs["c"] = "k"

plotting_fn = _plot_multitrace if isinstance(trace, MultiTrace) else _plot_inferencedata
plotting_fn(trace, eval, lm, samples, kwargs)
plt.title("Posterior predictive")


def _plot_multitrace(trace, eval, lm, samples, kwargs):
for rand_loc in np.random.randint(0, len(trace), samples):
rand_sample = trace[rand_loc]
plt.plot(eval, lm(eval, rand_sample), **kwargs)
# Make sure to not plot label multiple times
kwargs.pop("label", None)

plt.title("Posterior predictive")

def _plot_inferencedata(trace, eval, lm, samples, kwargs):
trace_df = trace.posterior.to_dataframe()
for rand_loc in np.random.randint(0, len(trace_df), samples):
rand_sample = trace_df.iloc[rand_loc]
plt.plot(eval, lm(eval, rand_sample), **kwargs)
# Make sure to not plot label multiple times
kwargs.pop("label", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two functions have a lot of duplicated lines; I think they can be merged into one by checking if isinstance(trace, MultiTrace) at the beginning of the function (or just before) and casting the InferenceData to_array (I think this is the name of the function but you can check on ArviZ website) instead of to a dataframe.
After that, the handling should be the same as you're dealing with numpy arrays in both cases

Copy link
Contributor Author

@MarcoGorelli MarcoGorelli Nov 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we cast to array then I think we wouldn't be able to access the different parameters (e.g. 'Intercept' or 'x'), which appear in lm. Each element here is a dict in the multitrace case:

> /home/mgorelli/pymc3-dev/pymc3/plots/posteriorplot.py(61)_plot_multitrace()
-> plt.plot(eval, lm(eval, rand_sample), **kwargs)
(Pdb) type(rand_sample)
<class 'dict'>
(Pdb) rand_sample
{'x': 1.0, 'Intercept': 1.0}

at this point, the only lines they have in common are

        plt.plot(eval, lm(eval, rand_sample), **kwargs)
        # Make sure to not plot label multiple times
        kwargs.pop("label", None)
  • the others are slightly different. My reason for making two separate helper functions is that I thought it'd be more readable than a single function with many if/then statements - I'll go with whatever you think is best though 😇

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, I forgot the whole trace was given here, and not only trace["y"] for instance. But then, wouldn't trace.posterior.to_dataframe().to_dict() get the format we want? That way we'd need only one plotting function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure - I think this'd be slightly more expensive, but arguably it's worth it for the sake of much simpler code

54 changes: 54 additions & 0 deletions pymc3/tests/test_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest

from arviz import from_pymc3

import pymc3 as pm

from pymc3.backends.ndarray import point_list_to_multitrace
from pymc3.plots import plot_posterior_predictive_glm


@pytest.mark.parametrize("inferencedata", [True, False])
def test_plot_posterior_predictive_glm_defaults(inferencedata):
with pm.Model() as model:
pm.Normal("x")
pm.Normal("Intercept")
trace = point_list_to_multitrace([{"x": np.array([1]), "Intercept": np.array([1])}], model)
if inferencedata:
trace = from_pymc3(trace, model=model)
_, ax = plt.subplots()
plot_posterior_predictive_glm(trace, samples=1)
lines = ax.get_lines()
expected_xvalues = np.linspace(0, 1, 100)
expected_yvalues = np.linspace(1, 2, 100)
for line in lines:
x_axis, y_axis = line.get_data()
np.testing.assert_array_equal(x_axis, expected_xvalues)
np.testing.assert_array_equal(y_axis, expected_yvalues)
assert line.get_lw() == 0.2
assert line.get_c() == "k"


@pytest.mark.parametrize("inferencedata", [True, False])
def test_plot_posterior_predictive_glm_non_defaults(inferencedata):
with pm.Model() as model:
pm.Normal("x")
pm.Normal("Intercept")
trace = point_list_to_multitrace([{"x": np.array([1]), "Intercept": np.array([1])}], model)
if inferencedata:
trace = from_pymc3(trace, model=model)
_, ax = plt.subplots()
plot_posterior_predictive_glm(
trace, samples=1, lm=lambda x, _: x, eval=np.linspace(0, 1, 10), lw=0.3, c="b"
)
lines = ax.get_lines()
expected_xvalues = np.linspace(0, 1, 10)
expected_yvalues = np.linspace(0, 1, 10)
for line in lines:
x_axis, y_axis = line.get_data()
np.testing.assert_array_equal(x_axis, expected_xvalues)
np.testing.assert_array_equal(y_axis, expected_yvalues)
assert line.get_lw() == 0.3
assert line.get_c() == "b"