Skip to content

Let plot_posterior_predictive_glm work with inferencedata too #4234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
cbab50c
adapt for both
MarcoGorelli Nov 19, 2020
2eeb5ae
add multitrace test
MarcoGorelli Nov 19, 2020
22991c9
add inferencedata test, cover all lines
MarcoGorelli Nov 19, 2020
3eb3410
update release notes
MarcoGorelli Nov 19, 2020
5c37ef7
sort imports (it's not checked yet?)
MarcoGorelli Nov 19, 2020
58d0839
fixup test
MarcoGorelli Nov 19, 2020
c145b0a
fixup PR number
MarcoGorelli Nov 19, 2020
e28f2b8
:label: add type annotations, correct docstring
MarcoGorelli Nov 20, 2020
08c11fc
Merge branch 'extend-plot_posterior_predictive_glm' of github.com:mar…
MarcoGorelli Nov 20, 2020
987e170
sort
MarcoGorelli Nov 20, 2020
f93da72
sort
MarcoGorelli Nov 20, 2020
40ac14b
add copyright note
MarcoGorelli Nov 20, 2020
9540abb
use todict
MarcoGorelli Nov 20, 2020
a0fa02b
:art:
MarcoGorelli Nov 20, 2020
258018a
don't cover import error (matplotlib is a dev requirement)
MarcoGorelli Nov 22, 2020
9147b34
don't cover import typechecking
MarcoGorelli Nov 22, 2020
ddcebeb
remove optional mpl import
MarcoGorelli Nov 26, 2020
fa7687f
remove optional mpl import
MarcoGorelli Nov 26, 2020
2ceaa13
remove optional mpl import
MarcoGorelli Nov 26, 2020
9cf211a
Merge remote-tracking branch 'upstream/master' into extend-plot_poste…
MarcoGorelli Nov 26, 2020
63a5a85
Merge branch 'master' into extend-plot_posterior_predictive_glm
MarcoGorelli Dec 16, 2020
0106d77
Update RELEASE-NOTES.md
MarcoGorelli Dec 16, 2020
3fe8210
use Python3.7+ type hints
MarcoGorelli Dec 17, 2020
9e012fb
Merge remote-tracking branch 'upstream/master' into extend-plot_poste…
MarcoGorelli Dec 17, 2020
2a38f67
Merge branch 'extend-plot_posterior_predictive_glm' of github.com:Mar…
MarcoGorelli Dec 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ This is the first release to support Python3.9 and to drop Python3.6.

### New Features
- `OrderedProbit` distribution added (see [#4232](https://github.com/pymc-devs/pymc3/pull/4232)).
- `plot_posterior_predictive_glm` now works with `arviz.InferenceData` as well (see [#4234](https://github.com/pymc-devs/pymc3/pull/4234))

### Maintenance
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)
Expand Down
29 changes: 22 additions & 7 deletions pymc3/plots/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

try:
import matplotlib.pyplot as plt
except ImportError: # mpl is optional
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think mpl is optional anymore - arviz is a required dependency, and mpl is a required dependency of arviz

pass
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import matplotlib.pyplot as plt
import numpy as np

from pymc3.backends.base import MultiTrace

if TYPE_CHECKING:
from arviz.data.inference_data import InferenceData

def plot_posterior_predictive_glm(trace, eval=None, lm=None, samples=30, **kwargs):

def plot_posterior_predictive_glm(
trace: Union[InferenceData, MultiTrace],
eval: Optional[np.ndarray] = None,
lm: Optional[Callable] = None,
samples: int = 30,
**kwargs: Any
) -> None:
"""Plot posterior predictive of a linear model.
:Arguments:
trace: <array>
Array of posterior samples with columns
trace: InferenceData or MultiTrace
Output of pm.sample()
eval: <array>
Array over which to evaluate lm
lm: function <default: linear function>
Expand All @@ -47,6 +59,9 @@ def plot_posterior_predictive_glm(trace, eval=None, lm=None, samples=30, **kwarg
if "c" not in kwargs and "color" not in kwargs:
kwargs["c"] = "k"

if not isinstance(trace, MultiTrace):
trace = trace.posterior.to_dataframe().to_dict(orient="records")

for rand_loc in np.random.randint(0, len(trace), samples):
rand_sample = trace[rand_loc]
plt.plot(eval, lm(eval, rand_sample), **kwargs)
Expand Down
68 changes: 68 additions & 0 deletions pymc3/tests/test_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import matplotlib.pyplot as plt
import numpy as np
import pytest

from arviz import from_pymc3

import pymc3 as pm

from pymc3.backends.ndarray import point_list_to_multitrace
from pymc3.plots import plot_posterior_predictive_glm


@pytest.mark.parametrize("inferencedata", [True, False])
def test_plot_posterior_predictive_glm_defaults(inferencedata):
with pm.Model() as model:
pm.Normal("x")
pm.Normal("Intercept")
trace = point_list_to_multitrace([{"x": np.array([1]), "Intercept": np.array([1])}], model)
if inferencedata:
trace = from_pymc3(trace, model=model)
_, ax = plt.subplots()
plot_posterior_predictive_glm(trace, samples=1)
lines = ax.get_lines()
expected_xvalues = np.linspace(0, 1, 100)
expected_yvalues = np.linspace(1, 2, 100)
for line in lines:
x_axis, y_axis = line.get_data()
np.testing.assert_array_equal(x_axis, expected_xvalues)
np.testing.assert_array_equal(y_axis, expected_yvalues)
assert line.get_lw() == 0.2
assert line.get_c() == "k"


@pytest.mark.parametrize("inferencedata", [True, False])
def test_plot_posterior_predictive_glm_non_defaults(inferencedata):
with pm.Model() as model:
pm.Normal("x")
pm.Normal("Intercept")
trace = point_list_to_multitrace([{"x": np.array([1]), "Intercept": np.array([1])}], model)
if inferencedata:
trace = from_pymc3(trace, model=model)
_, ax = plt.subplots()
plot_posterior_predictive_glm(
trace, samples=1, lm=lambda x, _: x, eval=np.linspace(0, 1, 10), lw=0.3, c="b"
)
lines = ax.get_lines()
expected_xvalues = np.linspace(0, 1, 10)
expected_yvalues = np.linspace(0, 1, 10)
for line in lines:
x_axis, y_axis = line.get_data()
np.testing.assert_array_equal(x_axis, expected_xvalues)
np.testing.assert_array_equal(y_axis, expected_yvalues)
assert line.get_lw() == 0.3
assert line.get_c() == "b"