Skip to content

Commit 4e5edd5

Browse files
authored
Let plot_posterior_predictive_glm work with inferencedata too (#4234)
* adapt for both * add multitrace test * add inferencedata test, cover all lines * update release notes * sort imports (it's not checked yet?) * fixup test * fixup PR number * 🏷️ add type annotations, correct docstring * sort * sort * add copyright note * use todict * 🎨 * don't cover import error (matplotlib is a dev requirement) * don't cover import typechecking * remove optional mpl import * remove optional mpl import * remove optional mpl import * Update RELEASE-NOTES.md * use Python3.7+ type hints
1 parent b7b145d commit 4e5edd5

File tree

3 files changed

+91
-7
lines changed

3 files changed

+91
-7
lines changed

Diff for: RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ This is the first release to support Python3.9 and to drop Python3.6.
55

66
### New Features
77
- `OrderedProbit` distribution added (see [#4232](https://github.com/pymc-devs/pymc3/pull/4232)).
8+
- `plot_posterior_predictive_glm` now works with `arviz.InferenceData` as well (see [#4234](https://github.com/pymc-devs/pymc3/pull/4234))
89

910
### Maintenance
1011
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)

Diff for: pymc3/plots/posteriorplot.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,30 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
try:
16-
import matplotlib.pyplot as plt
17-
except ImportError: # mpl is optional
18-
pass
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
18+
19+
import matplotlib.pyplot as plt
1920
import numpy as np
2021

22+
from pymc3.backends.base import MultiTrace
23+
24+
if TYPE_CHECKING:
25+
from arviz.data.inference_data import InferenceData
2126

22-
def plot_posterior_predictive_glm(trace, eval=None, lm=None, samples=30, **kwargs):
27+
28+
def plot_posterior_predictive_glm(
29+
trace: Union[InferenceData, MultiTrace],
30+
eval: Optional[np.ndarray] = None,
31+
lm: Optional[Callable] = None,
32+
samples: int = 30,
33+
**kwargs: Any
34+
) -> None:
2335
"""Plot posterior predictive of a linear model.
2436
:Arguments:
25-
trace: <array>
26-
Array of posterior samples with columns
37+
trace: InferenceData or MultiTrace
38+
Output of pm.sample()
2739
eval: <array>
2840
Array over which to evaluate lm
2941
lm: function <default: linear function>
@@ -47,6 +59,9 @@ def plot_posterior_predictive_glm(trace, eval=None, lm=None, samples=30, **kwarg
4759
if "c" not in kwargs and "color" not in kwargs:
4860
kwargs["c"] = "k"
4961

62+
if not isinstance(trace, MultiTrace):
63+
trace = trace.posterior.to_dataframe().to_dict(orient="records")
64+
5065
for rand_loc in np.random.randint(0, len(trace), samples):
5166
rand_sample = trace[rand_loc]
5267
plt.plot(eval, lm(eval, rand_sample), **kwargs)

Diff for: pymc3/tests/test_plots.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import matplotlib.pyplot as plt
16+
import numpy as np
17+
import pytest
18+
19+
from arviz import from_pymc3
20+
21+
import pymc3 as pm
22+
23+
from pymc3.backends.ndarray import point_list_to_multitrace
24+
from pymc3.plots import plot_posterior_predictive_glm
25+
26+
27+
@pytest.mark.parametrize("inferencedata", [True, False])
28+
def test_plot_posterior_predictive_glm_defaults(inferencedata):
29+
with pm.Model() as model:
30+
pm.Normal("x")
31+
pm.Normal("Intercept")
32+
trace = point_list_to_multitrace([{"x": np.array([1]), "Intercept": np.array([1])}], model)
33+
if inferencedata:
34+
trace = from_pymc3(trace, model=model)
35+
_, ax = plt.subplots()
36+
plot_posterior_predictive_glm(trace, samples=1)
37+
lines = ax.get_lines()
38+
expected_xvalues = np.linspace(0, 1, 100)
39+
expected_yvalues = np.linspace(1, 2, 100)
40+
for line in lines:
41+
x_axis, y_axis = line.get_data()
42+
np.testing.assert_array_equal(x_axis, expected_xvalues)
43+
np.testing.assert_array_equal(y_axis, expected_yvalues)
44+
assert line.get_lw() == 0.2
45+
assert line.get_c() == "k"
46+
47+
48+
@pytest.mark.parametrize("inferencedata", [True, False])
49+
def test_plot_posterior_predictive_glm_non_defaults(inferencedata):
50+
with pm.Model() as model:
51+
pm.Normal("x")
52+
pm.Normal("Intercept")
53+
trace = point_list_to_multitrace([{"x": np.array([1]), "Intercept": np.array([1])}], model)
54+
if inferencedata:
55+
trace = from_pymc3(trace, model=model)
56+
_, ax = plt.subplots()
57+
plot_posterior_predictive_glm(
58+
trace, samples=1, lm=lambda x, _: x, eval=np.linspace(0, 1, 10), lw=0.3, c="b"
59+
)
60+
lines = ax.get_lines()
61+
expected_xvalues = np.linspace(0, 1, 10)
62+
expected_yvalues = np.linspace(0, 1, 10)
63+
for line in lines:
64+
x_axis, y_axis = line.get_data()
65+
np.testing.assert_array_equal(x_axis, expected_xvalues)
66+
np.testing.assert_array_equal(y_axis, expected_yvalues)
67+
assert line.get_lw() == 0.3
68+
assert line.get_c() == "b"

0 commit comments

Comments
 (0)