Skip to content

Commit 512493a

Browse files
authored
Adding functional tests for explainability (#2584)
* Adding functional tests for explainability
1 parent 01a362b commit 512493a

File tree

6 files changed

+387
-5
lines changed

6 files changed

+387
-5
lines changed

test/Microsoft.ML.Functional.Tests/Common.cs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,38 @@ public static void AssertEqual(TypeTestData testType1, TypeTestData testType2)
164164
/// Check that a <see cref="RegressionMetrics"/> object is valid.
165165
/// </summary>
166166
/// <param name="metrics">The metrics object.</param>
167-
public static void CheckMetrics(RegressionMetrics metrics)
167+
public static void AssertMetrics(RegressionMetrics metrics)
168168
{
169169
// Perform sanity checks on the metrics.
170170
Assert.True(metrics.Rms >= 0);
171171
Assert.True(metrics.L1 >= 0);
172172
Assert.True(metrics.L2 >= 0);
173173
Assert.True(metrics.RSquared <= 1);
174174
}
175+
176+
/// <summary>
177+
/// Check that a <see cref="MetricStatistics"/> object is valid.
178+
/// </summary>
179+
/// <param name="metric">The <see cref="MetricStatistics"/> object.</param>
180+
public static void AssertMetricStatistics(MetricStatistics metric)
181+
{
182+
// Perform sanity checks on the metrics.
183+
Assert.True(metric.StandardDeviation >= 0);
184+
Assert.True(metric.StandardError >= 0);
185+
}
186+
187+
/// <summary>
188+
/// Check that a <see cref="RegressionMetricsStatistics"/> object is valid.
189+
/// </summary>
190+
/// <param name="metrics">The metrics object.</param>
191+
public static void AssertMetricsStatistics(RegressionMetricsStatistics metrics)
192+
{
193+
// The mean can be any float; the standard deviation and error must be >=0.
194+
AssertMetricStatistics(metrics.Rms);
195+
AssertMetricStatistics(metrics.L1);
196+
AssertMetricStatistics(metrics.L2);
197+
AssertMetricStatistics(metrics.RSquared);
198+
AssertMetricStatistics(metrics.LossFn);
199+
}
175200
}
176201
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Data;
6+
7+
namespace Microsoft.ML.Functional.Tests.Datasets
8+
{
9+
/// <summary>
10+
/// A class to hold the output of FeatureContributionCalculator
11+
/// </summary>
12+
internal sealed class FeatureContributionOutput
13+
{
14+
public float[] FeatureContributions { get; set; }
15+
}
16+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Data;
6+
7+
namespace Microsoft.ML.Functional.Tests.Datasets
8+
{
9+
/// <summary>
10+
/// A schematized class for loading the HousingRegression dataset.
11+
/// </summary>
12+
internal sealed class HousingRegression
13+
{
14+
[LoadColumn(0), ColumnName("Label")]
15+
public float MedianHomeValue { get; set; }
16+
17+
[LoadColumn(1)]
18+
public float CrimesPerCapita { get; set; }
19+
20+
[LoadColumn(2)]
21+
public float PercentResidental { get; set; }
22+
23+
[LoadColumn(3)]
24+
public float PercentNonRetail { get; set; }
25+
26+
[LoadColumn(4)]
27+
public float CharlesRiver { get; set; }
28+
29+
[LoadColumn(5)]
30+
public float NitricOxides { get; set; }
31+
32+
[LoadColumn(6)]
33+
public float RoomsPerDwelling { get; set; }
34+
35+
[LoadColumn(7)]
36+
public float PercentPre40s { get; set; }
37+
38+
[LoadColumn(8)]
39+
public float EmploymentDistance { get; set; }
40+
41+
[LoadColumn(9)]
42+
public float HighwayDistance { get; set; }
43+
44+
[LoadColumn(10)]
45+
public float TaxRate { get; set; }
46+
47+
[LoadColumn(11)]
48+
public float TeacherRatio { get; set; }
49+
50+
/// <summary>
51+
/// The list of columns commonly used as features
52+
/// </summary>
53+
public static readonly string[] Features = new string[] {"CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides",
54+
"RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"};
55+
}
56+
}
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.Data.DataView;
6+
using Microsoft.ML.Data;
7+
using Microsoft.ML.Functional.Tests.Datasets;
8+
using Microsoft.ML.RunTests;
9+
using Microsoft.ML.TestFramework;
10+
using Microsoft.ML.Trainers;
11+
using Microsoft.ML.Trainers.FastTree;
12+
using Xunit;
13+
using Xunit.Abstractions;
14+
15+
namespace Microsoft.ML.Functional.Tests
16+
{
17+
/// <summary>
18+
/// Test explainability features.
19+
/// </summary>
20+
public class Explainability : BaseTestClass
21+
{
22+
public Explainability(ITestOutputHelper output) : base(output)
23+
{
24+
}
25+
26+
/// <summary>
27+
/// GlobalFeatureImportance: PFI can be used to compute global feature importance.
28+
/// </summary>
29+
[Fact]
30+
public void GlobalFeatureImportanceWithPermutationFeatureImportance()
31+
{
32+
var mlContext = new MLContext(seed: 1, conc: 1);
33+
34+
// Get the dataset
35+
var data = mlContext.Data.ReadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);
36+
37+
// Create a pipeline to train on the housing data.
38+
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
39+
.Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent());
40+
41+
// Fit the pipeline and transform the data.
42+
var model = pipeline.Fit(data);
43+
var transformedData = model.Transform(data);
44+
45+
// Compute the permutation feature importance to look at global feature importance.
46+
var permutationMetrics = mlContext.Regression.PermutationFeatureImportance(model.LastTransformer, transformedData);
47+
48+
// Make sure the correct number of features came back.
49+
Assert.Equal(HousingRegression.Features.Length, permutationMetrics.Length);
50+
foreach (var metricsStatistics in permutationMetrics)
51+
Common.AssertMetricsStatistics(metricsStatistics);
52+
}
53+
54+
/// <summary>
55+
/// GlobalFeatureImportance: A linear model's feature importance can be viewed through its weight coefficients.
56+
/// </summary>
57+
/// <remarks>
58+
/// Note that this isn't recommended, as there are quite a few statistical issues with interpreting coefficients
59+
/// as weights, but it's common practice, so it's a supported scenario.
60+
/// </remarks>
61+
[Fact]
62+
public void GlobalFeatureImportanceForLinearModelThroughWeights()
63+
{
64+
var mlContext = new MLContext(seed: 1, conc: 1);
65+
66+
// Get the dataset.
67+
var data = mlContext.Data.ReadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);
68+
69+
// Create a pipeline to train on the housing data.
70+
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
71+
.Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent());
72+
73+
// Fit the pipeline and transform the data.
74+
var model = pipeline.Fit(data);
75+
var linearModel = model.LastTransformer.Model;
76+
77+
// Make sure the number of model weights returned matches the length of the input feature vector.
78+
var weights = linearModel.Weights;
79+
Assert.Equal(HousingRegression.Features.Length, weights.Count);
80+
}
81+
82+
/// <summary>
83+
/// GlobalFeatureImportance: A FastTree model can give back global feature importance through feature gain.
84+
/// </summary>
85+
[Fact]
86+
public void GlobalFeatureImportanceForFastTreeThroughFeatureGain()
87+
{
88+
var mlContext = new MLContext(seed: 1, conc: 1);
89+
90+
// Get the dataset
91+
var data = mlContext.Data.ReadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);
92+
93+
// Create a pipeline to train on the housing data.
94+
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
95+
.Append(mlContext.Regression.Trainers.FastTree());
96+
97+
// Fit the pipeline and transform the data.
98+
var model = pipeline.Fit(data);
99+
var treeModel = model.LastTransformer.Model;
100+
101+
// Get the feature gain.
102+
var weights = new VBuffer<float>();
103+
treeModel.GetFeatureWeights(ref weights);
104+
105+
// Make sure the number of feature gains returned matches the length of the input feature vector.
106+
Assert.Equal(HousingRegression.Features.Length, weights.Length);
107+
}
108+
109+
/// <summary>
110+
/// GlobalFeatureImportance: A FastForest model can give back global feature importance through feature gain.
111+
/// </summary>
112+
[Fact]
113+
public void GlobalFeatureImportanceForFastForestThroughFeatureGain()
114+
{
115+
var mlContext = new MLContext(seed: 1, conc: 1);
116+
117+
// Get the dataset
118+
var data = mlContext.Data.ReadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);
119+
120+
// Create a pipeline to train on the housing data.
121+
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
122+
.Append(mlContext.Regression.Trainers.FastForest());
123+
124+
// Fit the pipeline and transform the data.
125+
var model = pipeline.Fit(data);
126+
var treeModel = model.LastTransformer.Model;
127+
128+
// Get the feature gain
129+
var weights = new VBuffer<float>();
130+
treeModel.GetFeatureWeights(ref weights);
131+
132+
// Make sure the number of feature gains returned matches the length of the input feature vector.
133+
Assert.Equal(HousingRegression.Features.Length, weights.Length);
134+
}
135+
136+
/// <summary>
137+
/// LocalFeatureImportance: Per-row feature importance can be computed through FeatureContributionCalculator for a linear model.
138+
/// </summary>
139+
[Fact]
140+
public void LocalFeatureImportanceForLinearModel()
141+
{
142+
var mlContext = new MLContext(seed: 1, conc: 1);
143+
144+
// Get the dataset
145+
var data = mlContext.Data.ReadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);
146+
147+
// Create a pipeline to train on the housing data.
148+
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
149+
.Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent());
150+
151+
// Fit the pipeline and transform the data.
152+
var model = pipeline.Fit(data);
153+
var scoredData = model.Transform(data);
154+
155+
// Create a Feature Contribution Calculator.
156+
var predictor = model.LastTransformer;
157+
var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumn, normalize: false);
158+
159+
// Compute the contributions
160+
var outputData = featureContributions.Fit(scoredData).Transform(scoredData);
161+
162+
// Validate that the contributions are there
163+
var shuffledSubset = mlContext.Data.TakeRows(mlContext.Data.ShuffleRows(outputData), 10);
164+
var scoringEnumerator = mlContext.CreateEnumerable<FeatureContributionOutput>(shuffledSubset, true);
165+
166+
// Make sure the number of feature contributions returned matches the length of the input feature vector.
167+
foreach (var row in scoringEnumerator)
168+
{
169+
Assert.Equal(HousingRegression.Features.Length, row.FeatureContributions.Length);
170+
}
171+
}
172+
173+
/// <summary>
174+
/// LocalFeatureImportance: Per-row feature importance can be computed through FeatureContributionCalculator for a <see cref="FastTree"/> model.
175+
/// </summary>
176+
[Fact]
177+
public void LocalFeatureImportanceForFastTreeModel()
178+
{
179+
var mlContext = new MLContext(seed: 1, conc: 1);
180+
181+
// Get the dataset
182+
var data = mlContext.Data.ReadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);
183+
184+
// Create a pipeline to train on the housing data.
185+
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
186+
.Append(mlContext.Regression.Trainers.FastTree());
187+
188+
// Fit the pipeline and transform the data.
189+
var model = pipeline.Fit(data);
190+
var scoredData = model.Transform(data);
191+
192+
// Create a Feature Contribution Calculator.
193+
var predictor = model.LastTransformer;
194+
var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumn, normalize: false);
195+
196+
// Compute the contributions
197+
var outputData = featureContributions.Fit(scoredData).Transform(scoredData);
198+
199+
// Validate that the contributions are there
200+
var shuffledSubset = mlContext.Data.TakeRows(mlContext.Data.ShuffleRows(outputData), 10);
201+
var scoringEnumerator = mlContext.CreateEnumerable<FeatureContributionOutput>(shuffledSubset, true);
202+
203+
// Make sure the number of feature contributions returned matches the length of the input feature vector.
204+
foreach (var row in scoringEnumerator)
205+
{
206+
Assert.Equal(HousingRegression.Features.Length, row.FeatureContributions.Length);
207+
}
208+
}
209+
210+
/// <summary>
211+
/// LocalFeatureImportance: Per-row feature importance can be computed through FeatureContributionCalculator for a <see cref="FastForest"/>model.
212+
/// </summary>
213+
[Fact]
214+
public void LocalFeatureImportanceForFastForestModel()
215+
{
216+
var mlContext = new MLContext(seed: 1, conc: 1);
217+
218+
// Get the dataset
219+
var data = mlContext.Data.ReadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);
220+
221+
// Create a pipeline to train on the housing data.
222+
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
223+
.Append(mlContext.Regression.Trainers.FastForest());
224+
225+
// Fit the pipeline and transform the data.
226+
var model = pipeline.Fit(data);
227+
var scoredData = model.Transform(data);
228+
229+
// Create a Feature Contribution Calculator.
230+
var predictor = model.LastTransformer;
231+
var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumn, normalize: false);
232+
233+
// Compute the contributions
234+
var outputData = featureContributions.Fit(scoredData).Transform(scoredData);
235+
236+
// Validate that the contributions are there
237+
var shuffledSubset = mlContext.Data.TakeRows(mlContext.Data.ShuffleRows(outputData), 10);
238+
var scoringEnumerator = mlContext.CreateEnumerable<FeatureContributionOutput>(shuffledSubset, true);
239+
240+
// Make sure the number of feature contributions returned matches the length of the input feature vector.
241+
foreach (var row in scoringEnumerator)
242+
{
243+
Assert.Equal(HousingRegression.Features.Length, row.FeatureContributions.Length);
244+
}
245+
}
246+
247+
/// <summary>
248+
/// LocalFeatureImportance: Per-row feature importance can be computed through FeatureContributionCalculator for a <see cref="GamModelParametersBase" />
249+
/// (Generalized Additive Model) model.
250+
/// </summary>
251+
[Fact]
252+
public void LocalFeatureImportanceForGamModel()
253+
{
254+
var mlContext = new MLContext(seed: 1, conc: 1);
255+
256+
// Get the dataset
257+
var data = mlContext.Data.ReadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);
258+
259+
// Create a pipeline to train on the housing data.
260+
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
261+
.Append(mlContext.Regression.Trainers.GeneralizedAdditiveModels(numIterations: 2));
262+
263+
// Fit the pipeline and transform the data.
264+
var model = pipeline.Fit(data);
265+
var scoredData = model.Transform(data);
266+
267+
// Create a Feature Contribution Calculator.
268+
var predictor = model.LastTransformer;
269+
var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumn, normalize: false);
270+
271+
// Compute the contributions
272+
var outputData = featureContributions.Fit(scoredData).Transform(scoredData);
273+
274+
// Validate that the contributions are there
275+
var shuffledSubset = mlContext.Data.TakeRows(mlContext.Data.ShuffleRows(outputData), 10);
276+
var scoringEnumerator = mlContext.CreateEnumerable<FeatureContributionOutput>(shuffledSubset, true);
277+
278+
// Make sure the number of feature contributions returned matches the length of the input feature vector.
279+
foreach (var row in scoringEnumerator)
280+
{
281+
Assert.Equal(HousingRegression.Features.Length, row.FeatureContributions.Length);
282+
}
283+
}
284+
}
285+
}

0 commit comments

Comments
 (0)