Skip to content

FAI-893: Make test plots non-blocking by default #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ build-backend = "setuptools.build_meta"
[tool.setuptools]
package-dir = { "" = "src" }

[tool.pytest.ini_options]
addopts = '-m="not block_plots"'
markers = [
"block_plots: Test plots will block execution of subsequent tests until closed"
]

[tool.setuptools.packages.find]
where = ["src"]

Expand Down
4 changes: 2 additions & 2 deletions src/trustyai/explainers/counterfactuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def as_html(self) -> pd.io.formats.style.Styler:
"""
return self.as_dataframe().style

def plot(self) -> None:
def plot(self, block=True) -> None:
"""
Plot the counterfactual result.
"""
Expand All @@ -139,7 +139,7 @@ def change_colour(value):
x="features", color={"proposed": colour, "original": "black"}
)
plot.set_title("Counterfactual")
plt.show()
plt.show(block=block)


class CounterfactualExplainer:
Expand Down
10 changes: 6 additions & 4 deletions src/trustyai/explainers/explanation_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def saliency_map(self):
"""Return the Saliencies as a dictionary, keyed by output name"""

@abstractmethod
def _matplotlib_plot(self, output_name: str) -> None:
def _matplotlib_plot(self, output_name: str, block: bool) -> None:
"""Plot the saliencies of a particular output in matplotlib"""

@abstractmethod
Expand All @@ -44,7 +44,7 @@ def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]:
for output_name in self.saliency_map().keys()
}

def plot(self, output_name=None, render_bokeh=False) -> None:
def plot(self, output_name=None, render_bokeh=False, block=True) -> None:
"""
Plot the found feature saliencies.

Expand All @@ -55,15 +55,17 @@ def plot(self, output_name=None, render_bokeh=False) -> None:
be displayed
render_bokeh : bool
(default: false) Whether to render as bokeh (true) or matplotlib (false)
block: bool
(default: true) Whether displaying the plot blocks subsequent code execution
"""
if output_name is None:
for output_name_iterator in self.saliency_map().keys():
if render_bokeh:
show(self._get_bokeh_plot(output_name_iterator))
else:
self._matplotlib_plot(output_name_iterator)
self._matplotlib_plot(output_name_iterator, block)
else:
if render_bokeh:
show(self._get_bokeh_plot(output_name))
else:
self._matplotlib_plot(output_name)
self._matplotlib_plot(output_name, block)
4 changes: 2 additions & 2 deletions src/trustyai/explainers/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def as_html(self) -> pd.io.formats.style.Styler:
"""
return self.as_dataframe().style

def _matplotlib_plot(self, output_name: str) -> None:
def _matplotlib_plot(self, output_name: str, block=True) -> None:
"""Plot the LIME saliencies."""
with mpl.rc_context(drcp):
dictionary = {}
Expand All @@ -139,7 +139,7 @@ def _matplotlib_plot(self, output_name: str) -> None:
)
plt.yticks(range(len(dictionary)), list(dictionary.keys()))
plt.tight_layout()
plt.show()
plt.show(block=block)

def _get_bokeh_plot(self, output_name) -> bokeh.models.Plot:
lime_data_source = pd.DataFrame(
Expand Down
4 changes: 2 additions & 2 deletions src/trustyai/explainers/shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _color_feature_values(feature_values, background_vals):
)
return df_dict

def _matplotlib_plot(self, output_name) -> None:
def _matplotlib_plot(self, output_name, block=True) -> None:
"""Visualize the SHAP explanation of each output as a set of candlestick plots,
one per output."""
with mpl.rc_context(drcp):
Expand Down Expand Up @@ -272,7 +272,7 @@ def _matplotlib_plot(self, output_name) -> None:
plt.ylabel(self.saliency_map()[output_name].getOutput().getName())
plt.xlabel("Feature SHAP Value")
plt.title(f"Explanation of {output_name}")
plt.show()
plt.show(block=block)

def _get_bokeh_plot(self, output_name):
fnull = self.get_fnull()[output_name]
Expand Down
14 changes: 12 additions & 2 deletions tests/general/test_counterfactualexplainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=import-error, wrong-import-position, wrong-import-order, R0801
"""Test suite for counterfactual explanations"""
import pytest

from common import *

Expand Down Expand Up @@ -92,7 +93,7 @@ def test_counterfactual_match_python_model():
rel=3)


def test_counterfactual_plot():
def counterfactual_plot(block):
"""Test if there's a valid counterfactual with a Python model"""
GOAL_VALUE = 1000
goal = np.array([[GOAL_VALUE]])
Expand All @@ -110,7 +111,16 @@ def test_counterfactual_plot():
goal=goal,
model=model)

result.plot()
result.plot(block=block)


@pytest.mark.block_plots
def test_counterfactual_plot_blocking():
counterfactual_plot(True)


def test_counterfactual_plot():
counterfactual_plot(False)


def test_counterfactual_v2():
Expand Down
19 changes: 14 additions & 5 deletions tests/general/test_limeexplainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_normalized_weights():
assert -3.0 < feature_importance.getScore() < 3.0


def test_lime_plots():
def lime_plots(block):
"""Test normalized weights"""
lime_explainer = LimeExplainer(normalise_weights=False, perturbations=2, samples=10)
n_features = 15
Expand All @@ -100,10 +100,19 @@ def test_lime_plots():
outputs = model.predict([features])[0].outputs

explanation = lime_explainer.explain(inputs=features, outputs=outputs, model=model)
explanation.plot()
explanation.plot(render_bokeh=True)
explanation.plot(output_name="sum-but0")
explanation.plot(output_name="sum-but0", render_bokeh=True)
explanation.plot(block=block)
explanation.plot(block=block, render_bokeh=True)
explanation.plot(block=block, output_name="sum-but0")
explanation.plot(block=block, output_name="sum-but0", render_bokeh=True)


@pytest.mark.block_plots
def test_lime_plots_blocking():
lime_plots(True)


def test_lime_plots():
lime_plots(False)


def test_lime_v2():
Expand Down
20 changes: 14 additions & 6 deletions tests/general/test_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
np.random.seed(0)

import pytest

from trustyai.explainers import SHAPExplainer
from trustyai.model import feature, Model
from trustyai.utils.data_conversions import numpy_to_prediction_object
Expand Down Expand Up @@ -51,7 +50,7 @@ def test_shap_arrow():
assert answers[i] - 1e-2 <= feature_importance.getScore() <= answers[i] + 1e-2


def test_shap_plots():
def shap_plots(block):
"""Test SHAP plots"""
np.random.seed(0)
data = pd.DataFrame(np.random.rand(101, 5))
Expand All @@ -64,10 +63,19 @@ def test_shap_plots():
shap_explainer = SHAPExplainer(background=background)
explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model)

explanation.plot()
explanation.plot(render_bokeh=True)
explanation.plot(output_name='output-0')
explanation.plot(output_name='output-0', render_bokeh=True)
explanation.plot(block=block)
explanation.plot(block=block, render_bokeh=True)
explanation.plot(block=block, output_name='output-0')
explanation.plot(block=block, output_name='output-0', render_bokeh=True)


@pytest.mark.block_plots
def test_shap_plots_blocking():
shap_plots(block=True)


def test_shap_plots():
shap_plots(block=False)


def test_shap_as_df():
Expand Down