Skip to content

Commit e1a5d3a

Browse files
author
Rogan Carr
committed
Adding an example for using PFI
1 parent 78cad14 commit e1a5d3a

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
using Microsoft.ML.Runtime.Data;
2+
using Microsoft.ML.Runtime.Learners;
3+
using System;
4+
using System.Linq;
5+
6+
namespace Microsoft.ML.Samples.Dynamic
7+
{
8+
public class PFI_RegressionExample
9+
{
10+
public static void PFI_Regression()
11+
{
12+
// Download the dataset from github.com/dotnet/machinelearning.
13+
// This will create a housing.txt file in the filesystem.
14+
// You can open this file to see the data.
15+
string dataFile = SamplesUtils.DatasetUtils.DownloadHousingRegressionDataset();
16+
17+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
18+
// as a catalog of available operations and as the source of randomness.
19+
var mlContext = new MLContext();
20+
21+
// Step 1: Read the data as an IDataView.
22+
// First, we define the reader: specify the data columns and where to find them in the text file.
23+
var reader = mlContext.Data.TextReader(new TextLoader.Arguments()
24+
{
25+
Separator = "tab",
26+
HasHeader = true,
27+
Column = new[]
28+
{
29+
new TextLoader.Column("MedianHomeValue", DataKind.R4, 0),
30+
new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1),
31+
new TextLoader.Column("PercentResidental", DataKind.R4, 2),
32+
new TextLoader.Column("PercentNonRetail", DataKind.R4, 3),
33+
new TextLoader.Column("CharlesRiver", DataKind.R4, 4),
34+
new TextLoader.Column("NitricOxides", DataKind.R4, 5),
35+
new TextLoader.Column("RoomsPerDwelling", DataKind.R4, 6),
36+
new TextLoader.Column("PercentPre40s", DataKind.R4, 7),
37+
new TextLoader.Column("EmploymentDistance", DataKind.R4, 8),
38+
new TextLoader.Column("HighwayDistance", DataKind.R4, 9),
39+
new TextLoader.Column("TaxRate", DataKind.R4, 10),
40+
new TextLoader.Column("TeacherRatio", DataKind.R4, 11),
41+
}
42+
});
43+
44+
// Read the data
45+
var data = reader.Read(dataFile);
46+
47+
// Step 2: Pipeline
48+
// Concatenate the features to create a Feature vector.
49+
// Then append a gam regressor, setting the "MedianHomeValue" column as the label of the dataset,
50+
// the "Features" column produced by concatenation as the features column.
51+
var labelName = "MedianHomeValue";
52+
var pipeline = mlContext.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental",
53+
"PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", "PercentPre40s",
54+
"EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
55+
.Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent(
56+
labelColumn: labelName, featureColumn: "Features"));
57+
var fitPipeline = pipeline.Fit(data);
58+
59+
// Extract the model from the pipeline
60+
var linearPredictor = fitPipeline.LastTransformer;
61+
var weights = GetLinearModelWeights(linearPredictor.Model);
62+
63+
// Compute the permutation metrics using the properly-featurized data.
64+
var transformedData = fitPipeline.Transform(data);
65+
var permutationMetrics = mlContext.Regression.PermutationFeatureImportance(
66+
linearPredictor, transformedData, label: labelName, features: "Features");
67+
68+
// Now let's look at which features are most important to the model overall
69+
// First, we have to prepare the data:
70+
// Get the feature names as an IEnumerable
71+
var featureNames = data.Schema.GetColumns()
72+
.Select(tuple => tuple.column.Name) // Get the column names
73+
.Where(name => name != labelName) // Drop the Label
74+
.ToArray();
75+
76+
// Get the feature indices sorted by their impact on R-Squared
77+
var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.RSquared })
78+
.OrderByDescending(feature => Math.Abs(feature.RSquared))
79+
.Select(feature => feature.index);
80+
81+
// Print out the permutation results, with the model weights, in order of their impact
82+
Console.WriteLine("Feature\tModel Weight\tChange in R-Squared");
83+
var rSquared = permutationMetrics.Select(x => x.RSquared).ToArray(); // Fetch r-squared as an array
84+
foreach (int i in sortedIndices)
85+
{
86+
Console.WriteLine("{0}\t{1:0.00}\t{2:G4}", featureNames[i], weights[i], rSquared[i]);
87+
}
88+
}
89+
90+
private static float[] GetLinearModelWeights(LinearRegressionPredictor linearModel)
91+
{
92+
var weights = new VBuffer<float>();
93+
linearModel.GetFeatureWeights(ref weights);
94+
return weights.GetValues().ToArray();
95+
}
96+
}
97+
}

0 commit comments

Comments
 (0)