Skip to content

Get plot data for prepostfit experiments #438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
70110c0
export plot data from plot utils
lpoug Oct 16, 2024
fdc867f
add function to prepostfit class + fixes in plot_utils
lpoug Oct 22, 2024
deace7f
generic get_plot_data in base.py and update prepostfit code to get hdi
lpoug Nov 5, 2024
d7680f6
utility function to retrieve hdi and clean get_plot_data_bayesian
lpoug Nov 6, 2024
1734486
hdi_prob specification in get_plot_data_bayesian
lpoug Nov 12, 2024
b79e6ee
tested for its and index alignment in recovering hdi
lpoug Nov 14, 2024
3f813ac
removed unused library
lpoug Nov 25, 2024
bcf9e4f
export plot data from plot utils
lpoug Oct 16, 2024
8cf55ba
add function to prepostfit class + fixes in plot_utils
lpoug Oct 22, 2024
8284a86
generic get_plot_data in base.py and update prepostfit code to get hdi
lpoug Nov 5, 2024
c975987
utility function to retrieve hdi and clean get_plot_data_bayesian
lpoug Nov 6, 2024
adb52b2
hdi_prob specification in get_plot_data_bayesian
lpoug Nov 12, 2024
6075f20
tested for its and index alignment in recovering hdi
lpoug Nov 14, 2024
d0c2109
removed unused library
lpoug Nov 25, 2024
0b7fa36
export plot data from plot utils
lpoug Oct 16, 2024
2493b17
generic get_plot_data in base.py and update prepostfit code to get hdi
lpoug Nov 5, 2024
1eab610
hdi_prob specification in get_plot_data_bayesian
lpoug Nov 12, 2024
b49a646
tested for its and index alignment in recovering hdi
lpoug Nov 14, 2024
5521a07
removed unused library
lpoug Nov 25, 2024
e2263ea
fix diverging branch
lpoug Feb 7, 2025
bb7305a
export plot data from plot utils
lpoug Oct 16, 2024
7ac0432
generic get_plot_data in base.py and update prepostfit code to get hdi
lpoug Nov 5, 2024
f045c96
hdi_prob specification in get_plot_data_bayesian
lpoug Nov 12, 2024
6a92c4f
removed unused library
lpoug Nov 25, 2024
aea64a0
Merge branch 'plot-data' of github.com:lpoug/CausalPy into plot-data
lpoug Feb 7, 2025
9b6fcba
Merge branch 'main' into plot-data
drbenvincent Mar 3, 2025
6dca884
ran pre-commit checks
drbenvincent Mar 3, 2025
6a6face
renamed get_plot_data_bayesian to _get_plot_data_bayesian, get_plot_d…
lpoug Apr 7, 2025
97f0d79
added tests to get_plot_data for pymc and skl experiments
lpoug Apr 7, 2025
0edca77
updated uml diagrams
lpoug Apr 7, 2025
44d3870
added dynamic naming for hdi columns in _get_plot_data_bayesian, upda…
lpoug Apr 8, 2025
79c3ed1
Merge branch 'main' into pr/438
drbenvincent Apr 16, 2025
ae35d81
run pre-commit checks
drbenvincent Apr 16, 2025
5eaaec4
add comment to tests
drbenvincent Apr 16, 2025
d27aa89
make get_plot_data methods public, links to functions in docstrings, …
drbenvincent Apr 16, 2025
9af3bfb
replace NotImplementedError with pass
drbenvincent Apr 17, 2025
3fd642e
revert previous change
drbenvincent Apr 17, 2025
da6c91d
add tests to detect NotImplementedError exception
drbenvincent Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions causalpy/experiments/prepostfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,13 @@ def _get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
"""
if isinstance(self.model, PyMCModel):
hdi_pct = int(round(hdi_prob * 100))

pred_lower_col = f"pred_hdi_lower_{hdi_pct}"
pred_upper_col = f"pred_hdi_upper_{hdi_pct}"
impact_lower_col = f"impact_hdi_lower_{hdi_pct}"
impact_upper_col = f"impact_hdi_upper_{hdi_pct}"

pre_data = self.datapre.copy()
post_data = self.datapost.copy()

Expand All @@ -321,19 +328,19 @@ def _get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
.mean("sample")
.values
)
pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(
pre_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
).set_index(pre_data.index)
post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(
post_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
).set_index(post_data.index)

pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(
pre_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
self.pre_impact, hdi_prob=hdi_prob
).set_index(pre_data.index)
post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(
post_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
self.post_impact, hdi_prob=hdi_prob
).set_index(post_data.index)

Expand Down
74 changes: 62 additions & 12 deletions causalpy/tests/test_integration_pymc_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def test_its():
2. causalpy.InterruptedTimeSeries returns correct type
3. the correct number of MCMC chains exists in the posterior inference data
4. the correct number of MCMC draws exists in the posterior inference data
5. the method get_plot_data returns a DataFrame with expected columns
"""
df = (
cp.load_data("its")
Expand All @@ -378,9 +379,21 @@ def test_its():
isinstance(item, plt.Axes) for item in ax
), "ax must be a numpy.ndarray of plt.Axes"
plot_data = result.get_plot_data()
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
expected_columns = ['prediction', 'pred_hdi_lower', 'pred_hdi_upper', 'impact', 'impact_hdi_lower', 'impact_hdi_upper']
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
assert isinstance(plot_data, pd.DataFrame), (
"The returned object is not a pandas DataFrame"
)
expected_columns = [
"prediction",
"pred_hdi_lower_94",
"pred_hdi_upper_94",
"impact",
"impact_hdi_lower_94",
"impact_hdi_upper_94",
]
assert set(expected_columns).issubset(set(plot_data.columns)), (
f"DataFrame is missing expected columns {expected_columns}"
)


@pytest.mark.integration
def test_its_covid():
Expand All @@ -392,6 +405,7 @@ def test_its_covid():
2. causalpy.InterruptedtimeSeries returns correct type
3. the correct number of MCMC chains exists in the posterior inference data
4. the correct number of MCMC draws exists in the posterior inference data
5. the method get_plot_data returns a DataFrame with expected columns
"""

df = (
Expand All @@ -418,9 +432,20 @@ def test_its_covid():
isinstance(item, plt.Axes) for item in ax
), "ax must be a numpy.ndarray of plt.Axes"
plot_data = result.get_plot_data()
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
expected_columns = ['prediction', 'pred_hdi_lower', 'pred_hdi_upper', 'impact', 'impact_hdi_lower', 'impact_hdi_upper']
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
assert isinstance(plot_data, pd.DataFrame), (
"The returned object is not a pandas DataFrame"
)
expected_columns = [
"prediction",
"pred_hdi_lower_94",
"pred_hdi_upper_94",
"impact",
"impact_hdi_lower_94",
"impact_hdi_upper_94",
]
assert set(expected_columns).issubset(set(plot_data.columns)), (
f"DataFrame is missing expected columns {expected_columns}"
)


@pytest.mark.integration
Expand All @@ -433,6 +458,7 @@ def test_sc():
2. causalpy.SyntheticControl returns correct type
3. the correct number of MCMC chains exists in the posterior inference data
4. the correct number of MCMC draws exists in the posterior inference data
5. the method get_plot_data returns a DataFrame with expected columns
"""

df = cp.load_data("sc")
Expand Down Expand Up @@ -463,9 +489,21 @@ def test_sc():
isinstance(item, plt.Axes) for item in ax
), "ax must be a numpy.ndarray of plt.Axes"
plot_data = result.get_plot_data()
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
expected_columns = ['prediction', 'pred_hdi_lower', 'pred_hdi_upper', 'impact', 'impact_hdi_lower', 'impact_hdi_upper']
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
assert isinstance(plot_data, pd.DataFrame), (
"The returned object is not a pandas DataFrame"
)
expected_columns = [
"prediction",
"pred_hdi_lower_94",
"pred_hdi_upper_94",
"impact",
"impact_hdi_lower_94",
"impact_hdi_upper_94",
]
assert set(expected_columns).issubset(set(plot_data.columns)), (
f"DataFrame is missing expected columns {expected_columns}"
)


@pytest.mark.integration
def test_sc_brexit():
Expand All @@ -477,6 +515,7 @@ def test_sc_brexit():
2. causalpy.SyntheticControl returns correct type
3. the correct number of MCMC chains exists in the posterior inference data
4. the correct number of MCMC draws exists in the posterior inference data
5. the method get_plot_data returns a DataFrame with expected columns
"""

df = (
Expand Down Expand Up @@ -512,9 +551,20 @@ def test_sc_brexit():
isinstance(item, plt.Axes) for item in ax
), "ax must be a numpy.ndarray of plt.Axes"
plot_data = result.get_plot_data()
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
expected_columns = ['prediction', 'pred_hdi_lower', 'pred_hdi_upper', 'impact', 'impact_hdi_lower', 'impact_hdi_upper']
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
assert isinstance(plot_data, pd.DataFrame), (
"The returned object is not a pandas DataFrame"
)
expected_columns = [
"prediction",
"pred_hdi_lower_94",
"pred_hdi_upper_94",
"impact",
"impact_hdi_lower_94",
"impact_hdi_upper_94",
]
assert set(expected_columns).issubset(set(plot_data.columns)), (
f"DataFrame is missing expected columns {expected_columns}"
)


@pytest.mark.integration
Expand Down
22 changes: 16 additions & 6 deletions causalpy/tests/test_integration_skl_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_its():
Loads data and checks:
1. data is a dataframe
2. skl_experiements.InterruptedTimeSeries returns correct type
3. the method get_plot_data returns a DataFrame with expected columns
"""

df = (
Expand All @@ -112,9 +113,13 @@ def test_its():
isinstance(item, plt.Axes) for item in ax
), "ax must be a numpy.ndarray of plt.Axes"
plot_data = result.get_plot_data()
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
expected_columns = ['prediction', 'impact']
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
assert isinstance(plot_data, pd.DataFrame), (
"The returned object is not a pandas DataFrame"
)
expected_columns = ["prediction", "impact"]
assert set(expected_columns).issubset(set(plot_data.columns)), (
f"DataFrame is missing expected columns {expected_columns}"
)


@pytest.mark.integration
Expand All @@ -125,6 +130,7 @@ def test_sc():
Loads data and checks:
1. data is a dataframe
2. skl_experiements.SyntheticControl returns correct type
3. the method get_plot_data returns a DataFrame with expected columns
"""
df = cp.load_data("sc")
treatment_time = 70
Expand Down Expand Up @@ -152,9 +158,13 @@ def test_sc():
isinstance(item, plt.Axes) for item in ax
), "ax must be a numpy.ndarray of plt.Axes"
plot_data = result.get_plot_data()
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
expected_columns = ['prediction', 'impact']
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
assert isinstance(plot_data, pd.DataFrame), (
"The returned object is not a pandas DataFrame"
)
expected_columns = ["prediction", "impact"]
assert set(expected_columns).issubset(set(plot_data.columns)), (
f"DataFrame is missing expected columns {expected_columns}"
)


@pytest.mark.integration
Expand Down