Skip to content

FAI-882: Add kwargs to explainers #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions src/trustyai/explainers/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def _get_bokeh_plot_dict(self):
return plot_dict


# pylint: disable=too-many-arguments
class LimeExplainer:
"""*"Which features were most important to the results?"*

Expand All @@ -203,47 +202,49 @@ class LimeExplainer:
feature that describe how strongly said feature contributed to the model's output.
"""

def __init__(
self,
perturbations=1,
seed=0,
samples=10,
penalise_sparse_balance=True,
track_counterfactuals=False,
normalise_weights=False,
use_wlr_model=True,
**kwargs,
):
def __init__(self, samples=10, **kwargs):
"""Initialize the :class:`LimeExplainer`.

Parameters
----------
perturbations: int
The starting number of feature perturbations within the explanation process.
seed: int
The random seed to be used.
samples: int
Number of samples to be generated for the local linear model training.
penalise_sparse_balance : bool
Whether to penalise features that are likely to produce linearly inseparable outputs.
This can improve the efficacy and interpretability of the outputted saliencies.
normalise_weights : bool
Whether to normalise the saliencies generated by LIME. If selected, saliencies will be
normalized between 0 and 1.

Keyword Arguments:
* penalise_sparse_balance : bool
(default=True) Whether to penalise features that are likely to produce linearly
inseparable outputs. This can improve the efficacy and interpretability of the
outputted saliencies.
* normalise_weights : bool
(default=False) Whether to normalise the saliencies generated by LIME. If selected,
saliencies will be normalized between 0 and 1.
* use_wlr_model : bool
(default=True) Whether to use a weighted linear regression as the LIME explanatory
model. If `false`, a multilayer perceptron is used, which generally has a slower
runtime,
* seed: int
(default=0) The random seed to be used.
* perturbations: int
(default=1) The starting number of feature perturbations within the explanation
process.
* trackCounterfactuals : bool
(default=False) Keep track of produced byproduct counterfactuals during LIME run.
"""
self._jrandom = Random()
self._jrandom.setSeed(seed)
self._jrandom.setSeed(kwargs.get("seed", 0))

self._lime_config = (
LimeConfig()
.withNormalizeWeights(normalise_weights)
.withPerturbationContext(PerturbationContext(self._jrandom, perturbations))
.withNormalizeWeights(kwargs.get("normalise_weights", False))
.withPerturbationContext(
PerturbationContext(self._jrandom, kwargs.get("perturbations", 1))
)
.withSamples(samples)
.withEncodingParams(EncodingParams(0.07, 0.3))
.withAdaptiveVariance(True)
.withPenalizeBalanceSparse(penalise_sparse_balance)
.withUseWLRLinearModel(use_wlr_model)
.withTrackCounterfactuals(track_counterfactuals)
.withPenalizeBalanceSparse(kwargs.get("penalise_sparse_balance", True))
.withUseWLRLinearModel(kwargs.get("use_wlr_model", True))
.withTrackCounterfactuals(kwargs.get("track_counterfactuals", False))
)

self._explainer = _LimeExplainer(self._lime_config)
Expand Down
43 changes: 21 additions & 22 deletions src/trustyai/explainers/shap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Explainers.shap module"""
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
# pylint: disable = unused-argument, consider-using-f-string, invalid-name, too-many-arguments
# pylint: disable = unused-argument, consider-using-f-string, invalid-name
from typing import Dict, Optional, List, Union
import matplotlib.pyplot as plt
import matplotlib as mpl
Expand Down Expand Up @@ -434,11 +434,7 @@ class SHAPExplainer:
def __init__(
self,
background: Union[np.ndarray, pd.DataFrame, List[PredictionInput]],
samples=None,
batch_size=20,
seed=0,
link_type: Optional[_ShapConfig.LinkType] = None,
track_counterfactuals=False,
**kwargs,
):
r"""Initialize the :class:`SHAPxplainer`.
Expand All @@ -449,23 +445,26 @@ def __init__(
or List[:class:`PredictionInput]
The set of background datapoints as an array, dataframe of shape
``[n_datapoints, n_features]``, or list of TrustyAI PredictionInputs.
samples: int
The number of samples to use when computing SHAP values. Higher values will increase
explanation accuracy, at the cost of runtime.
batch_size: int
The number of batches passed to the PredictionProvider at once. When using a
:class:`~Model` in the :func:`explain` function, this parameter has no effect. With an
:class:`~ArrowModel`, `batch_sizes` of around
:math:`\frac{2000}{\mathtt{len(background)}}` can produce significant
performance gains.
seed: int
The random seed to be used when generating explanations.
link_type : :obj:`~_ShapConfig.LinkType`
A choice of either ``trustyai.explainers._ShapConfig.LinkType.IDENTITY``
or ``trustyai.explainers._ShapConfig.LinkType.LOGIT``. If the model output is a
probability, choosing the ``LOGIT`` link will rescale explanations into log-odds units.
Otherwise, choose ``IDENTITY``.

Keyword Arguments:
* samples: int
(default=None) The number of samples to use when computing SHAP values. Higher
values will increase explanation accuracy, at the cost of runtime. If none,
samples will equal 2048 + 2*n_features
* seed: int
(default=0) The random seed to be used when generating explanations.
* batchSize: int
(default=20) The number of batches passed to the PredictionProvider at once.
When uusing :class:`~Model` with `arrow=False` this parameter has no effect.
If `arrow=True`, `batch_sizes` of around
:math:`\frac{2000}{\mathtt{len(background)}}` can produce significant
performance gains.
* trackCounterfactuals : bool
(default=False) Keep track of produced byproduct counterfactuals during SHAP run.
Returns
-------
:class:`~SHAPResults`
Expand All @@ -474,7 +473,7 @@ def __init__(
if not link_type:
link_type = _ShapConfig.LinkType.IDENTITY
self._jrandom = Random()
self._jrandom.setSeed(seed)
self._jrandom.setSeed(kwargs.get("seed", 0))
perturbation_context = PerturbationContext(self._jrandom, 0)

if isinstance(background, np.ndarray):
Expand All @@ -491,13 +490,13 @@ def __init__(
self._configbuilder = (
_ShapConfig.builder()
.withLink(link_type)
.withBatchSize(batch_size)
.withBatchSize(kwargs.get("batch_size", 20))
.withPC(perturbation_context)
.withBackground(self.background)
.withTrackCounterfactuals(track_counterfactuals)
.withTrackCounterfactuals(kwargs.get("track_counterfactuals", False))
)
if samples is not None:
self._configbuilder.withNSamples(JInt(samples))
if kwargs.get("samples") is not None:
self._configbuilder.withNSamples(JInt(kwargs["samples"]))
self._config = self._configbuilder.build()
self._explainer = _ShapKernelExplainer(self._config)

Expand Down