Skip to content

Add TSLime explainer #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ dev = [
"xgboost==1.4.2"
]
extras = [
"aix360 [default,tsice,tssaliency] @ https://github.com/Trusted-AI/AIX360/archive/refs/heads/master.zip"
"aix360 [default,tsice,tslime,tssaliency] @ https://github.com/Trusted-AI/AIX360/archive/refs/heads/master.zip"
]

[project.urls]
Expand Down
95 changes: 95 additions & 0 deletions src/trustyai/explainers/extras/tslime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Wrapper module for TSLIME from aix360.
Original at https://github.com/Trusted-AI/AIX360/
"""

from typing import Callable, List, Union

import pandas as pd
import numpy as np
from aix360.algorithms.tslime import TSLimeExplainer as TSLimeExplainerAIX
from aix360.algorithms.tslime.surrogate import LinearSurrogateModel
from pandas.io.formats.style import Styler
import matplotlib.pyplot as plt

from trustyai.explainers.explanation_results import ExplanationResults
from trustyai.utils.extras.timeseries import TSPerturber


class TSSLIMEResults(ExplanationResults):
"""Wraps TSLimeExplainer results. This object is returned by the :class:`~TSLimeExplainer`,
and provides a variety of methods to visualize and interact with the explanation.
"""

def __init__(self, explanation):
self.explanation = explanation

def as_dataframe(self) -> pd.DataFrame:
"""Returns the weights as a pandas dataframe."""
return pd.DataFrame(self.explanation["history_weights"])

def as_html(self) -> Styler:
"""Returns the explanation as an HTML table."""
dataframe = self.as_dataframe()
return dataframe.style

def plot(self):
"""Plot TSLime explanation for the time-series instance. Based on
https://github.com/Trusted-AI/AIX360/blob/master/examples/tslime/tslime_univariate_demo.ipynb"""
relevant_history = self.explanation["history_weights"].shape[0]
input_data = self.explanation["input_data"]
relevant_df = input_data[-relevant_history:]

plt.figure(layout="constrained")
plt.plot(relevant_df, label="Input Time Series", marker="o")
plt.gca().invert_yaxis()

normalized_weights = (
self.explanation["history_weights"]
/ np.mean(np.abs(self.explanation["history_weights"]))
).flatten()

plt.bar(
input_data.index[-relevant_history:],
normalized_weights,
0.4,
label="TSLime Weights (Normalized)",
color="red",
)
plt.axhline(y=0, color="r", linestyle="-", alpha=0.4)
plt.title("Time Series Lime Explanation Plot")
plt.legend(bbox_to_anchor=(1.25, 1.0), loc="upper right")
plt.show()


class TSLimeExplainer(TSLimeExplainerAIX):
"""
Wrapper for TSLimeExplainer from aix360.
"""

def __init__( # pylint: disable=too-many-arguments
self,
model: Callable,
input_length: int,
n_perturbations: int = 2000,
relevant_history: int = None,
perturbers: List[Union[TSPerturber, dict]] = None,
local_interpretable_model: LinearSurrogateModel = None,
random_seed: int = None,
):
super().__init__(
model=model,
input_length=input_length,
n_perturbations=n_perturbations,
relevant_history=relevant_history,
perturbers=perturbers,
local_interpretable_model=local_interpretable_model,
random_seed=random_seed,
)

def explain(self, inputs, outputs=None, **kwargs) -> TSSLIMEResults:
"""
Explain the model's prediction on X.
"""
_explanation = super().explain_instance(inputs, y=outputs, **kwargs)
return TSSLIMEResults(_explanation)
2 changes: 2 additions & 0 deletions src/trustyai/utils/extras/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""AIX360 model wrappers"""
from aix360.algorithms.tsutils.model_wrappers import * # pylint: disable=wildcard-import,unused-wildcard-import
File renamed without changes.
101 changes: 101 additions & 0 deletions tests/extras/test_tslime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
import unittest
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from trustyai.utils.extras.timeseries import tsFrame
from aix360.datasets import SunspotDataset
from trustyai.explainers.extras.tslime import TSLimeExplainer
from trustyai.utils.extras.timeseries import BlockBootstrapPerturber


# transform a time series dataset into a supervised learning dataset
# below sample forecaster is from: https://machinelearningmastery.com/random-forest-for-time-series-forecasting/
class RandomForestUniVariateForecaster:
def __init__(self, n_past=4, n_future=1, RFparams={"n_estimators": 250}):
self.n_past = n_past
self.n_future = n_future
self.model = RandomForestRegressor(**RFparams)

def fit(self, X):
train = self._series_to_supervised(X, n_in=self.n_past, n_out=self.n_future)
trainX, trainy = train[:, : -self.n_future], train[:, -self.n_future:]
self.model = self.model.fit(trainX, trainy)
return self

def _series_to_supervised(self, data, n_in=1, n_out=1, dropnan=True):
n_vars = 1 if type(data) is list else data.shape[1]
df = pd.DataFrame(data)
cols = list()

# input sequence (t-n, ... t-1)
for i in range(n_in, 0, -1):
cols.append(df.shift(i))
# forecast sequence (t, t+1, ... t+n)
for i in range(0, n_out):
cols.append(df.shift(-i))
# put it all together
agg = pd.concat(cols, axis=1)
# drop rows with NaN values
if dropnan:
agg.dropna(inplace=True)
return agg.values

def predict(self, X):
row = X[-self.n_past:].flatten()
y_pred = self.model.predict(np.asarray([row]))
return y_pred


class TestTSLimeExplainer(unittest.TestCase):
def setUp(self):
# load data
df, schema = SunspotDataset().load_data()
ts = tsFrame(
df, timestamp_column=schema["timestamp"], columns=schema["targets"]
)

(self.ts_train, self.ts_test) = train_test_split(
ts, shuffle=False, stratify=None, test_size=0.15, train_size=None
)

def test_tslime(self):
# load model
input_length = 24
forecast_horizon = 4
forecaster = RandomForestUniVariateForecaster(
n_past=input_length, n_future=forecast_horizon
)

forecaster.fit(self.ts_train.iloc[-200:])

# initialize/fit explainer

relevant_history = 12
explainer = TSLimeExplainer(
model=forecaster.predict,
input_length=input_length,
relevant_history=relevant_history,
perturbers=[
BlockBootstrapPerturber(
window_length=min(4, input_length - 1), block_length=2, block_swap=2
),
],
n_perturbations=10,
random_seed=22,
)

# compute explanations
test_window = self.ts_test.iloc[:input_length]
explanation = explainer.explain(test_window)

# validate explanation structure
self.assertIn("input_data", explanation.explanation)
self.assertIn("history_weights", explanation.explanation)
self.assertIn("x_perturbations", explanation.explanation)
self.assertIn("y_perturbations", explanation.explanation)
self.assertIn("model_prediction", explanation.explanation)
self.assertIn("surrogate_prediction", explanation.explanation)

self.assertEqual(explanation.explanation["history_weights"].shape[0], relevant_history)
File renamed without changes.