diff --git a/pyproject.toml b/pyproject.toml index 772ea28..1becf9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,12 @@ build-backend = "setuptools.build_meta" [tool.setuptools] package-dir = { "" = "src" } +[tool.pytest.ini_options] +addopts = '-m="not block_plots"' +markers = [ + "block_plots: Test plots will block execution of subsequent tests until closed" +] + [tool.setuptools.packages.find] where = ["src"] diff --git a/src/trustyai/explainers/counterfactuals.py b/src/trustyai/explainers/counterfactuals.py index bf2e071..424bf8e 100644 --- a/src/trustyai/explainers/counterfactuals.py +++ b/src/trustyai/explainers/counterfactuals.py @@ -117,7 +117,7 @@ def as_html(self) -> pd.io.formats.style.Styler: """ return self.as_dataframe().style - def plot(self) -> None: + def plot(self, block=True) -> None: """ Plot the counterfactual result. """ @@ -139,7 +139,7 @@ def change_colour(value): x="features", color={"proposed": colour, "original": "black"} ) plot.set_title("Counterfactual") - plt.show() + plt.show(block=block) class CounterfactualExplainer: diff --git a/src/trustyai/explainers/explanation_results.py b/src/trustyai/explainers/explanation_results.py index 9ece460..e21451c 100644 --- a/src/trustyai/explainers/explanation_results.py +++ b/src/trustyai/explainers/explanation_results.py @@ -29,7 +29,7 @@ def saliency_map(self): """Return the Saliencies as a dictionary, keyed by output name""" @abstractmethod - def _matplotlib_plot(self, output_name: str) -> None: + def _matplotlib_plot(self, output_name: str, block: bool) -> None: """Plot the saliencies of a particular output in matplotlib""" @abstractmethod @@ -44,7 +44,7 @@ def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]: for output_name in self.saliency_map().keys() } - def plot(self, output_name=None, render_bokeh=False) -> None: + def plot(self, output_name=None, render_bokeh=False, block=True) -> None: """ Plot the found feature saliencies. @@ -55,15 +55,17 @@ def plot(self, output_name=None, render_bokeh=False) -> None: be displayed render_bokeh : bool (default: false) Whether to render as bokeh (true) or matplotlib (false) + block: bool + (default: true) Whether displaying the plot blocks subsequent code execution """ if output_name is None: for output_name_iterator in self.saliency_map().keys(): if render_bokeh: show(self._get_bokeh_plot(output_name_iterator)) else: - self._matplotlib_plot(output_name_iterator) + self._matplotlib_plot(output_name_iterator, block) else: if render_bokeh: show(self._get_bokeh_plot(output_name)) else: - self._matplotlib_plot(output_name) + self._matplotlib_plot(output_name, block) diff --git a/src/trustyai/explainers/lime.py b/src/trustyai/explainers/lime.py index 17b810d..1e24e48 100644 --- a/src/trustyai/explainers/lime.py +++ b/src/trustyai/explainers/lime.py @@ -113,7 +113,7 @@ def as_html(self) -> pd.io.formats.style.Styler: """ return self.as_dataframe().style - def _matplotlib_plot(self, output_name: str) -> None: + def _matplotlib_plot(self, output_name: str, block=True) -> None: """Plot the LIME saliencies.""" with mpl.rc_context(drcp): dictionary = {} @@ -139,7 +139,7 @@ def _matplotlib_plot(self, output_name: str) -> None: ) plt.yticks(range(len(dictionary)), list(dictionary.keys())) plt.tight_layout() - plt.show() + plt.show(block=block) def _get_bokeh_plot(self, output_name) -> bokeh.models.Plot: lime_data_source = pd.DataFrame( diff --git a/src/trustyai/explainers/shap.py b/src/trustyai/explainers/shap.py index 5c164e1..1372ad4 100644 --- a/src/trustyai/explainers/shap.py +++ b/src/trustyai/explainers/shap.py @@ -218,7 +218,7 @@ def _color_feature_values(feature_values, background_vals): ) return df_dict - def _matplotlib_plot(self, output_name) -> None: + def _matplotlib_plot(self, output_name, block=True) -> None: """Visualize the SHAP explanation of each output as a set of candlestick plots, one per output.""" with mpl.rc_context(drcp): @@ -272,7 +272,7 @@ def _matplotlib_plot(self, output_name) -> None: plt.ylabel(self.saliency_map()[output_name].getOutput().getName()) plt.xlabel("Feature SHAP Value") plt.title(f"Explanation of {output_name}") - plt.show() + plt.show(block=block) def _get_bokeh_plot(self, output_name): fnull = self.get_fnull()[output_name] diff --git a/tests/general/test_counterfactualexplainer.py b/tests/general/test_counterfactualexplainer.py index 6e0fcdf..3b883e9 100644 --- a/tests/general/test_counterfactualexplainer.py +++ b/tests/general/test_counterfactualexplainer.py @@ -1,5 +1,6 @@ # pylint: disable=import-error, wrong-import-position, wrong-import-order, R0801 """Test suite for counterfactual explanations""" +import pytest from common import * @@ -92,7 +93,7 @@ def test_counterfactual_match_python_model(): rel=3) -def test_counterfactual_plot(): +def counterfactual_plot(block): """Test if there's a valid counterfactual with a Python model""" GOAL_VALUE = 1000 goal = np.array([[GOAL_VALUE]]) @@ -110,7 +111,16 @@ def test_counterfactual_plot(): goal=goal, model=model) - result.plot() + result.plot(block=block) + + +@pytest.mark.block_plots +def test_counterfactual_plot_blocking(): + counterfactual_plot(True) + + +def test_counterfactual_plot(): + counterfactual_plot(False) def test_counterfactual_v2(): diff --git a/tests/general/test_limeexplainer.py b/tests/general/test_limeexplainer.py index b2d4e30..7cbd80f 100644 --- a/tests/general/test_limeexplainer.py +++ b/tests/general/test_limeexplainer.py @@ -91,7 +91,7 @@ def test_normalized_weights(): assert -3.0 < feature_importance.getScore() < 3.0 -def test_lime_plots(): +def lime_plots(block): """Test normalized weights""" lime_explainer = LimeExplainer(normalise_weights=False, perturbations=2, samples=10) n_features = 15 @@ -100,10 +100,19 @@ def test_lime_plots(): outputs = model.predict([features])[0].outputs explanation = lime_explainer.explain(inputs=features, outputs=outputs, model=model) - explanation.plot() - explanation.plot(render_bokeh=True) - explanation.plot(output_name="sum-but0") - explanation.plot(output_name="sum-but0", render_bokeh=True) + explanation.plot(block=block) + explanation.plot(block=block, render_bokeh=True) + explanation.plot(block=block, output_name="sum-but0") + explanation.plot(block=block, output_name="sum-but0", render_bokeh=True) + + +@pytest.mark.block_plots +def test_lime_plots_blocking(): + lime_plots(True) + + +def test_lime_plots(): + lime_plots(False) def test_lime_v2(): diff --git a/tests/general/test_shap.py b/tests/general/test_shap.py index 0dae7a3..96d7214 100644 --- a/tests/general/test_shap.py +++ b/tests/general/test_shap.py @@ -9,7 +9,6 @@ np.random.seed(0) import pytest - from trustyai.explainers import SHAPExplainer from trustyai.model import feature, Model from trustyai.utils.data_conversions import numpy_to_prediction_object @@ -51,7 +50,7 @@ def test_shap_arrow(): assert answers[i] - 1e-2 <= feature_importance.getScore() <= answers[i] + 1e-2 -def test_shap_plots(): +def shap_plots(block): """Test SHAP plots""" np.random.seed(0) data = pd.DataFrame(np.random.rand(101, 5)) @@ -64,10 +63,19 @@ def test_shap_plots(): shap_explainer = SHAPExplainer(background=background) explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model) - explanation.plot() - explanation.plot(render_bokeh=True) - explanation.plot(output_name='output-0') - explanation.plot(output_name='output-0', render_bokeh=True) + explanation.plot(block=block) + explanation.plot(block=block, render_bokeh=True) + explanation.plot(block=block, output_name='output-0') + explanation.plot(block=block, output_name='output-0', render_bokeh=True) + + +@pytest.mark.block_plots +def test_shap_plots_blocking(): + shap_plots(block=True) + + +def test_shap_plots(): + shap_plots(block=False) def test_shap_as_df():