-
-
Notifications
You must be signed in to change notification settings - Fork 59
Fix group selection in sample_posterior_predictive
when predictions=True
is passed in kwargs
#426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
34839ad
7cdd900
a0501fe
079e131
fb6b1e9
ce1b2d5
4ea5fbc
7f84f03
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -530,6 +530,7 @@ def predict( | |
self, | ||
X_pred: np.ndarray | pd.DataFrame | pd.Series, | ||
extend_idata: bool = True, | ||
predictions: bool = False, | ||
**kwargs, | ||
) -> np.ndarray: | ||
""" | ||
|
@@ -559,7 +560,7 @@ def predict( | |
""" | ||
|
||
posterior_predictive_samples = self.sample_posterior_predictive( | ||
X_pred, extend_idata, combined=False, **kwargs | ||
X_pred, extend_idata, predictions, combined=False, **kwargs | ||
) | ||
|
||
if self.output_var not in posterior_predictive_samples: | ||
|
@@ -624,7 +625,7 @@ def sample_prior_predictive( | |
|
||
return prior_predictive_samples | ||
|
||
def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): | ||
def sample_posterior_predictive(self, X_pred, extend_idata, predictions, combined, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Provide default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other arguments do not have defaults. The Would you be able to explain why we would want |
||
""" | ||
Sample from the model's posterior predictive distribution. | ||
|
||
|
@@ -646,12 +647,15 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): | |
self._data_setter(X_pred) | ||
|
||
with self.model: # sample with new input data | ||
post_pred = pm.sample_posterior_predictive(self.idata, **kwargs) | ||
post_pred = pm.sample_posterior_predictive(self.idata, predictions=predictions, **kwargs) | ||
if extend_idata: | ||
self.idata.extend(post_pred, join="right") | ||
|
||
# Determine the correct group | ||
butterman0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
group_name = "predictions" if predictions else "posterior_predictive" | ||
|
||
posterior_predictive_samples = az.extract( | ||
post_pred, "posterior_predictive", combined=combined | ||
post_pred, group_name, combined=combined | ||
) | ||
|
||
return posterior_predictive_samples | ||
|
@@ -700,6 +704,7 @@ def predict_posterior( | |
X_pred: np.ndarray | pd.DataFrame | pd.Series, | ||
extend_idata: bool = True, | ||
combined: bool = True, | ||
predictions: bool = False, | ||
**kwargs, | ||
) -> xr.DataArray: | ||
""" | ||
|
@@ -723,7 +728,7 @@ def predict_posterior( | |
|
||
X_pred = self._validate_data(X_pred) | ||
posterior_predictive_samples = self.sample_posterior_predictive( | ||
X_pred, extend_idata, combined, **kwargs | ||
X_pred, extend_idata, predictions, combined, **kwargs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pass by keyword argument There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was aiming to keep it in the same format as current implementation. i.e. x_pred, extend_idata and combined do not use keyword arguments.. Similar question to the one above - should these all be changed to use keyword arguments? Why would we treat There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @ricardoV94, let me know what you think and I can adjust. |
||
) | ||
|
||
if self.output_var not in posterior_predictive_samples: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass by keyword to be on the safe side