Skip to content

Get plot data for prepostfit experiments #438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
70110c0
export plot data from plot utils
lpoug Oct 16, 2024
fdc867f
add function to prepostfit class + fixes in plot_utils
lpoug Oct 22, 2024
deace7f
generic get_plot_data in base.py and update prepostfit code to get hdi
lpoug Nov 5, 2024
d7680f6
utility function to retrieve hdi and clean get_plot_data_bayesian
lpoug Nov 6, 2024
1734486
hdi_prob specification in get_plot_data_bayesian
lpoug Nov 12, 2024
b79e6ee
tested for its and index alignment in recovering hdi
lpoug Nov 14, 2024
3f813ac
removed unused library
lpoug Nov 25, 2024
bcf9e4f
export plot data from plot utils
lpoug Oct 16, 2024
8cf55ba
add function to prepostfit class + fixes in plot_utils
lpoug Oct 22, 2024
8284a86
generic get_plot_data in base.py and update prepostfit code to get hdi
lpoug Nov 5, 2024
c975987
utility function to retrieve hdi and clean get_plot_data_bayesian
lpoug Nov 6, 2024
adb52b2
hdi_prob specification in get_plot_data_bayesian
lpoug Nov 12, 2024
6075f20
tested for its and index alignment in recovering hdi
lpoug Nov 14, 2024
d0c2109
removed unused library
lpoug Nov 25, 2024
0b7fa36
export plot data from plot utils
lpoug Oct 16, 2024
2493b17
generic get_plot_data in base.py and update prepostfit code to get hdi
lpoug Nov 5, 2024
1eab610
hdi_prob specification in get_plot_data_bayesian
lpoug Nov 12, 2024
b49a646
tested for its and index alignment in recovering hdi
lpoug Nov 14, 2024
5521a07
removed unused library
lpoug Nov 25, 2024
e2263ea
fix diverging branch
lpoug Feb 7, 2025
bb7305a
export plot data from plot utils
lpoug Oct 16, 2024
7ac0432
generic get_plot_data in base.py and update prepostfit code to get hdi
lpoug Nov 5, 2024
f045c96
hdi_prob specification in get_plot_data_bayesian
lpoug Nov 12, 2024
6a92c4f
removed unused library
lpoug Nov 25, 2024
aea64a0
Merge branch 'plot-data' of github.com:lpoug/CausalPy into plot-data
lpoug Feb 7, 2025
9b6fcba
Merge branch 'main' into plot-data
drbenvincent Mar 3, 2025
6dca884
ran pre-commit checks
drbenvincent Mar 3, 2025
6a6face
renamed get_plot_data_bayesian to _get_plot_data_bayesian, get_plot_d…
lpoug Apr 7, 2025
97f0d79
added tests to get_plot_data for pymc and skl experiments
lpoug Apr 7, 2025
0edca77
updated uml diagrams
lpoug Apr 7, 2025
44d3870
added dynamic naming for hdi columns in _get_plot_data_bayesian, upda…
lpoug Apr 8, 2025
79c3ed1
Merge branch 'main' into pr/438
drbenvincent Apr 16, 2025
ae35d81
run pre-commit checks
drbenvincent Apr 16, 2025
5eaaec4
add comment to tests
drbenvincent Apr 16, 2025
d27aa89
make get_plot_data methods public, links to functions in docstrings, …
drbenvincent Apr 16, 2025
9af3bfb
replace NotImplementedError with pass
drbenvincent Apr 17, 2025
3fd642e
revert previous change
drbenvincent Apr 17, 2025
da6c91d
add tests to detect NotImplementedError exception
drbenvincent Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions causalpy/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from abc import abstractmethod

import pandas as pd
from sklearn.base import RegressorMixin

from causalpy.pymc_models import PyMCModel
Expand Down Expand Up @@ -59,22 +60,45 @@
def plot(self, *args, **kwargs) -> tuple:
"""Plot the model.

Internally, this function dispatches to either `bayesian_plot` or `ols_plot`
Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
depending on the model type.
"""
if isinstance(self.model, PyMCModel):
return self.bayesian_plot(*args, **kwargs)
return self._bayesian_plot(*args, **kwargs)
elif isinstance(self.model, RegressorMixin):
return self.ols_plot(*args, **kwargs)
return self._ols_plot(*args, **kwargs)
else:
raise ValueError("Unsupported model type")

@abstractmethod
def bayesian_plot(self, *args, **kwargs):
def _bayesian_plot(self, *args, **kwargs):
"""Abstract method for plotting the model."""
raise NotImplementedError("bayesian_plot method not yet implemented")
raise NotImplementedError("_bayesian_plot method not yet implemented")

Check warning on line 76 in causalpy/experiments/base.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/base.py#L76

Added line #L76 was not covered by tests

@abstractmethod
def ols_plot(self, *args, **kwargs):
def _ols_plot(self, *args, **kwargs):
"""Abstract method for plotting the model."""
raise NotImplementedError("ols_plot method not yet implemented")
raise NotImplementedError("_ols_plot method not yet implemented")

Check warning on line 81 in causalpy/experiments/base.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/base.py#L81

Added line #L81 was not covered by tests

def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
"""Recover the data of a PrePostFit experiment along with the prediction and causal impact information.

Internally, this function dispatches to either :func:`get_plot_data_bayesian` or :func:`get_plot_data_ols`
depending on the model type.
"""
if isinstance(self.model, PyMCModel):
return self.get_plot_data_bayesian(*args, **kwargs)
elif isinstance(self.model, RegressorMixin):
return self.get_plot_data_ols(*args, **kwargs)
else:
raise ValueError("Unsupported model type")

Check warning on line 94 in causalpy/experiments/base.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/base.py#L94

Added line #L94 was not covered by tests

@abstractmethod
def get_plot_data_bayesian(self, *args, **kwargs):
"""Abstract method for recovering plot data."""
raise NotImplementedError("get_plot_data_bayesian method not yet implemented")

@abstractmethod
def get_plot_data_ols(self, *args, **kwargs):
"""Abstract method for recovering plot data."""
raise NotImplementedError("get_plot_data_ols method not yet implemented")
4 changes: 2 additions & 2 deletions causalpy/experiments/diff_in_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _causal_impact_summary_stat(self, round_to=None) -> str:
"""Computes the mean and 94% credible interval bounds for the causal impact."""
return f"Causal impact = {convert_to_string(self.causal_impact, round_to=round_to)}"

def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
"""
Plot the results

Expand Down Expand Up @@ -367,7 +367,7 @@ def _plot_causal_impact_arrow(results, ax):
)
return fig, ax

def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
"""Generate plot for difference-in-differences"""
round_to = kwargs.get("round_to")
fig, ax = plt.subplots()
Expand Down
74 changes: 69 additions & 5 deletions causalpy/experiments/prepostfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sklearn.base import RegressorMixin

from causalpy.custom_exceptions import BadIndexException
from causalpy.plot_utils import plot_xY
from causalpy.plot_utils import get_hdi_to_df, plot_xY
from causalpy.pymc_models import PyMCModel
from causalpy.utils import round_num

Expand Down Expand Up @@ -123,7 +123,7 @@
print(f"Formula: {self.formula}")
self.print_coefficients(round_to)

def bayesian_plot(
def _bayesian_plot(
self, round_to=None, **kwargs
) -> tuple[plt.Figure, List[plt.Axes]]:
"""
Expand Down Expand Up @@ -231,7 +231,7 @@

return fig, ax

def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
"""
Plot the results

Expand Down Expand Up @@ -303,6 +303,70 @@

return (fig, ax)

def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
"""
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.

:param hdi_prob:
Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function.
"""
if isinstance(self.model, PyMCModel):
hdi_pct = int(round(hdi_prob * 100))

pred_lower_col = f"pred_hdi_lower_{hdi_pct}"
pred_upper_col = f"pred_hdi_upper_{hdi_pct}"
impact_lower_col = f"impact_hdi_lower_{hdi_pct}"
impact_upper_col = f"impact_hdi_upper_{hdi_pct}"

pre_data = self.datapre.copy()
post_data = self.datapost.copy()

pre_data["prediction"] = (
az.extract(self.pre_pred, group="posterior_predictive", var_names="mu")
.mean("sample")
.values
)
post_data["prediction"] = (
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
.mean("sample")
.values
)
pre_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
).set_index(pre_data.index)
post_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
).set_index(post_data.index)

pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
pre_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
self.pre_impact, hdi_prob=hdi_prob
).set_index(pre_data.index)
post_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
self.post_impact, hdi_prob=hdi_prob
).set_index(post_data.index)

self.plot_data = pd.concat([pre_data, post_data])

return self.plot_data
else:
raise ValueError("Unsupported model type")

Check warning on line 354 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L354

Added line #L354 was not covered by tests

def get_plot_data_ols(self) -> pd.DataFrame:
"""
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
"""
pre_data = self.datapre.copy()
post_data = self.datapost.copy()
pre_data["prediction"] = self.pre_pred
post_data["prediction"] = self.post_pred
pre_data["impact"] = self.pre_impact
post_data["impact"] = self.post_impact
self.plot_data = pd.concat([pre_data, post_data])

return self.plot_data


class InterruptedTimeSeries(PrePostFit):
"""
Expand Down Expand Up @@ -382,7 +446,7 @@
supports_ols = True
supports_bayes = True

def bayesian_plot(self, *args, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
def _bayesian_plot(self, *args, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
"""
Plot the results

Expand All @@ -393,7 +457,7 @@
Whether to plot the control units as well. Defaults to False.
"""
# call the super class method
fig, ax = super().bayesian_plot(*args, **kwargs)
fig, ax = super()._bayesian_plot(*args, **kwargs)

# additional plotting functionality for the synthetic control experiment
plot_predictors = kwargs.get("plot_predictors", False)
Expand Down
2 changes: 1 addition & 1 deletion causalpy/experiments/prepostnegd.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def summary(self, round_to=None) -> None:
print(self._causal_impact_summary_stat(round_to))
self.print_coefficients(round_to)

def bayesian_plot(
def _bayesian_plot(
self, round_to=None, **kwargs
) -> tuple[plt.Figure, List[plt.Axes]]:
"""Generate plot for ANOVA-like experiments with non-equivalent group designs."""
Expand Down
4 changes: 2 additions & 2 deletions causalpy/experiments/regression_discontinuity.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def summary(self, round_to=None) -> None:
print("\n")
self.print_coefficients(round_to)

def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
"""Generate plot for regression discontinuity designs."""
fig, ax = plt.subplots()
# Plot raw data
Expand Down Expand Up @@ -267,7 +267,7 @@ def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
)
return (fig, ax)

def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
"""Generate plot for regression discontinuity designs."""
fig, ax = plt.subplots()
# Plot raw data
Expand Down
2 changes: 1 addition & 1 deletion causalpy/experiments/regression_kink.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def summary(self, round_to=None) -> None:
)
self.print_coefficients(round_to)

def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
"""Generate plot for regression kink designs."""
fig, ax = plt.subplots()
# Plot raw data
Expand Down
21 changes: 21 additions & 0 deletions causalpy/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,24 @@ def plot_xY(
filter(lambda x: isinstance(x, PolyCollection), ax_hdi.get_children())
)[-1]
return (h_line, h_patch)


def get_hdi_to_df(
x: xr.DataArray,
hdi_prob: float = 0.94,
) -> pd.DataFrame:
"""
Utility function to calculate and recover HDI intervals.

:param x:
Xarray data array
:param hdi_prob:
The size of the HDI, default is 0.94
"""
hdi = (
az.hdi(x, hdi_prob=hdi_prob)
.to_dataframe()
.unstack(level="hdi")
.droplevel(0, axis=1)
)
return hdi
Loading