Skip to content

Commit f25bd4b

Browse files
authored
Adding a binary classification PFI Example (#1793)
* Adding a binary classification PFI Example, breaking the PFI examples into different files in a subfolder, and correcting XMLDocs links.
1 parent c54086b commit f25bd4b

File tree

4 files changed

+177
-75
lines changed

4 files changed

+177
-75
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
using Microsoft.ML.Runtime.Data;
2+
using Microsoft.ML.Runtime.Learners;
3+
using Microsoft.ML.Trainers.HalLearners;
4+
using System;
5+
using System.Linq;
6+
7+
namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance
8+
{
9+
public class PfiHelper
10+
{
11+
public static IDataView GetHousingRegressionIDataView(MLContext mlContext, out string labelName, out string[] featureNames, bool binaryPrediction = false)
12+
{
13+
// Download the dataset from github.com/dotnet/machinelearning.
14+
// This will create a housing.txt file in the filesystem.
15+
// You can open this file to see the data.
16+
string dataFile = SamplesUtils.DatasetUtils.DownloadHousingRegressionDataset();
17+
18+
// Read the data as an IDataView.
19+
// First, we define the reader: specify the data columns and where to find them in the text file.
20+
// The data file is composed of rows of data, with each row having 11 numerical columns
21+
// separated by whitespace.
22+
var reader = mlContext.Data.CreateTextReader(
23+
columns: new[]
24+
{
25+
// Read the first column (indexed by 0) in the data file as an R4 (float)
26+
new TextLoader.Column("MedianHomeValue", DataKind.R4, 0),
27+
new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1),
28+
new TextLoader.Column("PercentResidental", DataKind.R4, 2),
29+
new TextLoader.Column("PercentNonRetail", DataKind.R4, 3),
30+
new TextLoader.Column("CharlesRiver", DataKind.R4, 4),
31+
new TextLoader.Column("NitricOxides", DataKind.R4, 5),
32+
new TextLoader.Column("RoomsPerDwelling", DataKind.R4, 6),
33+
new TextLoader.Column("PercentPre40s", DataKind.R4, 7),
34+
new TextLoader.Column("EmploymentDistance", DataKind.R4, 8),
35+
new TextLoader.Column("HighwayDistance", DataKind.R4, 9),
36+
new TextLoader.Column("TaxRate", DataKind.R4, 10),
37+
new TextLoader.Column("TeacherRatio", DataKind.R4, 11),
38+
},
39+
hasHeader: true
40+
);
41+
42+
// Read the data
43+
var data = reader.Read(dataFile);
44+
var labelColumn = "MedianHomeValue";
45+
46+
if (binaryPrediction)
47+
{
48+
labelColumn = nameof(BinaryOutputRow.AboveAverage);
49+
data = mlContext.Transforms.CustomMappingTransformer(GreaterThanAverage, null).Transform(data);
50+
data = mlContext.Transforms.DropColumns("MedianHomeValue").Fit(data).Transform(data);
51+
}
52+
53+
labelName = labelColumn;
54+
featureNames = data.Schema.AsEnumerable()
55+
.Select(column => column.Name) // Get the column names
56+
.Where(name => name != labelColumn) // Drop the Label
57+
.ToArray();
58+
59+
return data;
60+
}
61+
62+
// Define a class for all the input columns that we intend to consume.
63+
private class ContinuousInputRow
64+
{
65+
public float MedianHomeValue { get; set; }
66+
}
67+
68+
// Define a class for all output columns that we intend to produce.
69+
private class BinaryOutputRow
70+
{
71+
public bool AboveAverage { get; set; }
72+
}
73+
74+
// Define an Action to apply a custom mapping from one object to the other
75+
private readonly static Action<ContinuousInputRow, BinaryOutputRow> GreaterThanAverage = (input, output)
76+
=> output.AboveAverage = input.MedianHomeValue > 22.6;
77+
78+
public static float[] GetLinearModelWeights(OlsLinearRegressionModelParameters linearModel)
79+
{
80+
return linearModel.Weights.ToArray();
81+
}
82+
83+
public static float[] GetLinearModelWeights(LinearBinaryModelParameters linearModel)
84+
{
85+
return linearModel.Weights.ToArray();
86+
}
87+
}
88+
}
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,39 @@
1-
using Microsoft.ML.Runtime.Data;
2-
using Microsoft.ML.Runtime.Learners;
3-
using Microsoft.ML.Trainers.HalLearners;
4-
using System;
1+
using System;
52
using System.Linq;
63

7-
namespace Microsoft.ML.Samples.Dynamic
4+
namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance
85
{
9-
public class PFI_RegressionExample
6+
public class PfiRegressionExample
107
{
11-
public static void PFI_Regression()
8+
public static void RunExample()
129
{
13-
// Download the dataset from github.com/dotnet/machinelearning.
14-
// This will create a housing.txt file in the filesystem.
15-
// You can open this file to see the data.
16-
string dataFile = SamplesUtils.DatasetUtils.DownloadHousingRegressionDataset();
17-
1810
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
1911
// as a catalog of available operations and as the source of randomness.
2012
var mlContext = new MLContext();
2113

22-
// Step 1: Read the data as an IDataView.
23-
// First, we define the reader: specify the data columns and where to find them in the text file.
24-
// The data file is composed of rows of data, with each row having 11 numerical columns
25-
// separated by whitespace.
26-
var reader = mlContext.Data.CreateTextReader(
27-
columns: new[]
28-
{
29-
// Read the first column (indexed by 0) in the data file as an R4 (float)
30-
new TextLoader.Column("MedianHomeValue", DataKind.R4, 0),
31-
new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1),
32-
new TextLoader.Column("PercentResidental", DataKind.R4, 2),
33-
new TextLoader.Column("PercentNonRetail", DataKind.R4, 3),
34-
new TextLoader.Column("CharlesRiver", DataKind.R4, 4),
35-
new TextLoader.Column("NitricOxides", DataKind.R4, 5),
36-
new TextLoader.Column("RoomsPerDwelling", DataKind.R4, 6),
37-
new TextLoader.Column("PercentPre40s", DataKind.R4, 7),
38-
new TextLoader.Column("EmploymentDistance", DataKind.R4, 8),
39-
new TextLoader.Column("HighwayDistance", DataKind.R4, 9),
40-
new TextLoader.Column("TaxRate", DataKind.R4, 10),
41-
new TextLoader.Column("TeacherRatio", DataKind.R4, 11)
42-
},
43-
hasHeader: true
44-
);
45-
46-
// Read the data
47-
var data = reader.Read(dataFile);
14+
// Step 1: Read the data
15+
var data = PfiHelper.GetHousingRegressionIDataView(mlContext, out string labelName, out string[] featureNames);
4816

4917
// Step 2: Pipeline
5018
// Concatenate the features to create a Feature vector.
5119
// Normalize the data set so that for each feature, its maximum value is 1 while its minimum value is 0.
52-
// Then append a linear regression trainer, setting the "MedianHomeValue" column as the label of the dataset,
53-
// the "Features" column produced by concatenation as the features of the dataset.
54-
var labelName = "MedianHomeValue";
55-
var pipeline = mlContext.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental",
56-
"PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", "PercentPre40s",
57-
"EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
20+
// Then append a linear regression trainer.
21+
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
5822
.Append(mlContext.Transforms.Normalize("Features"))
5923
.Append(mlContext.Regression.Trainers.OrdinaryLeastSquares(
6024
labelColumn: labelName, featureColumn: "Features"));
61-
6225
var model = pipeline.Fit(data);
26+
6327
// Extract the model from the pipeline
6428
var linearPredictor = model.LastTransformer;
65-
var weights = GetLinearModelWeights(linearPredictor.Model);
29+
var weights = PfiHelper.GetLinearModelWeights(linearPredictor.Model);
6630

67-
// Compute the permutation metrics using the properly-featurized data.
31+
// Compute the permutation metrics using the properly normalized data.
6832
var transformedData = model.Transform(data);
6933
var permutationMetrics = mlContext.Regression.PermutationFeatureImportance(
7034
linearPredictor, transformedData, label: labelName, features: "Features", permutationCount: 3);
7135

7236
// Now let's look at which features are most important to the model overall
73-
// First, we have to prepare the data:
74-
// Get the feature names as an IEnumerable
75-
var featureNames = data.Schema
76-
.Select(column => column.Name) // Get the column names
77-
.Where(name => name != labelName) // Drop the Label
78-
.ToArray();
79-
8037
// Get the feature indices sorted by their impact on R-Squared
8138
var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.RSquared })
8239
.OrderByDescending(feature => Math.Abs(feature.RSquared.Mean))
@@ -116,10 +73,5 @@ public static void PFI_Regression()
11673
Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i].Mean:G4}\t{1.96 * rSquared[i].StandardError:G4}");
11774
}
11875
}
119-
120-
private static float[] GetLinearModelWeights(OlsLinearRegressionModelParameters linearModel)
121-
{
122-
return linearModel.Weights.ToArray();
123-
}
12476
}
12577
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
using Microsoft.ML.Runtime.Learners;
2+
using System;
3+
using System.Linq;
4+
5+
namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance
6+
{
7+
public class PfiBinaryClassificationExample
8+
{
9+
public static void RunExample()
10+
{
11+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
12+
// as a catalog of available operations and as the source of randomness.
13+
var mlContext = new MLContext(seed:999123);
14+
15+
// Step 1: Read the data
16+
var data = PfiHelper.GetHousingRegressionIDataView(mlContext,
17+
out string labelName, out string[] featureNames, binaryPrediction: true);
18+
19+
// Step 2: Pipeline
20+
// Concatenate the features to create a Feature vector.
21+
// Normalize the data set so that for each feature, its maximum value is 1 while its minimum value is 0.
22+
// Then append a logistic regression trainer.
23+
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
24+
.Append(mlContext.Transforms.Normalize("Features"))
25+
.Append(mlContext.BinaryClassification.Trainers.LogisticRegression(
26+
labelColumn: labelName, featureColumn: "Features"));
27+
var model = pipeline.Fit(data);
28+
29+
// Extract the model from the pipeline
30+
var linearPredictor = model.LastTransformer;
31+
// Linear models for binary classification are wrapped by a calibrator as a generic predictor
32+
// To access it directly, we must extract it out and cast it to the proper class
33+
var weights = PfiHelper.GetLinearModelWeights(linearPredictor.Model.SubPredictor as LinearBinaryModelParameters);
34+
35+
// Compute the permutation metrics using the properly normalized data.
36+
var transformedData = model.Transform(data);
37+
var permutationMetrics = mlContext.BinaryClassification.PermutationFeatureImportance(
38+
linearPredictor, transformedData, label: labelName, features: "Features", permutationCount: 3);
39+
40+
// Now let's look at which features are most important to the model overall
41+
// Get the feature indices sorted by their impact on AUC
42+
var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.Auc })
43+
.OrderByDescending(feature => Math.Abs(feature.Auc.Mean))
44+
.Select(feature => feature.index);
45+
46+
// Print out the permutation results, with the model weights, in order of their impact:
47+
// Expected console output (for 100 permutations):
48+
// Feature Model Weight Change in AUC 95% Confidence in the Mean Change in AUC
49+
// PercentPre40s -1.96 -0.06316 0.002377
50+
// RoomsPerDwelling 3.71 -0.04385 0.001245
51+
// EmploymentDistance -1.31 -0.02139 0.0006867
52+
// TeacherRatio -2.46 -0.0203 0.0009566
53+
// PercentNonRetail -1.58 -0.01846 0.001586
54+
// CharlesRiver 0.66 -0.008605 0.0005136
55+
// PercentResidental 0.60 0.002483 0.0004818
56+
// TaxRate -0.95 -0.00221 0.0007394
57+
// NitricOxides -0.32 0.00101 0.0001428
58+
// CrimesPerCapita -0.04 -3.029E-05 1.678E-05
59+
// HighwayDistance 0.00 0 0
60+
// Let's look at these results.
61+
// First, if you look at the weights of the model, they generally correlate with the results of PFI,
62+
// but there are some significant misorderings. See the discussion in the Regression example for an
63+
// explanation of why this happens and how to interpret it.
64+
// Second, the logistic regression learner uses L1 regularization by default. Here, it causes the "HighWay Distance"
65+
// feature to be zeroed out from the model. PFI assigns zero importance to this variable, as expected.
66+
// Third, some features show an *increase* in AUC. This means that the model actually improved
67+
// when these features were shuffled. This is a sign to investigate these features further.
68+
Console.WriteLine("Feature\tModel Weight\tChange in AUC\t95% Confidence in the Mean Change in AUC");
69+
var auc = permutationMetrics.Select(x => x.Auc).ToArray(); // Fetch AUC as an array
70+
foreach (int i in sortedIndices)
71+
{
72+
Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{auc[i].Mean:G4}\t{1.96 * auc[i].StandardError:G4}");
73+
}
74+
}
75+
}
76+
}

src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public static class PermutationFeatureImportanceExtensions
4343
/// <example>
4444
/// <format type="text/markdown">
4545
/// <![CDATA[
46-
/// [!code-csharp[PFI](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs)]
46+
/// [!code-csharp[PFI](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs)]
4747
/// ]]>
4848
/// </format>
4949
/// </example>
@@ -120,7 +120,7 @@ private static RegressionMetrics RegressionDelta(
120120
/// <example>
121121
/// <format type="text/markdown">
122122
/// <![CDATA[
123-
/// [!code-csharp[PFI](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs)]
123+
/// [!code-csharp[PFI](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs)]
124124
/// ]]>
125125
/// </format>
126126
/// </example>
@@ -198,13 +198,6 @@ private static BinaryClassificationMetrics BinaryClassifierDelta(
198198
/// example of working with these results to analyze the feature importance of a model.
199199
/// </para>
200200
/// </remarks>
201-
/// <example>
202-
/// <format type="text/markdown">
203-
/// <![CDATA[
204-
/// [!code-csharp[PFI](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs)]
205-
/// ]]>
206-
/// </format>
207-
/// </example>
208201
/// <param name="ctx">The clustering context.</param>
209202
/// <param name="model">The model to evaluate.</param>
210203
/// <param name="data">The evaluation data set.</param>
@@ -284,13 +277,6 @@ private static MultiClassClassifierMetrics MulticlassClassificationDelta(
284277
/// example of working with these results to analyze the feature importance of a model.
285278
/// </para>
286279
/// </remarks>
287-
/// <example>
288-
/// <format type="text/markdown">
289-
/// <![CDATA[
290-
/// [!code-csharp[PFI](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs)]
291-
/// ]]>
292-
/// </format>
293-
/// </example>
294280
/// <param name="ctx">The clustering context.</param>
295281
/// <param name="model">The model to evaluate.</param>
296282
/// <param name="data">The evaluation data set.</param>

0 commit comments

Comments
 (0)