Skip to content

Get plot data for prepostfit experiments #438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
70110c0
export plot data from plot utils
lpoug Oct 16, 2024
fdc867f
add function to prepostfit class + fixes in plot_utils
lpoug Oct 22, 2024
deace7f
generic get_plot_data in base.py and update prepostfit code to get hdi
lpoug Nov 5, 2024
d7680f6
utility function to retrieve hdi and clean get_plot_data_bayesian
lpoug Nov 6, 2024
1734486
hdi_prob specification in get_plot_data_bayesian
lpoug Nov 12, 2024
b79e6ee
tested for its and index alignment in recovering hdi
lpoug Nov 14, 2024
3f813ac
removed unused library
lpoug Nov 25, 2024
bcf9e4f
export plot data from plot utils
lpoug Oct 16, 2024
8cf55ba
add function to prepostfit class + fixes in plot_utils
lpoug Oct 22, 2024
8284a86
generic get_plot_data in base.py and update prepostfit code to get hdi
lpoug Nov 5, 2024
c975987
utility function to retrieve hdi and clean get_plot_data_bayesian
lpoug Nov 6, 2024
adb52b2
hdi_prob specification in get_plot_data_bayesian
lpoug Nov 12, 2024
6075f20
tested for its and index alignment in recovering hdi
lpoug Nov 14, 2024
d0c2109
removed unused library
lpoug Nov 25, 2024
0b7fa36
export plot data from plot utils
lpoug Oct 16, 2024
2493b17
generic get_plot_data in base.py and update prepostfit code to get hdi
lpoug Nov 5, 2024
1eab610
hdi_prob specification in get_plot_data_bayesian
lpoug Nov 12, 2024
b49a646
tested for its and index alignment in recovering hdi
lpoug Nov 14, 2024
5521a07
removed unused library
lpoug Nov 25, 2024
e2263ea
fix diverging branch
lpoug Feb 7, 2025
bb7305a
export plot data from plot utils
lpoug Oct 16, 2024
7ac0432
generic get_plot_data in base.py and update prepostfit code to get hdi
lpoug Nov 5, 2024
f045c96
hdi_prob specification in get_plot_data_bayesian
lpoug Nov 12, 2024
6a92c4f
removed unused library
lpoug Nov 25, 2024
aea64a0
Merge branch 'plot-data' of github.com:lpoug/CausalPy into plot-data
lpoug Feb 7, 2025
9b6fcba
Merge branch 'main' into plot-data
drbenvincent Mar 3, 2025
6dca884
ran pre-commit checks
drbenvincent Mar 3, 2025
6a6face
renamed get_plot_data_bayesian to _get_plot_data_bayesian, get_plot_d…
lpoug Apr 7, 2025
97f0d79
added tests to get_plot_data for pymc and skl experiments
lpoug Apr 7, 2025
0edca77
updated uml diagrams
lpoug Apr 7, 2025
44d3870
added dynamic naming for hdi columns in _get_plot_data_bayesian, upda…
lpoug Apr 8, 2025
79c3ed1
Merge branch 'main' into pr/438
drbenvincent Apr 16, 2025
ae35d81
run pre-commit checks
drbenvincent Apr 16, 2025
5eaaec4
add comment to tests
drbenvincent Apr 16, 2025
d27aa89
make get_plot_data methods public, links to functions in docstrings, …
drbenvincent Apr 16, 2025
9af3bfb
replace NotImplementedError with pass
drbenvincent Apr 17, 2025
3fd642e
revert previous change
drbenvincent Apr 17, 2025
da6c91d
add tests to detect NotImplementedError exception
drbenvincent Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions causalpy/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from abc import abstractmethod

import pandas as pd
from sklearn.base import RegressorMixin

from causalpy.pymc_models import PyMCModel
Expand Down Expand Up @@ -78,3 +79,26 @@
def ols_plot(self, *args, **kwargs):
"""Abstract method for plotting the model."""
raise NotImplementedError("ols_plot method not yet implemented")

def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
"""Recover the data of a PrePostFit experiment along with the prediction and causal impact information.

Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
depending on the model type.
"""
if isinstance(self.model, PyMCModel):
return self.get_plot_data_bayesian(*args, **kwargs)
elif isinstance(self.model, RegressorMixin):
return self.get_plot_data_ols(*args, **kwargs)

Check warning on line 92 in causalpy/experiments/base.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/base.py#L89-L92

Added lines #L89 - L92 were not covered by tests
else:
raise ValueError("Unsupported model type")

Check warning on line 94 in causalpy/experiments/base.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/base.py#L94

Added line #L94 was not covered by tests

@abstractmethod
def get_plot_data_bayesian(self, *args, **kwargs):
"""Abstract method for recovering plot data."""
raise NotImplementedError("get_plot_data_bayesian method not yet implemented")

Check warning on line 99 in causalpy/experiments/base.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/base.py#L99

Added line #L99 was not covered by tests

@abstractmethod
def get_plot_data_ols(self, *args, **kwargs):
"""Abstract method for recovering plot data."""
raise NotImplementedError("get_plot_data_ols method not yet implemented")

Check warning on line 104 in causalpy/experiments/base.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/base.py#L104

Added line #L104 was not covered by tests
56 changes: 55 additions & 1 deletion causalpy/experiments/prepostfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sklearn.base import RegressorMixin

from causalpy.custom_exceptions import BadIndexException
from causalpy.plot_utils import plot_xY
from causalpy.plot_utils import get_hdi_to_df, plot_xY
from causalpy.pymc_models import PyMCModel
from causalpy.utils import round_num

Expand Down Expand Up @@ -303,6 +303,60 @@

return (fig, ax)

def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
"""
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
"""
if isinstance(self.model, PyMCModel):
pre_data = self.datapre.copy()
post_data = self.datapost.copy()

Check warning on line 312 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L310-L312

Added lines #L310 - L312 were not covered by tests

pre_data["prediction"] = (

Check warning on line 314 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L314

Added line #L314 was not covered by tests
az.extract(self.pre_pred, group="posterior_predictive", var_names="mu")
.mean("sample")
.values
)
post_data["prediction"] = (

Check warning on line 319 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L319

Added line #L319 was not covered by tests
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
.mean("sample")
.values
)
pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(

Check warning on line 324 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L324

Added line #L324 was not covered by tests
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
).set_index(pre_data.index)
post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(

Check warning on line 327 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L327

Added line #L327 was not covered by tests
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
).set_index(post_data.index)

pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(

Check warning on line 333 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L331-L333

Added lines #L331 - L333 were not covered by tests
self.pre_impact, hdi_prob=hdi_prob
).set_index(pre_data.index)
post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(

Check warning on line 336 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L336

Added line #L336 was not covered by tests
self.post_impact, hdi_prob=hdi_prob
).set_index(post_data.index)

self.plot_data = pd.concat([pre_data, post_data])

Check warning on line 340 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L340

Added line #L340 was not covered by tests

return self.plot_data

Check warning on line 342 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L342

Added line #L342 was not covered by tests
else:
raise ValueError("Unsupported model type")

Check warning on line 344 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L344

Added line #L344 was not covered by tests

def get_plot_data_ols(self) -> pd.DataFrame:
"""
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
"""
pre_data = self.datapre.copy()
post_data = self.datapost.copy()
pre_data["prediction"] = self.pre_pred
post_data["prediction"] = self.post_pred
pre_data["impact"] = self.pre_impact
post_data["impact"] = self.post_impact
self.plot_data = pd.concat([pre_data, post_data])

Check warning on line 356 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L350-L356

Added lines #L350 - L356 were not covered by tests

return self.plot_data

Check warning on line 358 in causalpy/experiments/prepostfit.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/prepostfit.py#L358

Added line #L358 was not covered by tests


class InterruptedTimeSeries(PrePostFit):
"""
Expand Down
21 changes: 21 additions & 0 deletions causalpy/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,24 @@
filter(lambda x: isinstance(x, PolyCollection), ax_hdi.get_children())
)[-1]
return (h_line, h_patch)


def get_hdi_to_df(
x: xr.DataArray,
hdi_prob: float = 0.94,
) -> pd.DataFrame:
"""
Utility function to calculate and recover HDI intervals.

:param x:
Xarray data array
:param hdi_prob:
The size of the HDI, default is 0.94
"""
hdi = (

Check warning on line 96 in causalpy/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

causalpy/plot_utils.py#L96

Added line #L96 was not covered by tests
az.hdi(x, hdi_prob=hdi_prob)
.to_dataframe()
.unstack(level="hdi")
.droplevel(0, axis=1)
)
return hdi

Check warning on line 102 in causalpy/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

causalpy/plot_utils.py#L102

Added line #L102 was not covered by tests
6 changes: 3 additions & 3 deletions docs/source/_static/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.