-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add a sample for Permutation Feature Importance #1728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 8 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
e1a5d3a
Adding an example for using PFI
b81965a
Adding in the expected console output.
89b6ffc
Addresing PR comments.
6c9c01a
removing blank line (nit)
3fe43f7
Change normalization comment
wschin 2ec6b6a
Fixing comments
5eec96c
Merge branch '1723_PFI_Documentation' of https://github.com/rogancarr…
03bb415
Cleaning up the language in comments, and variable names
648473a
Updating comments around data load
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
128 changes: 128 additions & 0 deletions
128
docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Learners; | ||
using System; | ||
using System.Linq; | ||
|
||
namespace Microsoft.ML.Samples.Dynamic | ||
{ | ||
public class PFI_RegressionExample | ||
{ | ||
public static void PFI_Regression() | ||
{ | ||
// Download the dataset from github.com/dotnet/machinelearning. | ||
// This will create a housing.txt file in the filesystem. | ||
// You can open this file to see the data. | ||
string dataFile = SamplesUtils.DatasetUtils.DownloadHousingRegressionDataset(); | ||
|
||
// Create a new context for ML.NET operations. It can be used for exception tracking and logging, | ||
// as a catalog of available operations and as the source of randomness. | ||
var mlContext = new MLContext(); | ||
|
||
// Step 1: Read the data as an IDataView. | ||
// First, we define the reader: specify the data columns and where to find them in the text file. | ||
// The data looks like this: | ||
// 24.00 0.00632 18.00 2.310 0 0.5380 6.5750 65.20 4.0900 1 296.0 15.30 | ||
// 21.60 0.02731 0.00 7.070 0 0.4690 6.4210 78.90 4.9671 2 242.0 17.80 | ||
rogancarr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
var reader = mlContext.Data.TextReader(new TextLoader.Arguments() | ||
{ | ||
Separator = "tab", | ||
HasHeader = true, | ||
Column = new[] | ||
{ | ||
new TextLoader.Column("MedianHomeValue", DataKind.R4, 0), | ||
rogancarr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1), | ||
new TextLoader.Column("PercentResidental", DataKind.R4, 2), | ||
new TextLoader.Column("PercentNonRetail", DataKind.R4, 3), | ||
new TextLoader.Column("CharlesRiver", DataKind.R4, 4), | ||
new TextLoader.Column("NitricOxides", DataKind.R4, 5), | ||
new TextLoader.Column("RoomsPerDwelling", DataKind.R4, 6), | ||
new TextLoader.Column("PercentPre40s", DataKind.R4, 7), | ||
new TextLoader.Column("EmploymentDistance", DataKind.R4, 8), | ||
new TextLoader.Column("HighwayDistance", DataKind.R4, 9), | ||
new TextLoader.Column("TaxRate", DataKind.R4, 10), | ||
new TextLoader.Column("TeacherRatio", DataKind.R4, 11), | ||
} | ||
}); | ||
|
||
// Read the data | ||
var data = reader.Read(dataFile); | ||
|
||
// Step 2: Pipeline | ||
// Concatenate the features to create a Feature vector. | ||
// Normalize the data set so that for each feature, its maximum value is 1 while its minimum value is 0. | ||
// Then append a linear regression trainer, setting the "MedianHomeValue" column as the label of the dataset, | ||
// the "Features" column produced by concatenation as the features of the dataset. | ||
var labelName = "MedianHomeValue"; | ||
var pipeline = mlContext.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", | ||
"PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", "PercentPre40s", | ||
"EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio") | ||
.Append(mlContext.Transforms.Normalize("Features")) | ||
.Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent( | ||
rogancarr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
labelColumn: labelName, featureColumn: "Features")); | ||
var model = pipeline.Fit(data); | ||
|
||
// Extract the model from the pipeline | ||
var linearPredictor = model.LastTransformer; | ||
var weights = GetLinearModelWeights(linearPredictor.Model); | ||
|
||
// Compute the permutation metrics using the properly-featurized data. | ||
var transformedData = model.Transform(data); | ||
var permutationMetrics = mlContext.Regression.PermutationFeatureImportance( | ||
linearPredictor, transformedData, label: labelName, features: "Features"); | ||
|
||
// Now let's look at which features are most important to the model overall | ||
// First, we have to prepare the data: | ||
// Get the feature names as an IEnumerable | ||
var featureNames = data.Schema.GetColumns() | ||
.Select(tuple => tuple.column.Name) // Get the column names | ||
.Where(name => name != labelName) // Drop the Label | ||
.ToArray(); | ||
|
||
// Get the feature indices sorted by their impact on R-Squared | ||
var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.RSquared }) | ||
.OrderByDescending(feature => Math.Abs(feature.RSquared)) | ||
.Select(feature => feature.index); | ||
|
||
// Print out the permutation results, with the model weights, in order of their impact: | ||
// Expected console output: | ||
// Feature Model Weight Change in R - Squared | ||
// RoomsPerDwelling 50.80 -0.3695 | ||
// EmploymentDistance -17.79 -0.2238 | ||
// TeacherRatio -19.83 -0.1228 | ||
// TaxRate -8.60 -0.1042 | ||
// NitricOxides -15.95 -0.1025 | ||
// HighwayDistance 5.37 -0.09345 | ||
// CrimesPerCapita -15.05 -0.05797 | ||
// PercentPre40s -4.64 -0.0385 | ||
// PercentResidental 3.98 -0.02184 | ||
// CharlesRiver 3.38 -0.01487 | ||
// PercentNonRetail -1.94 -0.007231 | ||
// | ||
// Let's dig into these results a little bit. First, if you look at the weights of the model, they generally correlate | ||
// with the results of PFI, but there are some significant misorderings. For example, "Tax Rate" is weighted lower than | ||
// "Nitric Oxides" and "Crimes Per Capita", but the permutation analysis shows this feature to have a larger effect | ||
// on the accuracy of the model even though it has a relatively small weight. To understand why the weights don't | ||
// reflect the same feature importance as PFI, we need to go back to the basics of linear models: one of the | ||
// assumptions of a linear model is that the features are uncorrelated. Now, the features in this dataset are clearly | ||
// correlated: the tax rate for a house and the student-to-teacher ratio at the nearest school, for example, are often | ||
// coupled through school levies. The tax rate, presence of pollution (e.g. nitric oxides), and the crime rate would also | ||
// seem to be correlated with each other through social dynamics. We could draw out similar relationships for all the | ||
// variables in this dataset. The reason why the linear model weights don't reflect the same feature importance as PFI | ||
// is that the solution to the linear model redistributes weights between correlated variables in unpredictable ways, so | ||
// that the weights themselves are no longer a good measure of feature importance. | ||
Console.WriteLine("Feature\tModel Weight\tChange in R-Squared"); | ||
var rSquared = permutationMetrics.Select(x => x.RSquared).ToArray(); // Fetch r-squared as an array | ||
foreach (int i in sortedIndices) | ||
{ | ||
Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i]:G4}"); | ||
} | ||
} | ||
|
||
private static float[] GetLinearModelWeights(LinearRegressionPredictor linearModel) | ||
{ | ||
var weights = new VBuffer<float>(); | ||
linearModel.GetFeatureWeights(ref weights); | ||
return weights.GetValues().ToArray(); | ||
} | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.