Skip to content

Confidence Intervals for Permutation Feature Importance #1844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 20, 2018
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Trainers.HalLearners;
using System;
using System.Linq;

Expand Down Expand Up @@ -57,18 +58,18 @@ public static void PFI_Regression()
"PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", "PercentPre40s",
"EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
.Append(mlContext.Transforms.Normalize("Features"))
.Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent(
.Append(mlContext.Regression.Trainers.OrdinaryLeastSquares(
labelColumn: labelName, featureColumn: "Features"));
var model = pipeline.Fit(data);

var model = pipeline.Fit(data);
// Extract the model from the pipeline
var linearPredictor = model.LastTransformer;
var weights = GetLinearModelWeights(linearPredictor.Model);

// Compute the permutation metrics using the properly-featurized data.
var transformedData = model.Transform(data);
var permutationMetrics = mlContext.Regression.PermutationFeatureImportance(
linearPredictor, transformedData, label: labelName, features: "Features");
linearPredictor, transformedData, label: labelName, features: "Features", permutationCount: 3);

// Now let's look at which features are most important to the model overall
// First, we have to prepare the data:
Expand All @@ -80,49 +81,47 @@ public static void PFI_Regression()

// Get the feature indices sorted by their impact on R-Squared
var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.RSquared })
.OrderByDescending(feature => Math.Abs(feature.RSquared))
.OrderByDescending(feature => Math.Abs(feature.RSquared.Mean))
.Select(feature => feature.index);

// Print out the permutation results, with the model weights, in order of their impact:
// Expected console output:
// Feature Model Weight Change in R - Squared
// RoomsPerDwelling 50.80 -0.3695
// EmploymentDistance -17.79 -0.2238
// TeacherRatio -19.83 -0.1228
// TaxRate -8.60 -0.1042
// NitricOxides -15.95 -0.1025
// HighwayDistance 5.37 -0.09345
// CrimesPerCapita -15.05 -0.05797
// PercentPre40s -4.64 -0.0385
// PercentResidental 3.98 -0.02184
// CharlesRiver 3.38 -0.01487
// PercentNonRetail -1.94 -0.007231
// Expected console output for 100 permutations:
// Feature Model Weight Change in R-Squared Standard Error of the Mean Change in R-Squared
// RoomsPerDwelling 53.35 -0.4294 0.003252
// EmploymentDistance -19.21 -0.2666 0.001997
// NitricOxides -19.32 -0.1543 0.001559
// HighwayDistance 6.11 -0.118 0.001168
// TeacherRatio -21.92 -0.1079 0.001392
// TaxRate -8.68 -0.1004 0.001191
// CrimesPerCapita -16.37 -0.05994 0.001023
// PercentPre40s -4.52 -0.0375 0.0007154
// PercentResidental 3.91 -0.01961 0.000492
// CharlesRiver 3.49 -0.01845 0.0004909
// PercentNonRetail -1.17 -0.001916 0.0001628
//
// Let's dig into these results a little bit. First, if you look at the weights of the model, they generally correlate
// with the results of PFI, but there are some significant misorderings. For example, "Tax Rate" is weighted lower than
// "Nitric Oxides" and "Crimes Per Capita", but the permutation analysis shows this feature to have a larger effect
// on the accuracy of the model even though it has a relatively small weight. To understand why the weights don't
// reflect the same feature importance as PFI, we need to go back to the basics of linear models: one of the
// assumptions of a linear model is that the features are uncorrelated. Now, the features in this dataset are clearly
// correlated: the tax rate for a house and the student-to-teacher ratio at the nearest school, for example, are often
// coupled through school levies. The tax rate, presence of pollution (e.g. nitric oxides), and the crime rate would also
// seem to be correlated with each other through social dynamics. We could draw out similar relationships for all the
// variables in this dataset. The reason why the linear model weights don't reflect the same feature importance as PFI
// is that the solution to the linear model redistributes weights between correlated variables in unpredictable ways, so
// that the weights themselves are no longer a good measure of feature importance.
Console.WriteLine("Feature\tModel Weight\tChange in R-Squared");
// with the results of PFI, but there are some significant misorderings. For example, "Tax Rate" and "Highway Distance"
// have relatively small model weights, but the permutation analysis shows these feature to have a larger effect
// on the accuracy of the model than higher-weighted features. To understand why the weights don't reflect the same
// feature importance as PFI, we need to go back to the basics of linear models: one of the assumptions of a linear
// model is that the features are uncorrelated. Now, the features in this dataset are clearly correlated: the tax rate
// for a house and the student-to-teacher ratio at the nearest school, for example, are often coupled through school
// levies. The tax rate, distance to a highway, and the crime rate would also seem to be correlated through social
// dynamics. We could draw out similar relationships for all variables in this dataset. The reason why the linear
// model weights don't reflect the same feature importance as PFI is that the solution to the linear model redistributes
// weights between correlated variables in unpredictable ways, so that the weights themselves are no longer a good
// measure of feature importance.
Console.WriteLine("Feature\tModel Weight\tChange in R-Squared\tStandard Error of the Mean Change in R-Squared");
var rSquared = permutationMetrics.Select(x => x.RSquared).ToArray(); // Fetch r-squared as an array
foreach (int i in sortedIndices)
{
Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i]:G4}");
Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i].Mean:G4}\t{rSquared[i].StandardError:G4}");
}
}

private static float[] GetLinearModelWeights(LinearRegressionPredictor linearModel)
private static float[] GetLinearModelWeights(OlsLinearRegressionPredictor linearModel)
{
var weights = new VBuffer<float>();
linearModel.GetFeatureWeights(ref weights);
return weights.GetValues().ToArray();
return linearModel.Weights2.ToArray();
}
}
}
75 changes: 44 additions & 31 deletions src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@

namespace Microsoft.ML.Transforms
{
internal static class PermutationFeatureImportance<TResult>
internal static class PermutationFeatureImportance<TMetric, TResult> where TResult : MetricsStatisticsBase<TMetric>, new()
{
public static ImmutableArray<TResult>
GetImportanceMetricsMatrix(
IHostEnvironment env,
IPredictionTransformer<IPredictor> model,
IDataView data,
Func<IDataView, TResult> evaluationFunc,
Func<TResult, TResult, TResult> deltaFunc,
Func<IDataView, TMetric> evaluationFunc,
Func<TMetric, TMetric, TMetric> deltaFunc,
string features,
int permutationCount,
bool useFeatureWeightFilter = false,
int? topExamples = null)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register(nameof(PermutationFeatureImportance<TResult>));
var host = env.Register(nameof(PermutationFeatureImportance<TMetric, TResult>));
host.CheckValue(model, nameof(model));
host.CheckValue(data, nameof(data));
host.CheckNonEmpty(features, nameof(features));
Expand Down Expand Up @@ -168,7 +169,7 @@ public static ImmutableArray<TResult>
// Now iterate through all the working slots, do permutation and calc the delta of metrics.
int processedCnt = 0;
int nextFeatureIndex = 0;
int shuffleSeed = host.Rand.Next();
var shuffleRand = RandomUtils.Create(host.Rand.Next());
using (var pch = host.StartProgressChannel("SDCA preprocessing with lookup"))
{
pch.SetHeader(new ProgressHeader("processed slots"), e => e.SetProgress(0, processedCnt));
Expand All @@ -178,27 +179,9 @@ public static ImmutableArray<TResult>
if (processedCnt < workingFeatureIndices.Count - 1)
nextFeatureIndex = workingFeatureIndices[processedCnt + 1];

// Used for pre-caching the next feature
int nextValuesIndex = 0;

Utils.Shuffle<float>(RandomUtils.Create(shuffleSeed), featureValuesBuffer);

Action<FeaturesBuffer, FeaturesBuffer, PermuterState> permuter =
(src, dst, state) =>
{
src.Features.CopyTo(ref dst.Features);
VBufferUtils.ApplyAt(ref dst.Features, workingIndx,
(int ii, ref float d) =>
d = featureValuesBuffer[state.SampleIndex++]);

if (processedCnt < workingFeatureIndices.Count - 1)
{
// This is the reason I need PermuterState in LambdaTransform.CreateMap.
nextValues[nextValuesIndex] = src.Features.GetItemOrDefault(nextFeatureIndex);
if (nextValuesIndex < valuesRowCount - 1)
nextValuesIndex++;
}
};

SchemaDefinition input = SchemaDefinition.Create(typeof(FeaturesBuffer));
Contracts.Assert(input.Count == 1);
input[0].ColumnName = features;
Expand All @@ -208,15 +191,45 @@ public static ImmutableArray<TResult>
output[0].ColumnName = features;
output[0].ColumnType = featuresColumn.Type;

IDataView viewPermuted = LambdaTransform.CreateMap(
host, data, permuter, null, input, output);
if (valuesRowCount == topExamples)
viewPermuted = SkipTakeFilter.Create(host, new SkipTakeFilter.TakeArguments() { Count = valuesRowCount }, viewPermuted);
// Perform multiple permutations for one feature to build a confidence interval
var metricsDeltaForFeature = new TResult();
for (int permutationIteration = 0; permutationIteration < permutationCount; permutationIteration++)
{
Utils.Shuffle<float>(shuffleRand, featureValuesBuffer);

var metrics = evaluationFunc(model.Transform(viewPermuted));
Action<FeaturesBuffer, FeaturesBuffer, PermuterState> permuter =
(src, dst, state) =>
{
src.Features.CopyTo(ref dst.Features);
VBufferUtils.ApplyAt(ref dst.Features, workingIndx,
(int ii, ref float d) =>
d = featureValuesBuffer[state.SampleIndex++]);

// Is it time to pre-cache the next feature?
if (permutationIteration == permutationCount - 1 &&
processedCnt < workingFeatureIndices.Count - 1)
{
// Fill out the featureValueBuffer for the next feature while updating the current feature
// This is the reason I need PermuterState in LambdaTransform.CreateMap.
nextValues[nextValuesIndex] = src.Features.GetItemOrDefault(nextFeatureIndex);
if (nextValuesIndex < valuesRowCount - 1)
nextValuesIndex++;
}
};

IDataView viewPermuted = LambdaTransform.CreateMap(
host, data, permuter, null, input, output);
if (valuesRowCount == topExamples)
viewPermuted = SkipTakeFilter.Create(host, new SkipTakeFilter.TakeArguments() { Count = valuesRowCount }, viewPermuted);

var metrics = evaluationFunc(model.Transform(viewPermuted));

var delta = deltaFunc(metrics, baselineMetrics);
metricsDeltaForFeature.Add(delta);
}

var delta = deltaFunc(metrics, baselineMetrics);
metricsDelta.Add(delta);
// Add the metrics delta to the list
metricsDelta.Add(metricsDeltaForFeature);

// Swap values for next iteration of permutation.
Array.Clear(featureValuesBuffer, 0, featureValuesBuffer.Length);
Expand Down
Loading