Skip to content

Commit 213ef9e

Browse files
authored
Add Permutation Feature Importance for Binary Classification (#1735)
Adding support for binary classification in Permutation Feature Importance
1 parent 08761e3 commit 213ef9e

File tree

3 files changed

+244
-63
lines changed

3 files changed

+244
-63
lines changed

src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs

+14
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,20 @@ internal Result(IExceptionContext ectx, IRow overallResult)
878878
F1Score = Fetch(BinaryClassifierEvaluator.F1);
879879
Auprc = Fetch(BinaryClassifierEvaluator.AuPrc);
880880
}
881+
882+
[BestFriend]
883+
internal Result(double auc, double accuracy, double positivePrecision, double positiveRecall,
884+
double negativePrecision, double negativeRecall, double f1Score, double auprc)
885+
{
886+
Auc = auc;
887+
Accuracy = accuracy;
888+
PositivePrecision = positivePrecision;
889+
PositiveRecall = positiveRecall;
890+
NegativePrecision = negativePrecision;
891+
NegativeRecall = negativeRecall;
892+
F1Score = f1Score;
893+
Auprc = auprc;
894+
}
881895
}
882896

883897
/// <summary>

src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs

+48
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,53 @@ private static RegressionEvaluator.Result RegressionDelta(
5555
lossFunction: a.LossFn - b.LossFn,
5656
rSquared: a.RSquared - b.RSquared);
5757
}
58+
59+
/// <summary>
60+
/// Permutation Feature Importance is a technique that calculates how much each feature 'matters' to the predictions.
61+
/// Namely, how much the model's predictions will change if we randomly permute the values of one feature across the evaluation set.
62+
/// If the quality doesn't change much, this feature is not very important. If the quality drops drastically, this was a really important feature.
63+
/// </summary>
64+
/// <param name="ctx">The binary classification context.</param>
65+
/// <param name="model">The model to evaluate.</param>
66+
/// <param name="data">The evaluation data set.</param>
67+
/// <param name="label">Label column name.</param>
68+
/// <param name="features">Feature column names.</param>
69+
/// <param name="useFeatureWeightFilter">Use features weight to pre-filter features.</param>
70+
/// <param name="topExamples">Limit the number of examples to evaluate on. null means examples (up to ~ 2 bln) from input will be used.</param>
71+
/// <returns>Array of per-feature 'contributions' to the score.</returns>
72+
public static ImmutableArray<BinaryClassifierEvaluator.Result>
73+
PermutationFeatureImportance(
74+
this BinaryClassificationContext ctx,
75+
IPredictionTransformer<IPredictor> model,
76+
IDataView data,
77+
string label = DefaultColumnNames.Label,
78+
string features = DefaultColumnNames.Features,
79+
bool useFeatureWeightFilter = false,
80+
int? topExamples = null)
81+
{
82+
return PermutationFeatureImportance<BinaryClassifierEvaluator.Result>.GetImportanceMetricsMatrix(
83+
CatalogUtils.GetEnvironment(ctx),
84+
model,
85+
data,
86+
idv => ctx.Evaluate(idv, label),
87+
BinaryClassifierDelta,
88+
features,
89+
useFeatureWeightFilter,
90+
topExamples);
91+
}
92+
93+
private static BinaryClassifierEvaluator.Result BinaryClassifierDelta(
94+
BinaryClassifierEvaluator.Result a, BinaryClassifierEvaluator.Result b)
95+
{
96+
return new BinaryClassifierEvaluator.Result(
97+
auc: a.Auc - b.Auc,
98+
accuracy: a.Accuracy - b.Accuracy,
99+
positivePrecision: a.PositivePrecision - b.PositivePrecision,
100+
positiveRecall: a.PositiveRecall - b.PositiveRecall,
101+
negativePrecision: a.NegativePrecision - b.NegativePrecision,
102+
negativeRecall: a.NegativeRecall - b.NegativeRecall,
103+
f1Score: a.F1Score - b.F1Score,
104+
auprc: a.Auprc - b.Auprc);
105+
}
58106
}
59107
}

test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs

+182-63
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using Microsoft.ML.Runtime.Data;
6+
using Microsoft.ML.Runtime.Internal.Utilities;
67
using Microsoft.ML.Runtime.RunTests;
78
using System;
89
using System.Collections.Immutable;
@@ -24,7 +25,151 @@ public PermutationFeatureImportanceTests(ITestOutputHelper output) : base(output
2425
/// Also test checks that x2 has the biggest importance.
2526
/// </summary>
2627
[Fact]
27-
public void TestDenseSGD()
28+
public void TestPfiRegressionOnDenseFeatures()
29+
{
30+
var data = GetDenseDataset();
31+
var model = ML.Regression.Trainers.OnlineGradientDescent().Fit(data);
32+
var pfi = ML.Regression.PermutationFeatureImportance(model, data);
33+
34+
// Pfi Indices:
35+
// X1: 0
36+
// X2Important: 1
37+
// X3: 2
38+
// X4Rand: 3
39+
40+
// For the following metrics lower is better, so maximum delta means more important feature, and vice versa
41+
Assert.True(MinDeltaIndex(pfi, m => m.L1) == 3);
42+
Assert.True(MaxDeltaIndex(pfi, m => m.L1) == 1);
43+
44+
Assert.True(MinDeltaIndex(pfi, m => m.L2) == 3);
45+
Assert.True(MaxDeltaIndex(pfi, m => m.L2) == 1);
46+
47+
Assert.True(MinDeltaIndex(pfi, m => m.Rms) == 3);
48+
Assert.True(MaxDeltaIndex(pfi, m => m.Rms) == 1);
49+
50+
// For the following metrics higher is better, so minimum delta means more important feature, and vice versa
51+
Assert.True(MaxDeltaIndex(pfi, m => m.RSquared) == 3);
52+
Assert.True(MinDeltaIndex(pfi, m => m.RSquared) == 1);
53+
54+
Done();
55+
}
56+
57+
/// <summary>
58+
/// Features: x1, x2vBuff(sparce vector), x3.
59+
/// y = 10x1 + 10x2vBuff + 30x3 + e.
60+
/// Within xBuff feature 2nd slot will be sparse most of the time.
61+
/// Test verifies that 2nd slot of xBuff has the least importance: L1, L2, RMS and Loss-Fn do not change a lot when this slot is permuted.
62+
/// Also test checks that x2 has the biggest importance.
63+
/// </summary>
64+
[Fact]
65+
public void TestPfiRegressionOnSparseFeatures()
66+
{
67+
var data = GetSparseDataset();
68+
var model = ML.Regression.Trainers.OnlineGradientDescent().Fit(data);
69+
var results = ML.Regression.PermutationFeatureImportance(model, data);
70+
71+
// Pfi Indices:
72+
// X1: 0
73+
// X2VBuffer-Slot-0: 1
74+
// X2VBuffer-Slot-1: 2
75+
// X2VBuffer-Slot-2: 3
76+
// X2VBuffer-Slot-3: 4
77+
// X3Important: 5
78+
79+
// Permuted X2VBuffer-Slot-1 lot (f2) should have min impact on SGD metrics, X3Important -- max impact.
80+
// For the following metrics lower is better, so maximum delta means more important feature, and vice versa
81+
Assert.True(MinDeltaIndex(results, m => m.L1) == 2);
82+
Assert.True(MaxDeltaIndex(results, m => m.L1) == 5);
83+
84+
Assert.True(MinDeltaIndex(results, m => m.L2) == 2);
85+
Assert.True(MaxDeltaIndex(results, m => m.L2) == 5);
86+
87+
Assert.True(MinDeltaIndex(results, m => m.Rms) == 2);
88+
Assert.True(MaxDeltaIndex(results, m => m.Rms) == 5);
89+
90+
// For the following metrics higher is better, so minimum delta means more important feature, and vice versa
91+
Assert.True(MaxDeltaIndex(results, m => m.RSquared) == 2);
92+
Assert.True(MinDeltaIndex(results, m => m.RSquared) == 5);
93+
}
94+
95+
[Fact]
96+
public void TestPfiBinaryClassificationOnDenseFeatures()
97+
{
98+
var data = GetDenseDataset(TaskType.BinaryClassification);
99+
var model = ML.BinaryClassification.Trainers.LogisticRegression().Fit(data);
100+
var pfi = ML.BinaryClassification.PermutationFeatureImportance(model, data);
101+
102+
// Pfi Indices:
103+
// X1: 0
104+
// X2Important: 1
105+
// X3: 2
106+
// X4Rand: 3
107+
108+
// For the following metrics higher is better, so minimum delta means more important feature, and vice versa
109+
Assert.True(MaxDeltaIndex(pfi, m => m.Auc) == 3);
110+
Assert.True(MinDeltaIndex(pfi, m => m.Auc) == 1);
111+
Assert.True(MaxDeltaIndex(pfi, m => m.Accuracy) == 3);
112+
Assert.True(MinDeltaIndex(pfi, m => m.Accuracy) == 1);
113+
Assert.True(MaxDeltaIndex(pfi, m => m.PositivePrecision) == 3);
114+
Assert.True(MinDeltaIndex(pfi, m => m.PositivePrecision) == 1);
115+
Assert.True(MaxDeltaIndex(pfi, m => m.PositiveRecall) == 3);
116+
Assert.True(MinDeltaIndex(pfi, m => m.PositiveRecall) == 1);
117+
Assert.True(MaxDeltaIndex(pfi, m => m.NegativePrecision) == 3);
118+
Assert.True(MinDeltaIndex(pfi, m => m.NegativePrecision) == 1);
119+
Assert.True(MaxDeltaIndex(pfi, m => m.NegativeRecall) == 3);
120+
Assert.True(MinDeltaIndex(pfi, m => m.NegativeRecall) == 1);
121+
Assert.True(MaxDeltaIndex(pfi, m => m.F1Score) == 3);
122+
Assert.True(MinDeltaIndex(pfi, m => m.F1Score) == 1);
123+
Assert.True(MaxDeltaIndex(pfi, m => m.Auprc) == 3);
124+
Assert.True(MinDeltaIndex(pfi, m => m.Auprc) == 1);
125+
126+
Done();
127+
}
128+
129+
/// <summary>
130+
/// Features: x1, x2vBuff(sparce vector), x3.
131+
/// y = 10x1 + 10x2vBuff + 30x3 + e.
132+
/// Within xBuff feature 2nd slot will be sparse most of the time.
133+
/// Test verifies that 2nd slot of xBuff has the least importance: L1, L2, RMS and Loss-Fn do not change a lot when this slot is permuted.
134+
/// Also test checks that x2 has the biggest importance.
135+
/// </summary>
136+
[Fact]
137+
public void TestPfiBinaryClassificationOnSparseFeatures()
138+
{
139+
var data = GetSparseDataset(TaskType.BinaryClassification);
140+
var model = ML.BinaryClassification.Trainers.LogisticRegression().Fit(data);
141+
var pfi = ML.BinaryClassification.PermutationFeatureImportance(model, data);
142+
143+
// Pfi Indices:
144+
// X1: 0
145+
// X2VBuffer-Slot-0: 1
146+
// X2VBuffer-Slot-1: 2
147+
// X2VBuffer-Slot-2: 3
148+
// X2VBuffer-Slot-3: 4
149+
// X3Important: 5
150+
151+
// For the following metrics higher is better, so minimum delta means more important feature, and vice versa
152+
Assert.True(MaxDeltaIndex(pfi, m => m.Auc) == 2);
153+
Assert.True(MinDeltaIndex(pfi, m => m.Auc) == 5);
154+
Assert.True(MaxDeltaIndex(pfi, m => m.Accuracy) == 2);
155+
Assert.True(MinDeltaIndex(pfi, m => m.Accuracy) == 5);
156+
Assert.True(MaxDeltaIndex(pfi, m => m.PositivePrecision) == 2);
157+
Assert.True(MinDeltaIndex(pfi, m => m.PositivePrecision) == 5);
158+
Assert.True(MaxDeltaIndex(pfi, m => m.PositiveRecall) == 2);
159+
Assert.True(MinDeltaIndex(pfi, m => m.PositiveRecall) == 5);
160+
Assert.True(MaxDeltaIndex(pfi, m => m.NegativePrecision) == 2);
161+
Assert.True(MinDeltaIndex(pfi, m => m.NegativePrecision) == 5);
162+
Assert.True(MaxDeltaIndex(pfi, m => m.NegativeRecall) == 2);
163+
Assert.True(MinDeltaIndex(pfi, m => m.NegativeRecall) == 5);
164+
Assert.True(MaxDeltaIndex(pfi, m => m.F1Score) == 2);
165+
Assert.True(MinDeltaIndex(pfi, m => m.F1Score) == 5);
166+
Assert.True(MaxDeltaIndex(pfi, m => m.Auprc) == 2);
167+
Assert.True(MinDeltaIndex(pfi, m => m.Auprc) == 5);
168+
169+
Done();
170+
}
171+
172+
private IDataView GetDenseDataset(TaskType task = TaskType.Regression)
28173
{
29174
// Setup synthetic dataset.
30175
const int numberOfInstances = 1000;
@@ -50,6 +195,10 @@ public void TestDenseSGD()
50195
yArray[i] = (float)(10 * x1 + 20 * x2Important + 5.5 * x3 + noise);
51196
}
52197

198+
// If binary classification, modify the labels
199+
if (task == TaskType.BinaryClassification)
200+
GetBinaryClassificationScores(yArray);
201+
53202
// Create data view.
54203
var bldr = new ArrayDataViewBuilder(Env);
55204
bldr.AddColumn("X1", NumberType.Float, x1Array);
@@ -62,41 +211,11 @@ public void TestDenseSGD()
62211
var pipeline = ML.Transforms.Concatenate("Features", "X1", "X2Important", "X3", "X4Rand")
63212
.Append(ML.Transforms.Normalize("Features"));
64213
var data = pipeline.Fit(srcDV).Transform(srcDV);
65-
var model = ML.Regression.Trainers.OnlineGradientDescent().Fit(data);
66-
var pfi = ML.Regression.PermutationFeatureImportance(model, data);
67-
68-
// Pfi Indices:
69-
// X1: 0
70-
// X2Important: 1
71-
// X3: 2
72-
// X4Rand: 3
73-
74-
// For the following metrics lower is better, so maximum delta means more important feature, and vice versa
75-
Assert.True(MinDeltaIndex(pfi, m => m.L1) == 3);
76-
Assert.True(MaxDeltaIndex(pfi, m => m.L1) == 1);
77-
78-
Assert.True(MinDeltaIndex(pfi, m => m.L2) == 3);
79-
Assert.True(MaxDeltaIndex(pfi, m => m.L2) == 1);
80-
81-
Assert.True(MinDeltaIndex(pfi, m => m.Rms) == 3);
82-
Assert.True(MaxDeltaIndex(pfi, m => m.Rms) == 1);
83-
84-
// For the following metrics higher is better, so minimum delta means more important feature, and vice versa
85-
Assert.True(MaxDeltaIndex(pfi, m => m.RSquared) == 3);
86-
Assert.True(MinDeltaIndex(pfi, m => m.RSquared) == 1);
87214

88-
Done();
215+
return data;
89216
}
90217

91-
/// <summary>
92-
/// Features: x1, x2vBuff(sparce vector), x3.
93-
/// y = 10x1 + 10x2vBuff + 30x3 + e.
94-
/// Within xBuff feature 2nd slot will be sparse most of the time.
95-
/// Test verifies that 2nd slot of xBuff has the least importance: L1, L2, RMS and Loss-Fn do not change a lot when this slot is permuted.
96-
/// Also test checks that x2 has the biggest importance.
97-
/// </summary>
98-
[Fact]
99-
public void TestSparseSGD()
218+
private IDataView GetSparseDataset(TaskType task = TaskType.Regression)
100219
{
101220
// Setup synthetic dataset.
102221
const int numberOfInstances = 10000;
@@ -137,6 +256,10 @@ public void TestSparseSGD()
137256
yArray[i] = 10 * x1 + vbSum + 20 * x3Important + noise;
138257
}
139258

259+
// If binary classification, modify the labels
260+
if (task == TaskType.BinaryClassification)
261+
GetBinaryClassificationScores(yArray);
262+
140263
// Create data view.
141264
var bldr = new ArrayDataViewBuilder(Env);
142265
bldr.AddColumn("X1", NumberType.Float, x1Array);
@@ -148,47 +271,43 @@ public void TestSparseSGD()
148271
var pipeline = ML.Transforms.Concatenate("Features", "X1", "X2VBuffer", "X3Important")
149272
.Append(ML.Transforms.Normalize("Features"));
150273
var data = pipeline.Fit(srcDV).Transform(srcDV);
151-
var model = ML.Regression.Trainers.OnlineGradientDescent().Fit(data);
152-
var results = ML.Regression.PermutationFeatureImportance(model, data);
153-
154-
// Pfi Indices:
155-
// X1: 0
156-
// X2VBuffer-Slot-0: 1
157-
// X2VBuffer-Slot-1: 2
158-
// X2VBuffer-Slot-2: 3
159-
// X2VBuffer-Slot-3: 4
160-
// X3Important: 5
161274

162-
// Permuted X2VBuffer-Slot-1 lot (f2) should have min impact on SGD metrics, X3Important -- max impact.
163-
// For the following metrics lower is better, so maximum delta means more important feature, and vice versa
164-
Assert.True(MinDeltaIndex(results, m => m.L1) == 2);
165-
Assert.True(MaxDeltaIndex(results, m => m.L1) == 5);
166-
167-
Assert.True(MinDeltaIndex(results, m => m.L2) == 2);
168-
Assert.True(MaxDeltaIndex(results, m => m.L2) == 5);
169-
170-
Assert.True(MinDeltaIndex(results, m => m.Rms) == 2);
171-
Assert.True(MaxDeltaIndex(results, m => m.Rms) == 5);
172-
173-
// For the following metrics higher is better, so minimum delta means more important feature, and vice versa
174-
Assert.True(MaxDeltaIndex(results, m => m.RSquared) == 2);
175-
Assert.True(MinDeltaIndex(results, m => m.RSquared) == 5);
275+
return data;
176276
}
177277

178-
private int MinDeltaIndex(
179-
ImmutableArray<RegressionEvaluator.Result> metricsDelta,
180-
Func<RegressionEvaluator.Result, double> metricSelector)
278+
private int MinDeltaIndex<T>(
279+
ImmutableArray<T> metricsDelta,
280+
Func<T, double> metricSelector)
181281
{
182282
var min = metricsDelta.OrderBy(m => metricSelector(m)).First();
183283
return metricsDelta.IndexOf(min);
184284
}
185285

186-
private int MaxDeltaIndex(
187-
ImmutableArray<RegressionEvaluator.Result> metricsDelta,
188-
Func<RegressionEvaluator.Result, double> metricSelector)
286+
private int MaxDeltaIndex<T>(
287+
ImmutableArray<T> metricsDelta,
288+
Func<T, double> metricSelector)
189289
{
190290
var max = metricsDelta.OrderByDescending(m => metricSelector(m)).First();
191291
return metricsDelta.IndexOf(max);
192292
}
293+
294+
private void GetBinaryClassificationScores(float[] rawScores)
295+
{
296+
// Compute the average so we can center the response
297+
float averageScore = 0.0f;
298+
for (int i = 0; i < rawScores.Length; i++)
299+
averageScore += rawScores[i];
300+
averageScore /= rawScores.Length;
301+
302+
// Center the response and then take the sigmoid to generate the classes
303+
for (int i = 0; i < rawScores.Length; i++)
304+
rawScores[i] = MathUtils.Sigmoid(rawScores[i] - averageScore) > 0.5 ? 1 : 0;
305+
}
306+
307+
private enum TaskType
308+
{
309+
Regression,
310+
BinaryClassification
311+
}
193312
}
194313
}

0 commit comments

Comments
 (0)