diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 6f968d733c..82c5761245 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -5,6 +5,7 @@ This is the first release to support Python3.9 and to drop Python3.6. ### New Features - `OrderedProbit` distribution added (see [#4232](https://github.com/pymc-devs/pymc3/pull/4232)). +- `plot_posterior_predictive_glm` now works with `arviz.InferenceData` as well (see [#4234](https://github.com/pymc-devs/pymc3/pull/4234)) ### Maintenance - Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318) diff --git a/pymc3/plots/posteriorplot.py b/pymc3/plots/posteriorplot.py index fc44e914c8..08ba7f0487 100644 --- a/pymc3/plots/posteriorplot.py +++ b/pymc3/plots/posteriorplot.py @@ -12,18 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import matplotlib.pyplot as plt -except ImportError: # mpl is optional - pass +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +import matplotlib.pyplot as plt import numpy as np +from pymc3.backends.base import MultiTrace + +if TYPE_CHECKING: + from arviz.data.inference_data import InferenceData -def plot_posterior_predictive_glm(trace, eval=None, lm=None, samples=30, **kwargs): + +def plot_posterior_predictive_glm( + trace: Union[InferenceData, MultiTrace], + eval: Optional[np.ndarray] = None, + lm: Optional[Callable] = None, + samples: int = 30, + **kwargs: Any +) -> None: """Plot posterior predictive of a linear model. :Arguments: - trace: - Array of posterior samples with columns + trace: InferenceData or MultiTrace + Output of pm.sample() eval: Array over which to evaluate lm lm: function @@ -47,6 +59,9 @@ def plot_posterior_predictive_glm(trace, eval=None, lm=None, samples=30, **kwarg if "c" not in kwargs and "color" not in kwargs: kwargs["c"] = "k" + if not isinstance(trace, MultiTrace): + trace = trace.posterior.to_dataframe().to_dict(orient="records") + for rand_loc in np.random.randint(0, len(trace), samples): rand_sample = trace[rand_loc] plt.plot(eval, lm(eval, rand_sample), **kwargs) diff --git a/pymc3/tests/test_plots.py b/pymc3/tests/test_plots.py new file mode 100644 index 0000000000..0e2149f469 --- /dev/null +++ b/pymc3/tests/test_plots.py @@ -0,0 +1,68 @@ +# Copyright 2020 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import matplotlib.pyplot as plt +import numpy as np +import pytest + +from arviz import from_pymc3 + +import pymc3 as pm + +from pymc3.backends.ndarray import point_list_to_multitrace +from pymc3.plots import plot_posterior_predictive_glm + + +@pytest.mark.parametrize("inferencedata", [True, False]) +def test_plot_posterior_predictive_glm_defaults(inferencedata): + with pm.Model() as model: + pm.Normal("x") + pm.Normal("Intercept") + trace = point_list_to_multitrace([{"x": np.array([1]), "Intercept": np.array([1])}], model) + if inferencedata: + trace = from_pymc3(trace, model=model) + _, ax = plt.subplots() + plot_posterior_predictive_glm(trace, samples=1) + lines = ax.get_lines() + expected_xvalues = np.linspace(0, 1, 100) + expected_yvalues = np.linspace(1, 2, 100) + for line in lines: + x_axis, y_axis = line.get_data() + np.testing.assert_array_equal(x_axis, expected_xvalues) + np.testing.assert_array_equal(y_axis, expected_yvalues) + assert line.get_lw() == 0.2 + assert line.get_c() == "k" + + +@pytest.mark.parametrize("inferencedata", [True, False]) +def test_plot_posterior_predictive_glm_non_defaults(inferencedata): + with pm.Model() as model: + pm.Normal("x") + pm.Normal("Intercept") + trace = point_list_to_multitrace([{"x": np.array([1]), "Intercept": np.array([1])}], model) + if inferencedata: + trace = from_pymc3(trace, model=model) + _, ax = plt.subplots() + plot_posterior_predictive_glm( + trace, samples=1, lm=lambda x, _: x, eval=np.linspace(0, 1, 10), lw=0.3, c="b" + ) + lines = ax.get_lines() + expected_xvalues = np.linspace(0, 1, 10) + expected_yvalues = np.linspace(0, 1, 10) + for line in lines: + x_axis, y_axis = line.get_data() + np.testing.assert_array_equal(x_axis, expected_xvalues) + np.testing.assert_array_equal(y_axis, expected_yvalues) + assert line.get_lw() == 0.3 + assert line.get_c() == "b"