18
18
counterfactual_prediction ,
19
19
PredictionInput ,
20
20
Model ,
21
+ GoalCriteria ,
21
22
)
22
23
23
-
24
24
from trustyai .utils .data_conversions import (
25
25
prediction_object_to_numpy ,
26
26
prediction_object_to_pandas ,
@@ -184,12 +184,13 @@ def __init__(self, steps=10_000):
184
184
def explain (
185
185
self ,
186
186
inputs : OneInputUnionType ,
187
- goal : OneOutputUnionType ,
188
187
model : Union [PredictionProvider , Model ],
188
+ goal : Optional [OneOutputUnionType ] = None ,
189
189
feature_domains : List [FeatureDomain ] = None ,
190
190
data_distribution : Optional [DataDistribution ] = None ,
191
191
uuid : Optional [_uuid .UUID ] = None ,
192
192
timeout : Optional [float ] = None ,
193
+ criteria : Optional [GoalCriteria ] = None ,
193
194
) -> CounterfactualResult :
194
195
r"""Request for a counterfactual explanation given a list of features, goals and a
195
196
:class:`~PredictionProvider`
@@ -217,7 +218,9 @@ def explain(
217
218
uuid : Optional[:class:`_uuid.UUID`]
218
219
The UUID to use during search.
219
220
timeout : Optional[float]
220
- The timeout time in seconds of the counterfactual explanation.
221
+ The timeout time in seconds of the counterfactual explanation.
222
+ criteria : Optional[:class:`GoalCriteria`]
223
+ An optional custom scoring function, wrapped as a :class:`GoalCriteria`.
221
224
222
225
Returns
223
226
-------
@@ -226,6 +229,9 @@ def explain(
226
229
"""
227
230
feature_names = model .feature_names if isinstance (model , Model ) else None
228
231
output_names = model .output_names if isinstance (model , Model ) else None
232
+ if goal is None and criteria is None :
233
+ raise ValueError ("Either a goal or criteria must be provided." )
234
+
229
235
_prediction = counterfactual_prediction (
230
236
input_features = one_input_convert (
231
237
inputs , feature_names = feature_names , feature_domains = feature_domains
@@ -236,6 +242,7 @@ def explain(
236
242
data_distribution = data_distribution ,
237
243
uuid = uuid ,
238
244
timeout = timeout ,
245
+ criteria = criteria ,
239
246
)
240
247
241
248
with Model .NonArrowTransmission (model ):
0 commit comments