-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Let plot_posterior_predictive_glm work with inferencedata too #4234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
AlexAndorra
merged 25 commits into
pymc-devs:master
from
MarcoGorelli:extend-plot_posterior_predictive_glm
Dec 17, 2020
Merged
Changes from 7 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
cbab50c
adapt for both
MarcoGorelli 2eeb5ae
add multitrace test
MarcoGorelli 22991c9
add inferencedata test, cover all lines
MarcoGorelli 3eb3410
update release notes
MarcoGorelli 5c37ef7
sort imports (it's not checked yet?)
MarcoGorelli 58d0839
fixup test
MarcoGorelli c145b0a
fixup PR number
MarcoGorelli e28f2b8
:label: add type annotations, correct docstring
MarcoGorelli 08c11fc
Merge branch 'extend-plot_posterior_predictive_glm' of github.com:mar…
MarcoGorelli 987e170
sort
MarcoGorelli f93da72
sort
MarcoGorelli 40ac14b
add copyright note
MarcoGorelli 9540abb
use todict
MarcoGorelli a0fa02b
:art:
MarcoGorelli 258018a
don't cover import error (matplotlib is a dev requirement)
MarcoGorelli 9147b34
don't cover import typechecking
MarcoGorelli ddcebeb
remove optional mpl import
MarcoGorelli fa7687f
remove optional mpl import
MarcoGorelli 2ceaa13
remove optional mpl import
MarcoGorelli 9cf211a
Merge remote-tracking branch 'upstream/master' into extend-plot_poste…
MarcoGorelli 63a5a85
Merge branch 'master' into extend-plot_posterior_predictive_glm
MarcoGorelli 0106d77
Update RELEASE-NOTES.md
MarcoGorelli 3fe8210
use Python3.7+ type hints
MarcoGorelli 9e012fb
Merge remote-tracking branch 'upstream/master' into extend-plot_poste…
MarcoGorelli 2a38f67
Merge branch 'extend-plot_posterior_predictive_glm' of github.com:Mar…
MarcoGorelli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import matplotlib.pyplot as plt | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import numpy as np | ||
import pytest | ||
|
||
from arviz import from_pymc3 | ||
|
||
import pymc3 as pm | ||
|
||
from pymc3.backends.ndarray import point_list_to_multitrace | ||
from pymc3.plots import plot_posterior_predictive_glm | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@pytest.mark.parametrize("inferencedata", [True, False]) | ||
def test_plot_posterior_predictive_glm_defaults(inferencedata): | ||
with pm.Model() as model: | ||
pm.Normal("x") | ||
pm.Normal("Intercept") | ||
trace = point_list_to_multitrace([{"x": np.array([1]), "Intercept": np.array([1])}], model) | ||
if inferencedata: | ||
trace = from_pymc3(trace, model=model) | ||
_, ax = plt.subplots() | ||
plot_posterior_predictive_glm(trace, samples=1) | ||
lines = ax.get_lines() | ||
expected_xvalues = np.linspace(0, 1, 100) | ||
expected_yvalues = np.linspace(1, 2, 100) | ||
for line in lines: | ||
x_axis, y_axis = line.get_data() | ||
np.testing.assert_array_equal(x_axis, expected_xvalues) | ||
np.testing.assert_array_equal(y_axis, expected_yvalues) | ||
assert line.get_lw() == 0.2 | ||
assert line.get_c() == "k" | ||
|
||
|
||
@pytest.mark.parametrize("inferencedata", [True, False]) | ||
def test_plot_posterior_predictive_glm_non_defaults(inferencedata): | ||
with pm.Model() as model: | ||
pm.Normal("x") | ||
pm.Normal("Intercept") | ||
trace = point_list_to_multitrace([{"x": np.array([1]), "Intercept": np.array([1])}], model) | ||
if inferencedata: | ||
trace = from_pymc3(trace, model=model) | ||
_, ax = plt.subplots() | ||
plot_posterior_predictive_glm( | ||
trace, samples=1, lm=lambda x, _: x, eval=np.linspace(0, 1, 10), lw=0.3, c="b" | ||
) | ||
lines = ax.get_lines() | ||
expected_xvalues = np.linspace(0, 1, 10) | ||
expected_yvalues = np.linspace(0, 1, 10) | ||
for line in lines: | ||
x_axis, y_axis = line.get_data() | ||
np.testing.assert_array_equal(x_axis, expected_xvalues) | ||
np.testing.assert_array_equal(y_axis, expected_yvalues) | ||
assert line.get_lw() == 0.3 | ||
assert line.get_c() == "b" |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two functions have a lot of duplicated lines; I think they can be merged into one by checking
if isinstance(trace, MultiTrace)
at the beginning of the function (or just before) and casting theInferenceData
to_array
(I think this is the name of the function but you can check on ArviZ website) instead of to a dataframe.After that, the handling should be the same as you're dealing with numpy arrays in both cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we cast to
array
then I think we wouldn't be able to access the different parameters (e.g.'Intercept'
or'x'
), which appear inlm
. Each element here is a dict in the multitrace case:at this point, the only lines they have in common are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah right, I forgot the whole trace was given here, and not only
trace["y"]
for instance. But then, wouldn'ttrace.posterior.to_dataframe().to_dict()
get the format we want? That way we'd need only one plotting functionThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure - I think this'd be slightly more expensive, but arguably it's worth it for the sake of much simpler code