Skip to content

Commit 56125c7

Browse files
authored
FAI-880: Move explainers into separate files (#110)
* Move explainers into separate files * linting * disabled dup-code warning * added dup-code warnings to lime and shap * fixed duplicate code warnings * remembered to include model in commit * fixed import of predunion type accidentally being from java, not trustyai.model * updated linting for new version
1 parent 4016b75 commit 56125c7

File tree

6 files changed

+529
-474
lines changed

6 files changed

+529
-474
lines changed

src/trustyai/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# pylint: disable = import-error, import-outside-toplevel, dangerous-default-value
2-
# pylint: disable = invalid-name, R0801
2+
# pylint: disable = invalid-name, R0801, duplicate-code
33
"""Main TrustyAI Python bindings"""
44
import os
55
import logging

src/trustyai/explainers/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Explainers module"""
2+
# pylint: disable=duplicate-code
3+
from .counterfactuals import CounterfactualResult, CounterfactualExplainer
4+
from .lime import LimeExplainer, LimeResults
5+
from .shap import SHAPExplainer, SHAPResults
+229
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
"""Explainers.countefactual module"""
2+
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
3+
# pylint: disable = unused-argument
4+
from typing import Optional, List, Union
5+
import matplotlib.pyplot as plt
6+
import matplotlib as mpl
7+
import pandas as pd
8+
import numpy as np
9+
import uuid as _uuid
10+
11+
from trustyai import _default_initializer # pylint: disable=unused-import
12+
from trustyai.utils._visualisation import (
13+
ExplanationVisualiser,
14+
DEFAULT_STYLE as ds,
15+
DEFAULT_RC_PARAMS as drcp,
16+
)
17+
18+
from trustyai.model import (
19+
counterfactual_prediction,
20+
Dataset,
21+
PredictionInput,
22+
)
23+
24+
from org.kie.trustyai.explainability.local.counterfactual import (
25+
CounterfactualExplainer as _CounterfactualExplainer,
26+
CounterfactualResult as _CounterfactualResult,
27+
SolverConfigBuilder as _SolverConfigBuilder,
28+
CounterfactualConfig as _CounterfactualConfig,
29+
)
30+
from org.kie.trustyai.explainability.model import (
31+
DataDistribution,
32+
Feature,
33+
Output,
34+
PredictionOutput,
35+
PredictionProvider,
36+
)
37+
from org.optaplanner.core.config.solver.termination import TerminationConfig
38+
from java.lang import Long
39+
40+
SolverConfigBuilder = _SolverConfigBuilder
41+
CounterfactualConfig = _CounterfactualConfig
42+
43+
44+
class CounterfactualResult(ExplanationVisualiser):
45+
"""Wraps Counterfactual results. This object is returned by the
46+
:class:`~CounterfactualExplainer`, and provides a variety of methods to visualize and interact
47+
with the results of the counterfactual explanation.
48+
"""
49+
50+
def __init__(self, result: _CounterfactualResult) -> None:
51+
"""Constructor method. This is called internally, and shouldn't ever need to be
52+
used manually."""
53+
self._result = result
54+
55+
@property
56+
def proposed_features_array(self):
57+
"""Return the proposed feature values found from the counterfactual explanation
58+
as a Numpy array.
59+
"""
60+
return Dataset.prediction_object_to_numpy(
61+
[PredictionInput([entity.as_feature() for entity in self._result.entities])]
62+
)
63+
64+
@property
65+
def proposed_features_dataframe(self):
66+
"""Return the proposed feature values found from the counterfactual explanation
67+
as a Pandas DataFrame.
68+
"""
69+
return Dataset.prediction_object_to_pandas(
70+
[PredictionInput([entity.as_feature() for entity in self._result.entities])]
71+
)
72+
73+
def as_dataframe(self) -> pd.DataFrame:
74+
"""
75+
Return the counterfactual result as a dataframe
76+
77+
Returns
78+
-------
79+
pandas.DataFrame
80+
DataFrame containing the results of the counterfactual explanation, containing the
81+
following columns:
82+
83+
* ``Features``: The names of each input feature.
84+
* ``Proposed``: The found values of the features.
85+
* ``Original``: The original feature values.
86+
* ``Constrained``: Whether this feature was constrained (held fixed) during the search.
87+
* ``Difference``: The difference between the proposed and original values.
88+
"""
89+
entities = self._result.entities
90+
features = self._result.getFeatures()
91+
92+
data = {}
93+
data["features"] = [f"{entity.as_feature().getName()}" for entity in entities]
94+
data["proposed"] = [entity.as_feature().value.as_obj() for entity in entities]
95+
data["original"] = [
96+
feature.getValue().getUnderlyingObject() for feature in features
97+
]
98+
data["constrained"] = [feature.is_constrained for feature in features]
99+
100+
dfr = pd.DataFrame.from_dict(data)
101+
dfr["difference"] = dfr.proposed - dfr.original
102+
return dfr
103+
104+
def as_html(self) -> pd.io.formats.style.Styler:
105+
"""
106+
Return the counterfactual result as a Pandas Styler object.
107+
108+
Returns
109+
-------
110+
pandas.Styler
111+
Styler containing the results of the counterfactual explanation, in the same
112+
schema as in :func:`as_dataframe`. Currently, no default styles are applied
113+
in this particular function, making it equivalent to :code:`self.as_dataframe().style`.
114+
"""
115+
return self.as_dataframe().style
116+
117+
def plot(self) -> None:
118+
"""
119+
Plot the counterfactual result.
120+
"""
121+
_df = self.as_dataframe().copy()
122+
_df = _df[_df["difference"] != 0.0]
123+
124+
def change_colour(value):
125+
if value == 0.0:
126+
colour = ds["neutral_primary_colour"]
127+
elif value > 0:
128+
colour = ds["positive_primary_colour"]
129+
else:
130+
colour = ds["negative_primary_colour"]
131+
return colour
132+
133+
with mpl.rc_context(drcp):
134+
colour = _df["difference"].transform(change_colour)
135+
plot = _df[["features", "proposed", "original"]].plot.barh(
136+
x="features", color={"proposed": colour, "original": "black"}
137+
)
138+
plot.set_title("Counterfactual")
139+
plt.show()
140+
141+
142+
class CounterfactualExplainer:
143+
"""*"How do I get the result I want?"*
144+
145+
The CounterfactualExplainer class seeks to answer this question by exploring "what-if"
146+
scenarios. Given some initial input and desired outcome, the counterfactual explainer tries to
147+
find a set of nearby inputs that produces the desired outcome. Mathematically, if we have a
148+
model :math:`f`, some input :math:`x` and a desired model output :math:`y'`, the counterfactual
149+
explainer finds some nearby input :math:`x'` such that :math:`f(x') = y'`.
150+
"""
151+
152+
def __init__(self, steps=10_000):
153+
"""
154+
Build a new counterfactual explainer.
155+
156+
Parameters
157+
----------
158+
steps: int
159+
The number of search steps to perform during the counterfactual search.
160+
"""
161+
self._termination_config = TerminationConfig().withScoreCalculationCountLimit(
162+
Long.valueOf(steps)
163+
)
164+
self._solver_config = (
165+
SolverConfigBuilder.builder()
166+
.withTerminationConfig(self._termination_config)
167+
.build()
168+
)
169+
self._cf_config = CounterfactualConfig().withSolverConfig(self._solver_config)
170+
171+
self._explainer = _CounterfactualExplainer(self._cf_config)
172+
173+
# pylint: disable=too-many-arguments
174+
def explain(
175+
self,
176+
inputs: Union[np.ndarray, pd.DataFrame, List[Feature], PredictionInput],
177+
goal: Union[np.ndarray, pd.DataFrame, List[Output], PredictionOutput],
178+
model: PredictionProvider,
179+
data_distribution: Optional[DataDistribution] = None,
180+
uuid: Optional[_uuid.UUID] = None,
181+
timeout: Optional[float] = None,
182+
) -> CounterfactualResult:
183+
"""Request for a counterfactual explanation given a list of features, goals and a
184+
:class:`~PredictionProvider`
185+
186+
Parameters
187+
----------
188+
inputs : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Feature`], or :class:`PredictionInput`
189+
List of input features, as a:
190+
191+
* Numpy array of shape ``[1, n_features]``
192+
* Pandas DataFrame with 1 row and ``n_features`` columns
193+
* A List of TrustyAI :class:`Feature`, as created by the :func:`~feature` function
194+
* A TrustyAI :class:`PredictionInput`
195+
196+
goal : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Output`], or :class:`PredictionOutput`
197+
The desired model outputs to be searched for in the counterfactual explanation.
198+
These can take the form of a:
199+
200+
* Numpy array of shape ``[1, n_outputs]``
201+
* Pandas DataFrame with 1 row and ``n_outputs`` columns
202+
* A List of TrustyAI :class:`Output`, as created by the :func:`~output` function
203+
* A TrustyAI :class:`PredictionOutput`
204+
205+
model : :obj:`~trustyai.model.PredictionProvider`
206+
The TrustyAI PredictionProvider, as generated by :class:`~trustyai.model.Model` or
207+
:class:`~trustyai.model.ArrowModel`.
208+
209+
data_distribution : Optional[:class:`DataDistribution`]
210+
The :class:`DataDistribution` to use when sampling the inputs.
211+
uuid : Optional[:class:`_uuid.UUID`]
212+
The UUID to use during search.
213+
timeout : Optional[float]
214+
The timeout time in seconds of the counterfactual explanation.
215+
Returns
216+
-------
217+
:class:`~CounterfactualResult`
218+
Object containing the results of the counterfactual explanation.
219+
"""
220+
_prediction = counterfactual_prediction(
221+
input_features=inputs,
222+
outputs=goal,
223+
data_distribution=data_distribution,
224+
uuid=uuid,
225+
timeout=timeout,
226+
)
227+
return CounterfactualResult(
228+
self._explainer.explainAsync(_prediction, model).get()
229+
)

0 commit comments

Comments
 (0)