Skip to content

Commit 8655730

Browse files
author
Rogan Carr
committed
Adding a binary classification PFI Example
1 parent 41d3196 commit 8655730

File tree

1 file changed

+152
-46
lines changed

1 file changed

+152
-46
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs

Lines changed: 152 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,55 +6,22 @@
66

77
namespace Microsoft.ML.Samples.Dynamic
88
{
9-
public class PFI_RegressionExample
9+
public class PermutationFeatureImportance_Examples
1010
{
1111
public static void PFI_Regression()
1212
{
13-
// Download the dataset from github.com/dotnet/machinelearning.
14-
// This will create a housing.txt file in the filesystem.
15-
// You can open this file to see the data.
16-
string dataFile = SamplesUtils.DatasetUtils.DownloadHousingRegressionDataset();
17-
1813
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
1914
// as a catalog of available operations and as the source of randomness.
2015
var mlContext = new MLContext();
2116

22-
// Step 1: Read the data as an IDataView.
23-
// First, we define the reader: specify the data columns and where to find them in the text file.
24-
// The data file is composed of rows of data, with each row having 11 numerical columns
25-
// separated by whitespace.
26-
var reader = mlContext.Data.CreateTextReader(
27-
columns: new[]
28-
{
29-
// Read the first column (indexed by 0) in the data file as an R4 (float)
30-
new TextLoader.Column("MedianHomeValue", DataKind.R4, 0),
31-
new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1),
32-
new TextLoader.Column("PercentResidental", DataKind.R4, 2),
33-
new TextLoader.Column("PercentNonRetail", DataKind.R4, 3),
34-
new TextLoader.Column("CharlesRiver", DataKind.R4, 4),
35-
new TextLoader.Column("NitricOxides", DataKind.R4, 5),
36-
new TextLoader.Column("RoomsPerDwelling", DataKind.R4, 6),
37-
new TextLoader.Column("PercentPre40s", DataKind.R4, 7),
38-
new TextLoader.Column("EmploymentDistance", DataKind.R4, 8),
39-
new TextLoader.Column("HighwayDistance", DataKind.R4, 9),
40-
new TextLoader.Column("TaxRate", DataKind.R4, 10),
41-
new TextLoader.Column("TeacherRatio", DataKind.R4, 11)
42-
},
43-
hasHeader: true
44-
);
45-
46-
// Read the data
47-
var data = reader.Read(dataFile);
17+
// Step 1: Read the data
18+
var data = GetHousingRegressionIDataView(mlContext, out string labelName, out string[] featureNames);
4819

4920
// Step 2: Pipeline
5021
// Concatenate the features to create a Feature vector.
5122
// Normalize the data set so that for each feature, its maximum value is 1 while its minimum value is 0.
52-
// Then append a linear regression trainer, setting the "MedianHomeValue" column as the label of the dataset,
53-
// the "Features" column produced by concatenation as the features of the dataset.
54-
var labelName = "MedianHomeValue";
55-
var pipeline = mlContext.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental",
56-
"PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", "PercentPre40s",
57-
"EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
23+
// Then append a linear regression trainer.
24+
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
5825
.Append(mlContext.Transforms.Normalize("Features"))
5926
.Append(mlContext.Regression.Trainers.OrdinaryLeastSquares(
6027
labelColumn: labelName, featureColumn: "Features"));
@@ -64,19 +31,12 @@ public static void PFI_Regression()
6431
var linearPredictor = model.LastTransformer;
6532
var weights = GetLinearModelWeights(linearPredictor.Model);
6633

67-
// Compute the permutation metrics using the properly-featurized data.
34+
// Compute the permutation metrics using the properly normalized data.
6835
var transformedData = model.Transform(data);
6936
var permutationMetrics = mlContext.Regression.PermutationFeatureImportance(
7037
linearPredictor, transformedData, label: labelName, features: "Features", permutationCount: 3);
7138

7239
// Now let's look at which features are most important to the model overall
73-
// First, we have to prepare the data:
74-
// Get the feature names as an IEnumerable
75-
var featureNames = data.Schema
76-
.Select(column => column.Name) // Get the column names
77-
.Where(name => name != labelName) // Drop the Label
78-
.ToArray();
79-
8040
// Get the feature indices sorted by their impact on R-Squared
8141
var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.RSquared })
8242
.OrderByDescending(feature => Math.Abs(feature.RSquared.Mean))
@@ -116,10 +76,156 @@ public static void PFI_Regression()
11676
Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i].Mean:G4}\t{1.96 * rSquared[i].StandardError:G4}");
11777
}
11878
}
79+
public static void PFI_BinaryClassification()
80+
{
81+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
82+
// as a catalog of available operations and as the source of randomness.
83+
var mlContext = new MLContext();
84+
85+
// Step 1: Read the data
86+
var data = GetHousingRegressionIDataView(mlContext, out string labelName, out string[] featureNames, binaryPrediction: true);
87+
88+
// Step 2: Pipeline
89+
// Concatenate the features to create a Feature vector.
90+
// Normalize the data set so that for each feature, its maximum value is 1 while its minimum value is 0.
91+
// Then append a logistic regression trainer.
92+
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
93+
.Append(mlContext.Transforms.Normalize("Features"))
94+
.Append(mlContext.BinaryClassification.Trainers.LogisticRegression(
95+
labelColumn: labelName, featureColumn: "Features"));
96+
var model = pipeline.Fit(data);
97+
98+
// Extract the model from the pipeline
99+
var linearPredictor = model.LastTransformer;
100+
// Linear models for binary classification are wrapped by a calibrator as a generic predictor
101+
// To access it directly, we must extract it out and cast it to the proper class
102+
var weights = GetLinearModelWeights(linearPredictor.Model.SubPredictor as LinearBinaryModelParameters);
103+
104+
// Compute the permutation metrics using the properly normalized data.
105+
var transformedData = model.Transform(data);
106+
var permutationMetrics = mlContext.BinaryClassification.PermutationFeatureImportance(
107+
linearPredictor, transformedData, label: labelName, features: "Features");
108+
109+
// Now let's look at which features are most important to the model overall
110+
// Get the feature indices sorted by their impact on AUC
111+
var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.Auc })
112+
.OrderByDescending(feature => Math.Abs(feature.Auc.Mean))
113+
.Select(feature => feature.index);
114+
115+
// Print out the permutation results, with the model weights, in order of their impact:
116+
// Expected console output:
117+
// Feature Model Weight Change in AUC
118+
// PercentPre40s -1.96 -0.04582
119+
// RoomsPerDwelling 3.71 -0.04516
120+
// EmploymentDistance -1.31 -0.02375
121+
// TeacherRatio -2.46 -0.01476
122+
// CharlesRiver 0.66 -0.008683
123+
// PercentNonRetail -1.58 -0.007314
124+
// PercentResidental 0.60 0.003979
125+
// TaxRate -0.95 0.002739
126+
// NitricOxides -0.32 0.001917
127+
// CrimesPerCapita -0.04 -3.222E-05
128+
// HighwayDistance 0.00 0
129+
//
130+
// Let's look at these results.
131+
// First, if you look at the weights of the model, they generally correlate with the results of PFI,
132+
// but there are some significant misorderings. See the discussion in the Regression example for an
133+
// explanation of why this happens and how to interpret it.
134+
// Second, the logistic regression learner uses L1 regularization by default. Here, it causes the "HighWay Distance"
135+
// feature to be zeroed out from the model. PFI assigns zero importance to this variable, as expected.
136+
// Third, some features showed an *increase* in AUC. This means that the model actually improved
137+
// when these features were shuffled. This is actually expected when the effects are small (here on the order of 10^-3).
138+
// This is due to the random nature of permutations. To reduce computational costs, PFI performs a single
139+
// permutation per feature, which means that the change in AUC is just from one sample of the data.
140+
// If each feature were permuted many times and the average computed, the resuting average change in AUC
141+
// would be small and negative for these features, or zero if the features truly were meaningless.
142+
// To see observe this behavior yourself, try adding a second call to PFI and compare the results, or
143+
// rerun the script with a different seed set in the MLContext(), like so:
144+
// `var mlContext = new MLContext(seed: 12345);`
145+
Console.WriteLine("Feature\tModel Weight\tChange in AUC\t95% Confidence in the Mean Change in AUC");
146+
var auc = permutationMetrics.Select(x => x.Auc).ToArray(); // Fetch AUC as an array
147+
foreach (int i in sortedIndices)
148+
{
149+
Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{auc[i].Mean:G4}\t{1.96 * auc[i].StandardError:G4}");
150+
}
151+
// DON"T CHECK IN UNTIL TEXT IS COMPLETE
152+
}
153+
154+
private static IDataView GetHousingRegressionIDataView(MLContext mlContext, out string labelName, out string[] featureNames, bool binaryPrediction = false)
155+
{
156+
// Download the dataset from github.com/dotnet/machinelearning.
157+
// This will create a housing.txt file in the filesystem.
158+
// You can open this file to see the data.
159+
string dataFile = SamplesUtils.DatasetUtils.DownloadHousingRegressionDataset();
160+
161+
// Read the data as an IDataView.
162+
// First, we define the reader: specify the data columns and where to find them in the text file.
163+
// The data file is composed of rows of data, with each row having 11 numerical columns
164+
// separated by whitespace.
165+
var reader = mlContext.Data.CreateTextReader(
166+
columns: new[]
167+
{
168+
// Read the first column (indexed by 0) in the data file as an R4 (float)
169+
new TextLoader.Column("MedianHomeValue", DataKind.R4, 0),
170+
new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1),
171+
new TextLoader.Column("PercentResidental", DataKind.R4, 2),
172+
new TextLoader.Column("PercentNonRetail", DataKind.R4, 3),
173+
new TextLoader.Column("CharlesRiver", DataKind.R4, 4),
174+
new TextLoader.Column("NitricOxides", DataKind.R4, 5),
175+
new TextLoader.Column("RoomsPerDwelling", DataKind.R4, 6),
176+
new TextLoader.Column("PercentPre40s", DataKind.R4, 7),
177+
new TextLoader.Column("EmploymentDistance", DataKind.R4, 8),
178+
new TextLoader.Column("HighwayDistance", DataKind.R4, 9),
179+
new TextLoader.Column("TaxRate", DataKind.R4, 10),
180+
new TextLoader.Column("TeacherRatio", DataKind.R4, 11),
181+
},
182+
hasHeader: true
183+
);
184+
185+
// Read the data
186+
var data = reader.Read(dataFile);
187+
var labelColumn = "MedianHomeValue";
188+
189+
if (binaryPrediction)
190+
{
191+
labelColumn = nameof(BinaryOutputRow.AboveAverage);
192+
data = mlContext.Transforms.CustomMappingTransformer(GreaterThanAverage, null).Transform(data);
193+
data = mlContext.Transforms.DropColumns("MedianHomeValue").Fit(data).Transform(data);
194+
}
195+
196+
labelName = labelColumn;
197+
featureNames = data.Schema.AsEnumerable()
198+
.Select(column => column.Name) // Get the column names
199+
.Where(name => name != labelColumn) // Drop the Label
200+
.ToArray();
201+
202+
return data;
203+
}
204+
205+
// Define a class for all the input columns that we intend to consume.
206+
private class ContinuousInputRow
207+
{
208+
public float MedianHomeValue { get; set; }
209+
}
210+
211+
// Define a class for all output columns that we intend to produce.
212+
private class BinaryOutputRow
213+
{
214+
public bool AboveAverage { get; set; }
215+
}
216+
217+
// Define an Action to apply a custom mapping from one object to the other
218+
private readonly static Action<ContinuousInputRow, BinaryOutputRow> GreaterThanAverage = (input, output)
219+
=> output.AboveAverage = input.MedianHomeValue > 22.6;
119220

120221
private static float[] GetLinearModelWeights(OlsLinearRegressionModelParameters linearModel)
121222
{
122223
return linearModel.Weights.ToArray();
123224
}
225+
226+
private static float[] GetLinearModelWeights(LinearBinaryModelParameters linearModel)
227+
{
228+
return linearModel.Weights.ToArray();
229+
}
124230
}
125231
}

0 commit comments

Comments
 (0)