|
| 1 | +"""Explainers.countefactual module""" |
| 2 | +# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long, |
| 3 | +# pylint: disable = unused-argument |
| 4 | +from typing import Optional, List, Union |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import matplotlib as mpl |
| 7 | +import pandas as pd |
| 8 | +import numpy as np |
| 9 | +import uuid as _uuid |
| 10 | + |
| 11 | +from trustyai import _default_initializer # pylint: disable=unused-import |
| 12 | +from trustyai.utils._visualisation import ( |
| 13 | + ExplanationVisualiser, |
| 14 | + DEFAULT_STYLE as ds, |
| 15 | + DEFAULT_RC_PARAMS as drcp, |
| 16 | +) |
| 17 | + |
| 18 | +from trustyai.model import ( |
| 19 | + counterfactual_prediction, |
| 20 | + Dataset, |
| 21 | + PredictionInput, |
| 22 | +) |
| 23 | + |
| 24 | +from org.kie.trustyai.explainability.local.counterfactual import ( |
| 25 | + CounterfactualExplainer as _CounterfactualExplainer, |
| 26 | + CounterfactualResult as _CounterfactualResult, |
| 27 | + SolverConfigBuilder as _SolverConfigBuilder, |
| 28 | + CounterfactualConfig as _CounterfactualConfig, |
| 29 | +) |
| 30 | +from org.kie.trustyai.explainability.model import ( |
| 31 | + DataDistribution, |
| 32 | + Feature, |
| 33 | + Output, |
| 34 | + PredictionOutput, |
| 35 | + PredictionProvider, |
| 36 | +) |
| 37 | +from org.optaplanner.core.config.solver.termination import TerminationConfig |
| 38 | +from java.lang import Long |
| 39 | + |
| 40 | +SolverConfigBuilder = _SolverConfigBuilder |
| 41 | +CounterfactualConfig = _CounterfactualConfig |
| 42 | + |
| 43 | + |
| 44 | +class CounterfactualResult(ExplanationVisualiser): |
| 45 | + """Wraps Counterfactual results. This object is returned by the |
| 46 | + :class:`~CounterfactualExplainer`, and provides a variety of methods to visualize and interact |
| 47 | + with the results of the counterfactual explanation. |
| 48 | + """ |
| 49 | + |
| 50 | + def __init__(self, result: _CounterfactualResult) -> None: |
| 51 | + """Constructor method. This is called internally, and shouldn't ever need to be |
| 52 | + used manually.""" |
| 53 | + self._result = result |
| 54 | + |
| 55 | + @property |
| 56 | + def proposed_features_array(self): |
| 57 | + """Return the proposed feature values found from the counterfactual explanation |
| 58 | + as a Numpy array. |
| 59 | + """ |
| 60 | + return Dataset.prediction_object_to_numpy( |
| 61 | + [PredictionInput([entity.as_feature() for entity in self._result.entities])] |
| 62 | + ) |
| 63 | + |
| 64 | + @property |
| 65 | + def proposed_features_dataframe(self): |
| 66 | + """Return the proposed feature values found from the counterfactual explanation |
| 67 | + as a Pandas DataFrame. |
| 68 | + """ |
| 69 | + return Dataset.prediction_object_to_pandas( |
| 70 | + [PredictionInput([entity.as_feature() for entity in self._result.entities])] |
| 71 | + ) |
| 72 | + |
| 73 | + def as_dataframe(self) -> pd.DataFrame: |
| 74 | + """ |
| 75 | + Return the counterfactual result as a dataframe |
| 76 | +
|
| 77 | + Returns |
| 78 | + ------- |
| 79 | + pandas.DataFrame |
| 80 | + DataFrame containing the results of the counterfactual explanation, containing the |
| 81 | + following columns: |
| 82 | +
|
| 83 | + * ``Features``: The names of each input feature. |
| 84 | + * ``Proposed``: The found values of the features. |
| 85 | + * ``Original``: The original feature values. |
| 86 | + * ``Constrained``: Whether this feature was constrained (held fixed) during the search. |
| 87 | + * ``Difference``: The difference between the proposed and original values. |
| 88 | + """ |
| 89 | + entities = self._result.entities |
| 90 | + features = self._result.getFeatures() |
| 91 | + |
| 92 | + data = {} |
| 93 | + data["features"] = [f"{entity.as_feature().getName()}" for entity in entities] |
| 94 | + data["proposed"] = [entity.as_feature().value.as_obj() for entity in entities] |
| 95 | + data["original"] = [ |
| 96 | + feature.getValue().getUnderlyingObject() for feature in features |
| 97 | + ] |
| 98 | + data["constrained"] = [feature.is_constrained for feature in features] |
| 99 | + |
| 100 | + dfr = pd.DataFrame.from_dict(data) |
| 101 | + dfr["difference"] = dfr.proposed - dfr.original |
| 102 | + return dfr |
| 103 | + |
| 104 | + def as_html(self) -> pd.io.formats.style.Styler: |
| 105 | + """ |
| 106 | + Return the counterfactual result as a Pandas Styler object. |
| 107 | +
|
| 108 | + Returns |
| 109 | + ------- |
| 110 | + pandas.Styler |
| 111 | + Styler containing the results of the counterfactual explanation, in the same |
| 112 | + schema as in :func:`as_dataframe`. Currently, no default styles are applied |
| 113 | + in this particular function, making it equivalent to :code:`self.as_dataframe().style`. |
| 114 | + """ |
| 115 | + return self.as_dataframe().style |
| 116 | + |
| 117 | + def plot(self) -> None: |
| 118 | + """ |
| 119 | + Plot the counterfactual result. |
| 120 | + """ |
| 121 | + _df = self.as_dataframe().copy() |
| 122 | + _df = _df[_df["difference"] != 0.0] |
| 123 | + |
| 124 | + def change_colour(value): |
| 125 | + if value == 0.0: |
| 126 | + colour = ds["neutral_primary_colour"] |
| 127 | + elif value > 0: |
| 128 | + colour = ds["positive_primary_colour"] |
| 129 | + else: |
| 130 | + colour = ds["negative_primary_colour"] |
| 131 | + return colour |
| 132 | + |
| 133 | + with mpl.rc_context(drcp): |
| 134 | + colour = _df["difference"].transform(change_colour) |
| 135 | + plot = _df[["features", "proposed", "original"]].plot.barh( |
| 136 | + x="features", color={"proposed": colour, "original": "black"} |
| 137 | + ) |
| 138 | + plot.set_title("Counterfactual") |
| 139 | + plt.show() |
| 140 | + |
| 141 | + |
| 142 | +class CounterfactualExplainer: |
| 143 | + """*"How do I get the result I want?"* |
| 144 | +
|
| 145 | + The CounterfactualExplainer class seeks to answer this question by exploring "what-if" |
| 146 | + scenarios. Given some initial input and desired outcome, the counterfactual explainer tries to |
| 147 | + find a set of nearby inputs that produces the desired outcome. Mathematically, if we have a |
| 148 | + model :math:`f`, some input :math:`x` and a desired model output :math:`y'`, the counterfactual |
| 149 | + explainer finds some nearby input :math:`x'` such that :math:`f(x') = y'`. |
| 150 | + """ |
| 151 | + |
| 152 | + def __init__(self, steps=10_000): |
| 153 | + """ |
| 154 | + Build a new counterfactual explainer. |
| 155 | +
|
| 156 | + Parameters |
| 157 | + ---------- |
| 158 | + steps: int |
| 159 | + The number of search steps to perform during the counterfactual search. |
| 160 | + """ |
| 161 | + self._termination_config = TerminationConfig().withScoreCalculationCountLimit( |
| 162 | + Long.valueOf(steps) |
| 163 | + ) |
| 164 | + self._solver_config = ( |
| 165 | + SolverConfigBuilder.builder() |
| 166 | + .withTerminationConfig(self._termination_config) |
| 167 | + .build() |
| 168 | + ) |
| 169 | + self._cf_config = CounterfactualConfig().withSolverConfig(self._solver_config) |
| 170 | + |
| 171 | + self._explainer = _CounterfactualExplainer(self._cf_config) |
| 172 | + |
| 173 | + # pylint: disable=too-many-arguments |
| 174 | + def explain( |
| 175 | + self, |
| 176 | + inputs: Union[np.ndarray, pd.DataFrame, List[Feature], PredictionInput], |
| 177 | + goal: Union[np.ndarray, pd.DataFrame, List[Output], PredictionOutput], |
| 178 | + model: PredictionProvider, |
| 179 | + data_distribution: Optional[DataDistribution] = None, |
| 180 | + uuid: Optional[_uuid.UUID] = None, |
| 181 | + timeout: Optional[float] = None, |
| 182 | + ) -> CounterfactualResult: |
| 183 | + """Request for a counterfactual explanation given a list of features, goals and a |
| 184 | + :class:`~PredictionProvider` |
| 185 | +
|
| 186 | + Parameters |
| 187 | + ---------- |
| 188 | + inputs : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Feature`], or :class:`PredictionInput` |
| 189 | + List of input features, as a: |
| 190 | +
|
| 191 | + * Numpy array of shape ``[1, n_features]`` |
| 192 | + * Pandas DataFrame with 1 row and ``n_features`` columns |
| 193 | + * A List of TrustyAI :class:`Feature`, as created by the :func:`~feature` function |
| 194 | + * A TrustyAI :class:`PredictionInput` |
| 195 | +
|
| 196 | + goal : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Output`], or :class:`PredictionOutput` |
| 197 | + The desired model outputs to be searched for in the counterfactual explanation. |
| 198 | + These can take the form of a: |
| 199 | +
|
| 200 | + * Numpy array of shape ``[1, n_outputs]`` |
| 201 | + * Pandas DataFrame with 1 row and ``n_outputs`` columns |
| 202 | + * A List of TrustyAI :class:`Output`, as created by the :func:`~output` function |
| 203 | + * A TrustyAI :class:`PredictionOutput` |
| 204 | +
|
| 205 | + model : :obj:`~trustyai.model.PredictionProvider` |
| 206 | + The TrustyAI PredictionProvider, as generated by :class:`~trustyai.model.Model` or |
| 207 | + :class:`~trustyai.model.ArrowModel`. |
| 208 | +
|
| 209 | + data_distribution : Optional[:class:`DataDistribution`] |
| 210 | + The :class:`DataDistribution` to use when sampling the inputs. |
| 211 | + uuid : Optional[:class:`_uuid.UUID`] |
| 212 | + The UUID to use during search. |
| 213 | + timeout : Optional[float] |
| 214 | + The timeout time in seconds of the counterfactual explanation. |
| 215 | + Returns |
| 216 | + ------- |
| 217 | + :class:`~CounterfactualResult` |
| 218 | + Object containing the results of the counterfactual explanation. |
| 219 | + """ |
| 220 | + _prediction = counterfactual_prediction( |
| 221 | + input_features=inputs, |
| 222 | + outputs=goal, |
| 223 | + data_distribution=data_distribution, |
| 224 | + uuid=uuid, |
| 225 | + timeout=timeout, |
| 226 | + ) |
| 227 | + return CounterfactualResult( |
| 228 | + self._explainer.explainAsync(_prediction, model).get() |
| 229 | + ) |
0 commit comments