Skip to content

Commit 148e419

Browse files
authored
FAI-926: Implement custom counterfactual goal criteria (#140)
* Add initial GoalCriteria * Refactor wrapped functions abstract class * Refactor output casting * Add support for numpy CF criteria * Add test for missing goal and criteria * Fix formatting * Fix feature domain import
1 parent 31ecded commit 148e419

File tree

4 files changed

+293
-114
lines changed

4 files changed

+293
-114
lines changed

src/trustyai/explainers/counterfactuals.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
counterfactual_prediction,
1919
PredictionInput,
2020
Model,
21+
GoalCriteria,
2122
)
2223

23-
2424
from trustyai.utils.data_conversions import (
2525
prediction_object_to_numpy,
2626
prediction_object_to_pandas,
@@ -184,12 +184,13 @@ def __init__(self, steps=10_000):
184184
def explain(
185185
self,
186186
inputs: OneInputUnionType,
187-
goal: OneOutputUnionType,
188187
model: Union[PredictionProvider, Model],
188+
goal: Optional[OneOutputUnionType] = None,
189189
feature_domains: List[FeatureDomain] = None,
190190
data_distribution: Optional[DataDistribution] = None,
191191
uuid: Optional[_uuid.UUID] = None,
192192
timeout: Optional[float] = None,
193+
criteria: Optional[GoalCriteria] = None,
193194
) -> CounterfactualResult:
194195
r"""Request for a counterfactual explanation given a list of features, goals and a
195196
:class:`~PredictionProvider`
@@ -217,7 +218,9 @@ def explain(
217218
uuid : Optional[:class:`_uuid.UUID`]
218219
The UUID to use during search.
219220
timeout : Optional[float]
220-
The timeout time in seconds of the counterfactual explanation.
221+
The timeout time in seconds of the counterfactual explanation.
222+
criteria : Optional[:class:`GoalCriteria`]
223+
An optional custom scoring function, wrapped as a :class:`GoalCriteria`.
221224
222225
Returns
223226
-------
@@ -226,6 +229,9 @@ def explain(
226229
"""
227230
feature_names = model.feature_names if isinstance(model, Model) else None
228231
output_names = model.output_names if isinstance(model, Model) else None
232+
if goal is None and criteria is None:
233+
raise ValueError("Either a goal or criteria must be provided.")
234+
229235
_prediction = counterfactual_prediction(
230236
input_features=one_input_convert(
231237
inputs, feature_names=feature_names, feature_domains=feature_domains
@@ -236,6 +242,7 @@ def explain(
236242
data_distribution=data_distribution,
237243
uuid=uuid,
238244
timeout=timeout,
245+
criteria=criteria,
239246
)
240247

241248
with Model.NonArrowTransmission(model):

0 commit comments

Comments
 (0)