diff --git a/causalpy/experiments/base.py b/causalpy/experiments/base.py index 78628aaf..431e824d 100644 --- a/causalpy/experiments/base.py +++ b/causalpy/experiments/base.py @@ -17,6 +17,7 @@ from abc import abstractmethod +import pandas as pd from sklearn.base import RegressorMixin from causalpy.pymc_models import PyMCModel @@ -59,22 +60,45 @@ def print_coefficients(self, round_to=None): def plot(self, *args, **kwargs) -> tuple: """Plot the model. - Internally, this function dispatches to either `bayesian_plot` or `ols_plot` + Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot` depending on the model type. """ if isinstance(self.model, PyMCModel): - return self.bayesian_plot(*args, **kwargs) + return self._bayesian_plot(*args, **kwargs) elif isinstance(self.model, RegressorMixin): - return self.ols_plot(*args, **kwargs) + return self._ols_plot(*args, **kwargs) else: raise ValueError("Unsupported model type") @abstractmethod - def bayesian_plot(self, *args, **kwargs): + def _bayesian_plot(self, *args, **kwargs): """Abstract method for plotting the model.""" - raise NotImplementedError("bayesian_plot method not yet implemented") + raise NotImplementedError("_bayesian_plot method not yet implemented") @abstractmethod - def ols_plot(self, *args, **kwargs): + def _ols_plot(self, *args, **kwargs): """Abstract method for plotting the model.""" - raise NotImplementedError("ols_plot method not yet implemented") + raise NotImplementedError("_ols_plot method not yet implemented") + + def get_plot_data(self, *args, **kwargs) -> pd.DataFrame: + """Recover the data of a PrePostFit experiment along with the prediction and causal impact information. + + Internally, this function dispatches to either :func:`get_plot_data_bayesian` or :func:`get_plot_data_ols` + depending on the model type. + """ + if isinstance(self.model, PyMCModel): + return self.get_plot_data_bayesian(*args, **kwargs) + elif isinstance(self.model, RegressorMixin): + return self.get_plot_data_ols(*args, **kwargs) + else: + raise ValueError("Unsupported model type") + + @abstractmethod + def get_plot_data_bayesian(self, *args, **kwargs): + """Abstract method for recovering plot data.""" + raise NotImplementedError("get_plot_data_bayesian method not yet implemented") + + @abstractmethod + def get_plot_data_ols(self, *args, **kwargs): + """Abstract method for recovering plot data.""" + raise NotImplementedError("get_plot_data_ols method not yet implemented") diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index ec704e14..37204052 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -229,7 +229,7 @@ def _causal_impact_summary_stat(self, round_to=None) -> str: """Computes the mean and 94% credible interval bounds for the causal impact.""" return f"Causal impact = {convert_to_string(self.causal_impact, round_to=round_to)}" - def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: + def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: """ Plot the results @@ -367,7 +367,7 @@ def _plot_causal_impact_arrow(results, ax): ) return fig, ax - def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: + def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: """Generate plot for difference-in-differences""" round_to = kwargs.get("round_to") fig, ax = plt.subplots() diff --git a/causalpy/experiments/prepostfit.py b/causalpy/experiments/prepostfit.py index 47138e5c..adaa0b84 100644 --- a/causalpy/experiments/prepostfit.py +++ b/causalpy/experiments/prepostfit.py @@ -25,7 +25,7 @@ from sklearn.base import RegressorMixin from causalpy.custom_exceptions import BadIndexException -from causalpy.plot_utils import plot_xY +from causalpy.plot_utils import get_hdi_to_df, plot_xY from causalpy.pymc_models import PyMCModel from causalpy.utils import round_num @@ -123,7 +123,7 @@ def summary(self, round_to=None) -> None: print(f"Formula: {self.formula}") self.print_coefficients(round_to) - def bayesian_plot( + def _bayesian_plot( self, round_to=None, **kwargs ) -> tuple[plt.Figure, List[plt.Axes]]: """ @@ -231,7 +231,7 @@ def bayesian_plot( return fig, ax - def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]: + def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]: """ Plot the results @@ -303,6 +303,70 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]] return (fig, ax) + def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame: + """ + Recover the data of a PrePostFit experiment along with the prediction and causal impact information. + + :param hdi_prob: + Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function. + """ + if isinstance(self.model, PyMCModel): + hdi_pct = int(round(hdi_prob * 100)) + + pred_lower_col = f"pred_hdi_lower_{hdi_pct}" + pred_upper_col = f"pred_hdi_upper_{hdi_pct}" + impact_lower_col = f"impact_hdi_lower_{hdi_pct}" + impact_upper_col = f"impact_hdi_upper_{hdi_pct}" + + pre_data = self.datapre.copy() + post_data = self.datapost.copy() + + pre_data["prediction"] = ( + az.extract(self.pre_pred, group="posterior_predictive", var_names="mu") + .mean("sample") + .values + ) + post_data["prediction"] = ( + az.extract(self.post_pred, group="posterior_predictive", var_names="mu") + .mean("sample") + .values + ) + pre_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df( + self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob + ).set_index(pre_data.index) + post_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df( + self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob + ).set_index(post_data.index) + + pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values + post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values + pre_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df( + self.pre_impact, hdi_prob=hdi_prob + ).set_index(pre_data.index) + post_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df( + self.post_impact, hdi_prob=hdi_prob + ).set_index(post_data.index) + + self.plot_data = pd.concat([pre_data, post_data]) + + return self.plot_data + else: + raise ValueError("Unsupported model type") + + def get_plot_data_ols(self) -> pd.DataFrame: + """ + Recover the data of a PrePostFit experiment along with the prediction and causal impact information. + """ + pre_data = self.datapre.copy() + post_data = self.datapost.copy() + pre_data["prediction"] = self.pre_pred + post_data["prediction"] = self.post_pred + pre_data["impact"] = self.pre_impact + post_data["impact"] = self.post_impact + self.plot_data = pd.concat([pre_data, post_data]) + + return self.plot_data + class InterruptedTimeSeries(PrePostFit): """ @@ -382,7 +446,7 @@ class SyntheticControl(PrePostFit): supports_ols = True supports_bayes = True - def bayesian_plot(self, *args, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]: + def _bayesian_plot(self, *args, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]: """ Plot the results @@ -393,7 +457,7 @@ def bayesian_plot(self, *args, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]: Whether to plot the control units as well. Defaults to False. """ # call the super class method - fig, ax = super().bayesian_plot(*args, **kwargs) + fig, ax = super()._bayesian_plot(*args, **kwargs) # additional plotting functionality for the synthetic control experiment plot_predictors = kwargs.get("plot_predictors", False) diff --git a/causalpy/experiments/prepostnegd.py b/causalpy/experiments/prepostnegd.py index f1854b92..c33d89dc 100644 --- a/causalpy/experiments/prepostnegd.py +++ b/causalpy/experiments/prepostnegd.py @@ -200,7 +200,7 @@ def summary(self, round_to=None) -> None: print(self._causal_impact_summary_stat(round_to)) self.print_coefficients(round_to) - def bayesian_plot( + def _bayesian_plot( self, round_to=None, **kwargs ) -> tuple[plt.Figure, List[plt.Axes]]: """Generate plot for ANOVA-like experiments with non-equivalent group designs.""" diff --git a/causalpy/experiments/regression_discontinuity.py b/causalpy/experiments/regression_discontinuity.py index 1afc5c1a..da4f98aa 100644 --- a/causalpy/experiments/regression_discontinuity.py +++ b/causalpy/experiments/regression_discontinuity.py @@ -218,7 +218,7 @@ def summary(self, round_to=None) -> None: print("\n") self.print_coefficients(round_to) - def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: + def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: """Generate plot for regression discontinuity designs.""" fig, ax = plt.subplots() # Plot raw data @@ -267,7 +267,7 @@ def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: ) return (fig, ax) - def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: + def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: """Generate plot for regression discontinuity designs.""" fig, ax = plt.subplots() # Plot raw data diff --git a/causalpy/experiments/regression_kink.py b/causalpy/experiments/regression_kink.py index 2eb35079..95ad3fcc 100644 --- a/causalpy/experiments/regression_kink.py +++ b/causalpy/experiments/regression_kink.py @@ -189,7 +189,7 @@ def summary(self, round_to=None) -> None: ) self.print_coefficients(round_to) - def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: + def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: """Generate plot for regression kink designs.""" fig, ax = plt.subplots() # Plot raw data diff --git a/causalpy/plot_utils.py b/causalpy/plot_utils.py index 6d41df4d..5ad596ce 100644 --- a/causalpy/plot_utils.py +++ b/causalpy/plot_utils.py @@ -79,3 +79,24 @@ def plot_xY( filter(lambda x: isinstance(x, PolyCollection), ax_hdi.get_children()) )[-1] return (h_line, h_patch) + + +def get_hdi_to_df( + x: xr.DataArray, + hdi_prob: float = 0.94, +) -> pd.DataFrame: + """ + Utility function to calculate and recover HDI intervals. + + :param x: + Xarray data array + :param hdi_prob: + The size of the HDI, default is 0.94 + """ + hdi = ( + az.hdi(x, hdi_prob=hdi_prob) + .to_dataframe() + .unstack(level="hdi") + .droplevel(0, axis=1) + ) + return hdi diff --git a/causalpy/tests/test_integration_pymc_examples.py b/causalpy/tests/test_integration_pymc_examples.py index 016ad567..87599bb7 100644 --- a/causalpy/tests/test_integration_pymc_examples.py +++ b/causalpy/tests/test_integration_pymc_examples.py @@ -49,6 +49,8 @@ def test_did(): fig, ax = result.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) + with pytest.raises(NotImplementedError): + result.get_plot_data() # TODO: set up fixture for the banks dataset @@ -192,6 +194,8 @@ def test_rd(): fig, ax = result.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) + with pytest.raises(NotImplementedError): + result.get_plot_data() @pytest.mark.integration @@ -311,6 +315,8 @@ def test_rkink(): fig, ax = result.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) + with pytest.raises(NotImplementedError): + result.get_plot_data() @pytest.mark.integration @@ -353,6 +359,7 @@ def test_its(): 2. causalpy.InterruptedTimeSeries returns correct type 3. the correct number of MCMC chains exists in the posterior inference data 4. the correct number of MCMC draws exists in the posterior inference data + 5. the method get_plot_data returns a DataFrame with expected columns """ df = ( cp.load_data("its") @@ -377,6 +384,22 @@ def test_its(): assert isinstance(ax, np.ndarray) and all( isinstance(item, plt.Axes) for item in ax ), "ax must be a numpy.ndarray of plt.Axes" + # Test get_plot_data with default parameters + plot_data = result.get_plot_data() + assert isinstance(plot_data, pd.DataFrame), ( + "The returned object is not a pandas DataFrame" + ) + expected_columns = [ + "prediction", + "pred_hdi_lower_94", + "pred_hdi_upper_94", + "impact", + "impact_hdi_lower_94", + "impact_hdi_upper_94", + ] + assert set(expected_columns).issubset(set(plot_data.columns)), ( + f"DataFrame is missing expected columns {expected_columns}" + ) @pytest.mark.integration @@ -389,6 +412,7 @@ def test_its_covid(): 2. causalpy.InterruptedtimeSeries returns correct type 3. the correct number of MCMC chains exists in the posterior inference data 4. the correct number of MCMC draws exists in the posterior inference data + 5. the method get_plot_data returns a DataFrame with expected columns """ df = ( @@ -414,6 +438,22 @@ def test_its_covid(): assert isinstance(ax, np.ndarray) and all( isinstance(item, plt.Axes) for item in ax ), "ax must be a numpy.ndarray of plt.Axes" + # Test get_plot_data with default parameters + plot_data = result.get_plot_data() + assert isinstance(plot_data, pd.DataFrame), ( + "The returned object is not a pandas DataFrame" + ) + expected_columns = [ + "prediction", + "pred_hdi_lower_94", + "pred_hdi_upper_94", + "impact", + "impact_hdi_lower_94", + "impact_hdi_upper_94", + ] + assert set(expected_columns).issubset(set(plot_data.columns)), ( + f"DataFrame is missing expected columns {expected_columns}" + ) @pytest.mark.integration @@ -426,6 +466,7 @@ def test_sc(): 2. causalpy.SyntheticControl returns correct type 3. the correct number of MCMC chains exists in the posterior inference data 4. the correct number of MCMC draws exists in the posterior inference data + 5. the method get_plot_data returns a DataFrame with expected columns """ df = cp.load_data("sc") @@ -455,6 +496,22 @@ def test_sc(): assert isinstance(ax, np.ndarray) and all( isinstance(item, plt.Axes) for item in ax ), "ax must be a numpy.ndarray of plt.Axes" + # Test get_plot_data with default parameters + plot_data = result.get_plot_data() + assert isinstance(plot_data, pd.DataFrame), ( + "The returned object is not a pandas DataFrame" + ) + expected_columns = [ + "prediction", + "pred_hdi_lower_94", + "pred_hdi_upper_94", + "impact", + "impact_hdi_lower_94", + "impact_hdi_upper_94", + ] + assert set(expected_columns).issubset(set(plot_data.columns)), ( + f"DataFrame is missing expected columns {expected_columns}" + ) @pytest.mark.integration @@ -467,6 +524,7 @@ def test_sc_brexit(): 2. causalpy.SyntheticControl returns correct type 3. the correct number of MCMC chains exists in the posterior inference data 4. the correct number of MCMC draws exists in the posterior inference data + 5. the method get_plot_data returns a DataFrame with expected columns """ df = ( @@ -501,6 +559,22 @@ def test_sc_brexit(): assert isinstance(ax, np.ndarray) and all( isinstance(item, plt.Axes) for item in ax ), "ax must be a numpy.ndarray of plt.Axes" + # Test get_plot_data with default parameters + plot_data = result.get_plot_data() + assert isinstance(plot_data, pd.DataFrame), ( + "The returned object is not a pandas DataFrame" + ) + expected_columns = [ + "prediction", + "pred_hdi_lower_94", + "pred_hdi_upper_94", + "impact", + "impact_hdi_lower_94", + "impact_hdi_upper_94", + ] + assert set(expected_columns).issubset(set(plot_data.columns)), ( + f"DataFrame is missing expected columns {expected_columns}" + ) @pytest.mark.integration @@ -596,6 +670,8 @@ def test_iv_reg(): assert isinstance(result, cp.InstrumentalVariable) assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"] assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"] + with pytest.raises(NotImplementedError): + result.get_plot_data() @pytest.mark.integration @@ -646,6 +722,8 @@ def test_inverse_prop(): assert isinstance(fig, plt.Figure) assert isinstance(axs, list) assert all(isinstance(ax, plt.Axes) for ax in axs) + with pytest.raises(NotImplementedError): + result.get_plot_data() # DEPRECATION WARNING TESTS ============================================================ diff --git a/causalpy/tests/test_integration_skl_examples.py b/causalpy/tests/test_integration_skl_examples.py index 944274ed..775ff4c6 100644 --- a/causalpy/tests/test_integration_skl_examples.py +++ b/causalpy/tests/test_integration_skl_examples.py @@ -47,6 +47,8 @@ def test_did(): fig, ax = result.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) + with pytest.raises(NotImplementedError): + result.get_plot_data() @pytest.mark.integration @@ -78,6 +80,8 @@ def test_rd_drinking(): fig, ax = result.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) + with pytest.raises(NotImplementedError): + result.get_plot_data() @pytest.mark.integration @@ -88,6 +92,7 @@ def test_its(): Loads data and checks: 1. data is a dataframe 2. skl_experiements.InterruptedTimeSeries returns correct type + 3. the method get_plot_data returns a DataFrame with expected columns """ df = ( @@ -111,6 +116,15 @@ def test_its(): assert isinstance(ax, np.ndarray) and all( isinstance(item, plt.Axes) for item in ax ), "ax must be a numpy.ndarray of plt.Axes" + # Test get_plot_data with default parameters + plot_data = result.get_plot_data() + assert isinstance(plot_data, pd.DataFrame), ( + "The returned object is not a pandas DataFrame" + ) + expected_columns = ["prediction", "impact"] + assert set(expected_columns).issubset(set(plot_data.columns)), ( + f"DataFrame is missing expected columns {expected_columns}" + ) @pytest.mark.integration @@ -121,6 +135,7 @@ def test_sc(): Loads data and checks: 1. data is a dataframe 2. skl_experiements.SyntheticControl returns correct type + 3. the method get_plot_data returns a DataFrame with expected columns """ df = cp.load_data("sc") treatment_time = 70 @@ -147,6 +162,15 @@ def test_sc(): assert isinstance(ax, np.ndarray) and all( isinstance(item, plt.Axes) for item in ax ), "ax must be a numpy.ndarray of plt.Axes" + # Test get_plot_data with default parameters + plot_data = result.get_plot_data() + assert isinstance(plot_data, pd.DataFrame), ( + "The returned object is not a pandas DataFrame" + ) + expected_columns = ["prediction", "impact"] + assert set(expected_columns).issubset(set(plot_data.columns)), ( + f"DataFrame is missing expected columns {expected_columns}" + ) @pytest.mark.integration diff --git a/docs/source/_static/classes.png b/docs/source/_static/classes.png index 8eaf78ce..012109d8 100644 Binary files a/docs/source/_static/classes.png and b/docs/source/_static/classes.png differ diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 2d6395ba..c698e26b 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 90.1% + interrogate: 90.4% @@ -12,8 +12,8 @@ interrogate interrogate - 90.1% - 90.1% + 90.4% + 90.4% diff --git a/docs/source/_static/packages.png b/docs/source/_static/packages.png index 0805c70d..a285c3b3 100644 Binary files a/docs/source/_static/packages.png and b/docs/source/_static/packages.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py index 70b683aa..0ce796ab 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -109,6 +109,7 @@ # -- intersphinx config ------------------------------------------------------- intersphinx_mapping = { + "arviz": ("https://python.arviz.org/en/stable/", None), "examples": ("https://www.pymc.io/projects/examples/en/latest/", None), "mpl": ("https://matplotlib.org/stable", None), "numpy": ("https://numpy.org/doc/stable/", None),