-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Adding a binary classification PFI Example #1793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Learners; | ||
using Microsoft.ML.Trainers.HalLearners; | ||
using System; | ||
using System.Linq; | ||
|
||
namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance | ||
{ | ||
public class PfiHelper | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you add BinaryClassificationExample and Regression to our documentation. |
||
{ | ||
public static IDataView GetHousingRegressionIDataView(MLContext mlContext, out string labelName, out string[] featureNames, bool binaryPrediction = false) | ||
{ | ||
// Download the dataset from github.com/dotnet/machinelearning. | ||
// This will create a housing.txt file in the filesystem. | ||
// You can open this file to see the data. | ||
string dataFile = SamplesUtils.DatasetUtils.DownloadHousingRegressionDataset(); | ||
|
||
// Read the data as an IDataView. | ||
// First, we define the reader: specify the data columns and where to find them in the text file. | ||
// The data file is composed of rows of data, with each row having 11 numerical columns | ||
// separated by whitespace. | ||
var reader = mlContext.Data.CreateTextReader( | ||
columns: new[] | ||
{ | ||
// Read the first column (indexed by 0) in the data file as an R4 (float) | ||
new TextLoader.Column("MedianHomeValue", DataKind.R4, 0), | ||
new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1), | ||
new TextLoader.Column("PercentResidental", DataKind.R4, 2), | ||
new TextLoader.Column("PercentNonRetail", DataKind.R4, 3), | ||
new TextLoader.Column("CharlesRiver", DataKind.R4, 4), | ||
new TextLoader.Column("NitricOxides", DataKind.R4, 5), | ||
new TextLoader.Column("RoomsPerDwelling", DataKind.R4, 6), | ||
new TextLoader.Column("PercentPre40s", DataKind.R4, 7), | ||
new TextLoader.Column("EmploymentDistance", DataKind.R4, 8), | ||
new TextLoader.Column("HighwayDistance", DataKind.R4, 9), | ||
new TextLoader.Column("TaxRate", DataKind.R4, 10), | ||
new TextLoader.Column("TeacherRatio", DataKind.R4, 11), | ||
}, | ||
hasHeader: true | ||
); | ||
|
||
// Read the data | ||
var data = reader.Read(dataFile); | ||
var labelColumn = "MedianHomeValue"; | ||
|
||
if (binaryPrediction) | ||
{ | ||
labelColumn = nameof(BinaryOutputRow.AboveAverage); | ||
data = mlContext.Transforms.CustomMappingTransformer(GreaterThanAverage, null).Transform(data); | ||
data = mlContext.Transforms.DropColumns("MedianHomeValue").Fit(data).Transform(data); | ||
} | ||
|
||
labelName = labelColumn; | ||
featureNames = data.Schema.AsEnumerable() | ||
.Select(column => column.Name) // Get the column names | ||
.Where(name => name != labelColumn) // Drop the Label | ||
.ToArray(); | ||
|
||
return data; | ||
} | ||
|
||
// Define a class for all the input columns that we intend to consume. | ||
private class ContinuousInputRow | ||
{ | ||
public float MedianHomeValue { get; set; } | ||
} | ||
|
||
// Define a class for all output columns that we intend to produce. | ||
private class BinaryOutputRow | ||
{ | ||
public bool AboveAverage { get; set; } | ||
} | ||
|
||
// Define an Action to apply a custom mapping from one object to the other | ||
private readonly static Action<ContinuousInputRow, BinaryOutputRow> GreaterThanAverage = (input, output) | ||
=> output.AboveAverage = input.MedianHomeValue > 22.6; | ||
|
||
public static float[] GetLinearModelWeights(OlsLinearRegressionModelParameters linearModel) | ||
{ | ||
return linearModel.Weights.ToArray(); | ||
} | ||
|
||
public static float[] GetLinearModelWeights(LinearBinaryModelParameters linearModel) | ||
{ | ||
return linearModel.Weights.ToArray(); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
using Microsoft.ML.Runtime.Learners; | ||
using System; | ||
using System.Linq; | ||
|
||
namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you need to merge with master anyway... You already inside PermutationFeatureImportance namespace, I don't see much reason behind adding Pfi prefix to names, classes and methods. |
||
{ | ||
public class PfiBinaryClassificationExample | ||
{ | ||
public static void RunExample() | ||
{ | ||
// Create a new context for ML.NET operations. It can be used for exception tracking and logging, | ||
// as a catalog of available operations and as the source of randomness. | ||
var mlContext = new MLContext(seed:999123); | ||
|
||
// Step 1: Read the data | ||
var data = PfiHelper.GetHousingRegressionIDataView(mlContext, | ||
out string labelName, out string[] featureNames, binaryPrediction: true); | ||
|
||
// Step 2: Pipeline | ||
// Concatenate the features to create a Feature vector. | ||
// Normalize the data set so that for each feature, its maximum value is 1 while its minimum value is 0. | ||
// Then append a logistic regression trainer. | ||
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames) | ||
.Append(mlContext.Transforms.Normalize("Features")) | ||
.Append(mlContext.BinaryClassification.Trainers.LogisticRegression( | ||
labelColumn: labelName, featureColumn: "Features")); | ||
var model = pipeline.Fit(data); | ||
|
||
// Extract the model from the pipeline | ||
var linearPredictor = model.LastTransformer; | ||
// Linear models for binary classification are wrapped by a calibrator as a generic predictor | ||
// To access it directly, we must extract it out and cast it to the proper class | ||
var weights = PfiHelper.GetLinearModelWeights(linearPredictor.Model.SubPredictor as LinearBinaryModelParameters); | ||
|
||
// Compute the permutation metrics using the properly normalized data. | ||
var transformedData = model.Transform(data); | ||
var permutationMetrics = mlContext.BinaryClassification.PermutationFeatureImportance( | ||
linearPredictor, transformedData, label: labelName, features: "Features", permutationCount: 3); | ||
|
||
// Now let's look at which features are most important to the model overall | ||
// Get the feature indices sorted by their impact on AUC | ||
var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.Auc }) | ||
.OrderByDescending(feature => Math.Abs(feature.Auc.Mean)) | ||
.Select(feature => feature.index); | ||
|
||
// Print out the permutation results, with the model weights, in order of their impact: | ||
// Expected console output (for 100 permutations): | ||
// Feature Model Weight Change in AUC 95% Confidence in the Mean Change in AUC | ||
// PercentPre40s -1.96 -0.06316 0.002377 | ||
// RoomsPerDwelling 3.71 -0.04385 0.001245 | ||
// EmploymentDistance -1.31 -0.02139 0.0006867 | ||
// TeacherRatio -2.46 -0.0203 0.0009566 | ||
// PercentNonRetail -1.58 -0.01846 0.001586 | ||
// CharlesRiver 0.66 -0.008605 0.0005136 | ||
// PercentResidental 0.60 0.002483 0.0004818 | ||
// TaxRate -0.95 -0.00221 0.0007394 | ||
// NitricOxides -0.32 0.00101 0.0001428 | ||
// CrimesPerCapita -0.04 -3.029E-05 1.678E-05 | ||
// HighwayDistance 0.00 0 0 | ||
// Let's look at these results. | ||
// First, if you look at the weights of the model, they generally correlate with the results of PFI, | ||
// but there are some significant misorderings. See the discussion in the Regression example for an | ||
// explanation of why this happens and how to interpret it. | ||
// Second, the logistic regression learner uses L1 regularization by default. Here, it causes the "HighWay Distance" | ||
// feature to be zeroed out from the model. PFI assigns zero importance to this variable, as expected. | ||
// Third, some features show an *increase* in AUC. This means that the model actually improved | ||
// when these features were shuffled. This is a sign to investigate these features further. | ||
Console.WriteLine("Feature\tModel Weight\tChange in AUC\t95% Confidence in the Mean Change in AUC"); | ||
var auc = permutationMetrics.Select(x => x.Auc).ToArray(); // Fetch AUC as an array | ||
foreach (int i in sortedIndices) | ||
{ | ||
Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{auc[i].Mean:G4}\t{1.96 * auc[i].StandardError:G4}"); | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make it internal as well?