diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs index 8bf8c6a291..4d98bb87ab 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs @@ -1,5 +1,6 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Trainers.HalLearners; using System; using System.Linq; @@ -55,10 +56,10 @@ public static void PFI_Regression() "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio") .Append(mlContext.Transforms.Normalize("Features")) - .Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent( + .Append(mlContext.Regression.Trainers.OrdinaryLeastSquares( labelColumn: labelName, featureColumn: "Features")); - var model = pipeline.Fit(data); + var model = pipeline.Fit(data); // Extract the model from the pipeline var linearPredictor = model.LastTransformer; var weights = GetLinearModelWeights(linearPredictor.Model); @@ -66,7 +67,7 @@ public static void PFI_Regression() // Compute the permutation metrics using the properly-featurized data. var transformedData = model.Transform(data); var permutationMetrics = mlContext.Regression.PermutationFeatureImportance( - linearPredictor, transformedData, label: labelName, features: "Features"); + linearPredictor, transformedData, label: labelName, features: "Features", permutationCount: 3); // Now let's look at which features are most important to the model overall // First, we have to prepare the data: @@ -78,49 +79,47 @@ public static void PFI_Regression() // Get the feature indices sorted by their impact on R-Squared var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.RSquared }) - .OrderByDescending(feature => Math.Abs(feature.RSquared)) + .OrderByDescending(feature => Math.Abs(feature.RSquared.Mean)) .Select(feature => feature.index); // Print out the permutation results, with the model weights, in order of their impact: - // Expected console output: - // Feature Model Weight Change in R - Squared - // RoomsPerDwelling 50.80 -0.3695 - // EmploymentDistance -17.79 -0.2238 - // TeacherRatio -19.83 -0.1228 - // TaxRate -8.60 -0.1042 - // NitricOxides -15.95 -0.1025 - // HighwayDistance 5.37 -0.09345 - // CrimesPerCapita -15.05 -0.05797 - // PercentPre40s -4.64 -0.0385 - // PercentResidental 3.98 -0.02184 - // CharlesRiver 3.38 -0.01487 - // PercentNonRetail -1.94 -0.007231 + // Expected console output for 100 permutations: + // Feature Model Weight Change in R-Squared 95% Confidence Interval of the Mean + // RoomsPerDwelling 53.35 -0.4298 0.005705 + // EmploymentDistance -19.21 -0.2609 0.004591 + // NitricOxides -19.32 -0.1569 0.003701 + // HighwayDistance 6.11 -0.1173 0.0025 + // TeacherRatio -21.92 -0.1106 0.002207 + // TaxRate -8.68 -0.1008 0.002083 + // CrimesPerCapita -16.37 -0.05988 0.00178 + // PercentPre40s -4.52 -0.03836 0.001432 + // PercentResidental 3.91 -0.02006 0.001079 + // CharlesRiver 3.49 -0.01839 0.000841 + // PercentNonRetail -1.17 -0.002111 0.0003176 // // Let's dig into these results a little bit. First, if you look at the weights of the model, they generally correlate - // with the results of PFI, but there are some significant misorderings. For example, "Tax Rate" is weighted lower than - // "Nitric Oxides" and "Crimes Per Capita", but the permutation analysis shows this feature to have a larger effect - // on the accuracy of the model even though it has a relatively small weight. To understand why the weights don't - // reflect the same feature importance as PFI, we need to go back to the basics of linear models: one of the - // assumptions of a linear model is that the features are uncorrelated. Now, the features in this dataset are clearly - // correlated: the tax rate for a house and the student-to-teacher ratio at the nearest school, for example, are often - // coupled through school levies. The tax rate, presence of pollution (e.g. nitric oxides), and the crime rate would also - // seem to be correlated with each other through social dynamics. We could draw out similar relationships for all the - // variables in this dataset. The reason why the linear model weights don't reflect the same feature importance as PFI - // is that the solution to the linear model redistributes weights between correlated variables in unpredictable ways, so - // that the weights themselves are no longer a good measure of feature importance. - Console.WriteLine("Feature\tModel Weight\tChange in R-Squared"); + // with the results of PFI, but there are some significant misorderings. For example, "Tax Rate" and "Highway Distance" + // have relatively small model weights, but the permutation analysis shows these feature to have a larger effect + // on the accuracy of the model than higher-weighted features. To understand why the weights don't reflect the same + // feature importance as PFI, we need to go back to the basics of linear models: one of the assumptions of a linear + // model is that the features are uncorrelated. Now, the features in this dataset are clearly correlated: the tax rate + // for a house and the student-to-teacher ratio at the nearest school, for example, are often coupled through school + // levies. The tax rate, distance to a highway, and the crime rate would also seem to be correlated through social + // dynamics. We could draw out similar relationships for all variables in this dataset. The reason why the linear + // model weights don't reflect the same feature importance as PFI is that the solution to the linear model redistributes + // weights between correlated variables in unpredictable ways, so that the weights themselves are no longer a good + // measure of feature importance. + Console.WriteLine("Feature\tModel Weight\tChange in R-Squared\t95% Confidence Interval of the Mean"); var rSquared = permutationMetrics.Select(x => x.RSquared).ToArray(); // Fetch r-squared as an array foreach (int i in sortedIndices) { - Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i]:G4}"); + Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i].Mean:G4}\t{1.96 * rSquared[i].StandardError:G4}"); } } - private static float[] GetLinearModelWeights(LinearRegressionModelParameters linearModel) + private static float[] GetLinearModelWeights(OlsLinearRegressionModelParameters linearModel) { - var weights = new VBuffer(); - linearModel.GetFeatureWeights(ref weights); - return weights.GetValues().ToArray(); + return linearModel.Weights.ToArray(); } } } diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs index 17928c2c9d..2c24f343ed 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs @@ -15,21 +15,22 @@ namespace Microsoft.ML.Transforms { - internal static class PermutationFeatureImportance + internal static class PermutationFeatureImportance where TResult : MetricsStatisticsBase, new() { public static ImmutableArray GetImportanceMetricsMatrix( IHostEnvironment env, IPredictionTransformer model, IDataView data, - Func evaluationFunc, - Func deltaFunc, + Func evaluationFunc, + Func deltaFunc, string features, + int permutationCount, bool useFeatureWeightFilter = false, int? topExamples = null) { Contracts.CheckValue(env, nameof(env)); - var host = env.Register(nameof(PermutationFeatureImportance)); + var host = env.Register(nameof(PermutationFeatureImportance)); host.CheckValue(model, nameof(model)); host.CheckValue(data, nameof(data)); host.CheckNonEmpty(features, nameof(features)); @@ -168,7 +169,7 @@ public static ImmutableArray // Now iterate through all the working slots, do permutation and calc the delta of metrics. int processedCnt = 0; int nextFeatureIndex = 0; - int shuffleSeed = host.Rand.Next(); + var shuffleRand = RandomUtils.Create(host.Rand.Next()); using (var pch = host.StartProgressChannel("SDCA preprocessing with lookup")) { pch.SetHeader(new ProgressHeader("processed slots"), e => e.SetProgress(0, processedCnt)); @@ -178,27 +179,9 @@ public static ImmutableArray if (processedCnt < workingFeatureIndices.Count - 1) nextFeatureIndex = workingFeatureIndices[processedCnt + 1]; + // Used for pre-caching the next feature int nextValuesIndex = 0; - Utils.Shuffle(RandomUtils.Create(shuffleSeed), featureValuesBuffer); - - Action permuter = - (src, dst, state) => - { - src.Features.CopyTo(ref dst.Features); - VBufferUtils.ApplyAt(ref dst.Features, workingIndx, - (int ii, ref float d) => - d = featureValuesBuffer[state.SampleIndex++]); - - if (processedCnt < workingFeatureIndices.Count - 1) - { - // This is the reason I need PermuterState in LambdaTransform.CreateMap. - nextValues[nextValuesIndex] = src.Features.GetItemOrDefault(nextFeatureIndex); - if (nextValuesIndex < valuesRowCount - 1) - nextValuesIndex++; - } - }; - SchemaDefinition input = SchemaDefinition.Create(typeof(FeaturesBuffer)); Contracts.Assert(input.Count == 1); input[0].ColumnName = features; @@ -208,20 +191,53 @@ public static ImmutableArray output[0].ColumnName = features; output[0].ColumnType = featuresColumn.Type; - IDataView viewPermuted = LambdaTransform.CreateMap( - host, data, permuter, null, input, output); - if (valuesRowCount == topExamples) - viewPermuted = SkipTakeFilter.Create(host, new SkipTakeFilter.TakeArguments() { Count = valuesRowCount }, viewPermuted); + // Perform multiple permutations for one feature to build a confidence interval + var metricsDeltaForFeature = new TResult(); + for (int permutationIteration = 0; permutationIteration < permutationCount; permutationIteration++) + { + Utils.Shuffle(shuffleRand, featureValuesBuffer); - var metrics = evaluationFunc(model.Transform(viewPermuted)); + Action permuter = + (src, dst, state) => + { + src.Features.CopyTo(ref dst.Features); + VBufferUtils.ApplyAt(ref dst.Features, workingIndx, + (int ii, ref float d) => + d = featureValuesBuffer[state.SampleIndex++]); + + // Is it time to pre-cache the next feature? + if (permutationIteration == permutationCount - 1 && + processedCnt < workingFeatureIndices.Count - 1) + { + // Fill out the featureValueBuffer for the next feature while updating the current feature + // This is the reason I need PermuterState in LambdaTransform.CreateMap. + nextValues[nextValuesIndex] = src.Features.GetItemOrDefault(nextFeatureIndex); + if (nextValuesIndex < valuesRowCount - 1) + nextValuesIndex++; + } + }; + + IDataView viewPermuted = LambdaTransform.CreateMap( + host, data, permuter, null, input, output); + if (valuesRowCount == topExamples) + viewPermuted = SkipTakeFilter.Create(host, new SkipTakeFilter.TakeArguments() { Count = valuesRowCount }, viewPermuted); + + var metrics = evaluationFunc(model.Transform(viewPermuted)); + + var delta = deltaFunc(metrics, baselineMetrics); + metricsDeltaForFeature.Add(delta); + } - var delta = deltaFunc(metrics, baselineMetrics); - metricsDelta.Add(delta); + // Add the metrics delta to the list + metricsDelta.Add(metricsDeltaForFeature); // Swap values for next iteration of permutation. - Array.Clear(featureValuesBuffer, 0, featureValuesBuffer.Length); - nextValues.CopyTo(featureValuesBuffer, 0); - Array.Clear(nextValues, 0, nextValues.Length); + if (processedCnt < workingFeatureIndices.Count - 1) + { + Array.Clear(featureValuesBuffer, 0, featureValuesBuffer.Length); + nextValues.CopyTo(featureValuesBuffer, 0); + Array.Clear(nextValues, 0, nextValues.Length); + } processedCnt++; } pch.Checkpoint(processedCnt, processedCnt); diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs index 611379f320..f4e373650b 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Transforms; using System; using System.Collections.Immutable; @@ -53,8 +54,9 @@ public static class PermutationFeatureImportanceExtensions /// Feature column name. /// Use features weight to pre-filter features. /// Limit the number of examples to evaluate on. null means examples (up to ~ 2 bln) from input will be used. + /// The number of permutations to perform. /// Array of per-feature 'contributions' to the score. - public static ImmutableArray + public static ImmutableArray PermutationFeatureImportance( this RegressionContext ctx, IPredictionTransformer model, @@ -62,15 +64,17 @@ public static ImmutableArray string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, bool useFeatureWeightFilter = false, - int? topExamples = null) + int? topExamples = null, + int permutationCount = 1) { - return PermutationFeatureImportance.GetImportanceMetricsMatrix( + return PermutationFeatureImportance.GetImportanceMetricsMatrix( CatalogUtils.GetEnvironment(ctx), model, data, idv => ctx.Evaluate(idv, label), RegressionDelta, features, + permutationCount, useFeatureWeightFilter, topExamples); } @@ -127,8 +131,9 @@ private static RegressionMetrics RegressionDelta( /// Feature column name. /// Use features weight to pre-filter features. /// Limit the number of examples to evaluate on. null means examples (up to ~ 2 bln) from input will be used. + /// The number of permutations to perform. /// Array of per-feature 'contributions' to the score. - public static ImmutableArray + public static ImmutableArray PermutationFeatureImportance( this BinaryClassificationContext ctx, IPredictionTransformer model, @@ -136,15 +141,17 @@ public static ImmutableArray string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, bool useFeatureWeightFilter = false, - int? topExamples = null) + int? topExamples = null, + int permutationCount = 1) { - return PermutationFeatureImportance.GetImportanceMetricsMatrix( + return PermutationFeatureImportance.GetImportanceMetricsMatrix( CatalogUtils.GetEnvironment(ctx), model, data, idv => ctx.Evaluate(idv, label), BinaryClassifierDelta, features, + permutationCount, useFeatureWeightFilter, topExamples); } @@ -205,8 +212,9 @@ private static BinaryClassificationMetrics BinaryClassifierDelta( /// Feature column name. /// Use features weight to pre-filter features. /// Limit the number of examples to evaluate on. null means examples (up to ~ 2 bln) from input will be used. + /// The number of permutations to perform. /// Array of per-feature 'contributions' to the score. - public static ImmutableArray + public static ImmutableArray PermutationFeatureImportance( this MulticlassClassificationContext ctx, IPredictionTransformer model, @@ -214,15 +222,17 @@ public static ImmutableArray string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, bool useFeatureWeightFilter = false, - int? topExamples = null) + int? topExamples = null, + int permutationCount = 1) { - return PermutationFeatureImportance.GetImportanceMetricsMatrix( + return PermutationFeatureImportance.GetImportanceMetricsMatrix( CatalogUtils.GetEnvironment(ctx), model, data, idv => ctx.Evaluate(idv, label), MulticlassClassificationDelta, features, + permutationCount, useFeatureWeightFilter, topExamples); } @@ -231,7 +241,7 @@ private static MultiClassClassifierMetrics MulticlassClassificationDelta( MultiClassClassifierMetrics a, MultiClassClassifierMetrics b) { if (a.TopK != b.TopK) - Contracts.Assert(a.TopK== b.TopK, "TopK to compare must be the same length."); + Contracts.Assert(a.TopK == b.TopK, "TopK to compare must be the same length."); var perClassLogLoss = ComputeArrayDeltas(a.PerClassLogLoss, b.PerClassLogLoss); @@ -289,8 +299,9 @@ private static MultiClassClassifierMetrics MulticlassClassificationDelta( /// Feature column name. /// Use features weight to pre-filter features. /// Limit the number of examples to evaluate on. null means examples (up to ~ 2 bln) from input will be used. + /// The number of permutations to perform. /// Array of per-feature 'contributions' to the score. - public static ImmutableArray + public static ImmutableArray PermutationFeatureImportance( this RankingContext ctx, IPredictionTransformer model, @@ -299,15 +310,17 @@ public static ImmutableArray string groupId = DefaultColumnNames.GroupId, string features = DefaultColumnNames.Features, bool useFeatureWeightFilter = false, - int? topExamples = null) + int? topExamples = null, + int permutationCount = 1) { - return PermutationFeatureImportance.GetImportanceMetricsMatrix( + return PermutationFeatureImportance.GetImportanceMetricsMatrix( CatalogUtils.GetEnvironment(ctx), model, data, idv => ctx.Evaluate(idv, label, groupId), RankingDelta, features, + permutationCount, useFeatureWeightFilter, topExamples); } @@ -323,91 +336,320 @@ private static RankerMetrics RankingDelta( #endregion - #region Clustering + #region Helpers + + private static double[] ComputeArrayDeltas(double[] a, double[] b) + { + Contracts.Assert(a.Length == b.Length, "Arrays to compare must be of the same length."); + + var delta = new double[a.Length]; + for (int i = 0; i < a.Length; i++) + delta[i] = a[i] - b[i]; + return delta; + } + + #endregion + } + + #region MetricsStatistics + + /// + /// The MetricsStatistics class computes summary statistics over multiple observations of a metric. + /// + public sealed class MetricStatistics + { + private readonly SummaryStatistics _statistic; + /// - /// Permutation Feature Importance (PFI) for Clustering + /// Get the mean value for the metric /// - /// - /// - /// Permutation feature importance (PFI) is a technique to determine the global importance of features in a trained - /// machine learning model. PFI is a simple yet powerful technique motivated by Breiman in his Random Forest paper, section 10 - /// (Breiman. "Random Forests." Machine Learning, 2001.) - /// The advantage of the PFI method is that it is model agnostic -- it works with any model that can be - /// evaluated -- and it can use any dataset, not just the training set, to compute feature importance metrics. - /// - /// - /// PFI works by taking a labeled dataset, choosing a feature, and permuting the values - /// for that feature across all the examples, so that each example now has a random value for the feature and - /// the original values for all other features. The evalution metric (e.g. normalized mutual information) is then calculated - /// for this modified dataset, and the change in the evaluation metric from the original dataset is computed. - /// The larger the change in the evaluation metric, the more important the feature is to the model. - /// PFI works by performing this permutation analysis across all the features of a model, one after another. - /// - /// - /// In this implementation, PFI computes the change in all possible clustering evaluation metrics for each feature, and an - /// ImmutableArray of ClusteringMetrics objects is returned. See the sample below for an - /// example of working with these results to analyze the feature importance of a model. - /// - /// - /// - /// - /// - /// - /// - /// The clustering context. - /// The model to evaluate. - /// The evaluation data set. - /// Label column name. - /// Feature column name. - /// Use features weight to pre-filter features. - /// Limit the number of examples to evaluate on. null means examples (up to ~ 2 bln) from input will be used. - /// Array of per-feature 'contributions' to the score. - public static ImmutableArray - PermutationFeatureImportance( - this ClusteringContext ctx, - IPredictionTransformer model, - IDataView data, - string label = DefaultColumnNames.Label, - string features = DefaultColumnNames.Features, - bool useFeatureWeightFilter = false, - int? topExamples = null) + public double Mean => _statistic.Mean; + + /// + /// Get the standard deviation for the metric + /// + public double StandardDeviation => (_statistic.RawCount <= 1) ? 0 : _statistic.SampleStdDev; + + /// + /// Get the standard error of the mean for the metric + /// + public double StandardError => (_statistic.RawCount <= 1) ? 0 : _statistic.StandardErrorMean; + + /// + /// Get the count for the number of samples used. Useful for interpreting + /// the standard deviation and the stardard error and building confidence intervals. + /// + public int Count => (int) _statistic.RawCount; + + internal MetricStatistics() { - return PermutationFeatureImportance.GetImportanceMetricsMatrix( - CatalogUtils.GetEnvironment(ctx), - model, - data, - idv => ctx.Evaluate(idv, label), - ClusteringDelta, - features, - useFeatureWeightFilter, - topExamples); + _statistic = new SummaryStatistics(); } - private static ClusteringMetrics ClusteringDelta( - ClusteringMetrics a, ClusteringMetrics b) + /// + /// Add another metric to the set of observations + /// + /// The metric being accumulated + internal void Add(double metric) { - return new ClusteringMetrics( - nmi: a.Nmi - b.Nmi, - avgMinScore: a.AvgMinScore- b.AvgMinScore, - dbi: a.Dbi - b.Dbi); + _statistic.Add(metric); } + } - #endregion Clustering + /// + /// The MetricsStatisticsBase class is the base class for computing summary + /// statistics over multiple observations of model evaluation metrics. + /// + /// The EvaluationMetric type, such as RegressionMetrics + public abstract class MetricsStatisticsBase{ + internal MetricsStatisticsBase() + { + } - #region Helpers + public abstract void Add(T metrics); - private static double[] ComputeArrayDeltas(double[] a, double[] b) + protected static void AddArray(double[] src, MetricStatistics[] dest) { - Contracts.Assert(a.Length == b.Length, "Arrays to compare must be of the same length."); + Contracts.Assert(src.Length == dest.Length, "Array sizes do not match."); - var delta = new double[a.Length]; - for (int i = 0; i < a.Length; i++) - delta[i] = a[i] - b[i]; - return delta; + for (int i = 0; i < dest.Length; i++) + dest[i].Add(src[i]); } - #endregion + protected MetricStatistics[] InitializeArray(int length) + { + var array = new MetricStatistics[length]; + for (int i = 0; i < array.Length; i++) + array[i] = new MetricStatistics(); + + return array; + } } + + /// + /// The RegressionMetricsStatistics class is computes summary + /// statistics over multiple observations of regression evaluation metrics. + /// + public sealed class RegressionMetricsStatistics : MetricsStatisticsBase + { + /// + /// Summary Statistics for L1 + /// + public MetricStatistics L1 { get; } + + /// + /// Summary Statistics for L2 + /// + public MetricStatistics L2 { get; } + + /// + /// Summary statistics for the root mean square loss (or RMS). + /// + public MetricStatistics Rms { get; } + + /// + /// Summary statistics for the user-supplied loss function. + /// + public MetricStatistics LossFn { get; } + + /// + /// Summary statistics for the R squared value. + /// + public MetricStatistics RSquared { get; } + + public RegressionMetricsStatistics() + { + L1 = new MetricStatistics(); + L2 = new MetricStatistics(); + Rms = new MetricStatistics(); + LossFn = new MetricStatistics(); + RSquared = new MetricStatistics(); + } + + /// + /// Add a set of evaluation metrics to the set of observations. + /// + /// The observed regression evaluation metric + public override void Add(RegressionMetrics metrics) + { + L1.Add(metrics.L1); + L2.Add(metrics.L2); + Rms.Add(metrics.Rms); + LossFn.Add(metrics.LossFn); + RSquared.Add(metrics.RSquared); + } + } + + /// + /// The BinaryClassificationMetricsStatistics class is computes summary + /// statistics over multiple observations of binary classification evaluation metrics. + /// + public sealed class BinaryClassificationMetricsStatistics : MetricsStatisticsBase + { + /// + /// Summary Statistics for AUC + /// + public MetricStatistics Auc { get; } + + /// + /// Summary Statistics for Accuracy + /// + public MetricStatistics Accuracy { get; } + + /// + /// Summary statistics for Positive Precision + /// + public MetricStatistics PositivePrecision { get; } + + /// + /// Summary statistics for Positive Recall + /// + public MetricStatistics PositiveRecall { get; } + + /// + /// Summary statistics for Negative Precision. + /// + public MetricStatistics NegativePrecision { get; } + + /// + /// Summary statistics for Negative Recall. + /// + public MetricStatistics NegativeRecall { get; } + + /// + /// Summary statistics for F1Score. + /// + public MetricStatistics F1Score { get; } + + /// + /// Summary statistics for AUPRC. + /// + public MetricStatistics Auprc { get; } + + public BinaryClassificationMetricsStatistics() + { + Auc = new MetricStatistics(); + Accuracy = new MetricStatistics(); + PositivePrecision = new MetricStatistics(); + PositiveRecall = new MetricStatistics(); + NegativePrecision = new MetricStatistics(); + NegativeRecall = new MetricStatistics(); + F1Score = new MetricStatistics(); + Auprc = new MetricStatistics(); + } + + /// + /// Add a set of evaluation metrics to the set of observations. + /// + /// The observed binary classification evaluation metric + public override void Add(BinaryClassificationMetrics metrics) + { + Auc.Add(metrics.Auc); + Accuracy.Add(metrics.Accuracy); + PositivePrecision.Add(metrics.PositivePrecision); + PositiveRecall.Add(metrics.PositiveRecall); + NegativePrecision.Add(metrics.NegativePrecision); + NegativeRecall.Add(metrics.NegativeRecall); + F1Score.Add(metrics.F1Score); + Auprc.Add(metrics.Auprc); + } + } + + /// + /// The MultiClassClassifierMetricsStatistics class is computes summary + /// statistics over multiple observations of binary classification evaluation metrics. + /// + public sealed class MultiClassClassifierMetricsStatistics : MetricsStatisticsBase + { + /// + /// Summary Statistics for Micro-Accuracy + /// + public MetricStatistics AccuracyMacro { get; } + + /// + /// Summary Statistics for Micro-Accuracy + /// + public MetricStatistics AccuracyMicro { get; } + + /// + /// Summary statistics for Log Loss + /// + public MetricStatistics LogLoss { get; } + + /// + /// Summary statistics for Log Loss Reduction + /// + public MetricStatistics LogLossReduction { get; } + + /// + /// Summary statistics for Top K Accuracy + /// + public MetricStatistics TopKAccuracy { get; } + + /// + /// Summary statistics for Per Class Log Loss + /// + public MetricStatistics[] PerClassLogLoss { get; private set; } + + public MultiClassClassifierMetricsStatistics() + { + AccuracyMacro = new MetricStatistics(); + AccuracyMicro = new MetricStatistics(); + LogLoss = new MetricStatistics(); + LogLossReduction = new MetricStatistics(); + TopKAccuracy = new MetricStatistics(); + } + + /// + /// Add a set of evaluation metrics to the set of observations. + /// + /// The observed binary classification evaluation metric + public override void Add(MultiClassClassifierMetrics metrics) + { + AccuracyMacro.Add(metrics.AccuracyMacro); + AccuracyMicro.Add(metrics.AccuracyMicro); + LogLoss.Add(metrics.LogLoss); + LogLossReduction.Add(metrics.LogLossReduction); + TopKAccuracy.Add(metrics.TopKAccuracy); + + if (PerClassLogLoss == null) + PerClassLogLoss = InitializeArray(metrics.PerClassLogLoss.Length); + AddArray(metrics.PerClassLogLoss, PerClassLogLoss); + } + } + + /// + /// The RankerMetricsStatistics class is computes summary + /// statistics over multiple observations of regression evaluation metrics. + /// + public sealed class RankerMetricsStatistics : MetricsStatisticsBase + { + /// + /// Summary Statistics for DCG + /// + public MetricStatistics[] Dcg { get; private set; } + + /// + /// Summary Statistics for L2 + /// + public MetricStatistics[] Ndcg { get; private set; } + + /// + /// Add a set of evaluation metrics to the set of observations. + /// + /// The observed regression evaluation metric + public override void Add(RankerMetrics metrics) + { + if (Dcg == null) + Dcg = InitializeArray(metrics.Dcg.Length); + + if (Ndcg == null) + Ndcg = InitializeArray(metrics.Ndcg.Length); + + AddArray(metrics.Dcg, Dcg); + AddArray(metrics.Ndcg, Ndcg); + } + } + + #endregion } diff --git a/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs b/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs index 8815842c6e..85428e0888 100644 --- a/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs +++ b/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs @@ -40,18 +40,70 @@ public void TestPfiRegressionOnDenseFeatures() // X4Rand: 3 // For the following metrics lower is better, so maximum delta means more important feature, and vice versa - Assert.Equal(3, MinDeltaIndex(pfi, m => m.L1)); - Assert.Equal(1, MaxDeltaIndex(pfi, m => m.L1)); + Assert.Equal(3, MinDeltaIndex(pfi, m => m.L1.Mean)); + Assert.Equal(1, MaxDeltaIndex(pfi, m => m.L1.Mean)); - Assert.Equal(3, MinDeltaIndex(pfi, m => m.L2)); - Assert.Equal(1, MaxDeltaIndex(pfi, m => m.L2)); + Assert.Equal(3, MinDeltaIndex(pfi, m => m.L2.Mean)); + Assert.Equal(1, MaxDeltaIndex(pfi, m => m.L2.Mean)); - Assert.Equal(3, MinDeltaIndex(pfi, m => m.Rms)); - Assert.Equal(1, MaxDeltaIndex(pfi, m => m.Rms)); + Assert.Equal(3, MinDeltaIndex(pfi, m => m.Rms.Mean)); + Assert.Equal(1, MaxDeltaIndex(pfi, m => m.Rms.Mean)); // For the following metrics higher is better, so minimum delta means more important feature, and vice versa - Assert.Equal(3, MaxDeltaIndex(pfi, m => m.RSquared)); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.RSquared)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.RSquared.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.RSquared.Mean)); + + Done(); + } + + /// + /// Test PFI Regression Standard Deviation and Standard Error for Dense Features + /// + [Fact] + public void TestPfiRegressionStandardDeviationAndErrorOnDenseFeatures() + { + var data = GetDenseDataset(); + var model = ML.Regression.Trainers.OnlineGradientDescent().Fit(data); + var pfi = ML.Regression.PermutationFeatureImportance(model, data, permutationCount: 20); + // Keep the permutation count high so fluctuations are kept to a minimum + // but not high enough to slow down the tests + // (fluctuations lead to random test failures) + + // Pfi Indices: + // X1: 0 + // X2Important: 1 + // X3: 2 + // X4Rand: 3 + + // For these metrics, the magnitude of the difference will be greatest for 1, least for 3 + // Stardard Deviation will scale with the magnitude of the measure + Assert.Equal(3, MinDeltaIndex(pfi, m => m.L1.StandardDeviation)); + Assert.Equal(1, MaxDeltaIndex(pfi, m => m.L1.StandardDeviation)); + + Assert.Equal(3, MinDeltaIndex(pfi, m => m.L2.StandardDeviation)); + Assert.Equal(1, MaxDeltaIndex(pfi, m => m.L2.StandardDeviation)); + + Assert.Equal(3, MinDeltaIndex(pfi, m => m.Rms.StandardDeviation)); + Assert.Equal(1, MaxDeltaIndex(pfi, m => m.Rms.StandardDeviation)); + + Assert.Equal(3, MinDeltaIndex(pfi, m => m.RSquared.StandardDeviation)); + Assert.Equal(1, MaxDeltaIndex(pfi, m => m.RSquared.StandardDeviation)); + + // Stardard Error will scale with the magnitude of the measure (as it's SD/sqrt(N)) + Assert.Equal(3, MinDeltaIndex(pfi, m => m.L1.StandardError)); + Assert.Equal(1, MaxDeltaIndex(pfi, m => m.L1.StandardError)); + + Assert.Equal(3, MinDeltaIndex(pfi, m => m.L2.StandardError)); + Assert.Equal(1, MaxDeltaIndex(pfi, m => m.L2.StandardError)); + + Assert.Equal(3, MinDeltaIndex(pfi, m => m.Rms.StandardError)); + Assert.Equal(1, MaxDeltaIndex(pfi, m => m.Rms.StandardError)); + + Assert.Equal(3, MinDeltaIndex(pfi, m => m.RSquared.StandardError)); + Assert.Equal(1, MaxDeltaIndex(pfi, m => m.RSquared.StandardError)); + + // And test that the Standard Deviation and Standard Error are related as we expect + Assert.Equal(pfi[0].Rms.StandardError, pfi[0].Rms.StandardDeviation / Math.Sqrt(pfi[0].Rms.Count)); Done(); } @@ -76,18 +128,18 @@ public void TestPfiRegressionOnSparseFeatures() // Permuted X2VBuffer-Slot-1 lot (f2) should have min impact on SGD metrics, X3Important -- max impact. // For the following metrics lower is better, so maximum delta means more important feature, and vice versa - Assert.Equal(2, MinDeltaIndex(results, m => m.L1)); - Assert.Equal(5, MaxDeltaIndex(results, m => m.L1)); + Assert.Equal(2, MinDeltaIndex(results, m => m.L1.Mean)); + Assert.Equal(5, MaxDeltaIndex(results, m => m.L1.Mean)); - Assert.Equal(2, MinDeltaIndex(results, m => m.L2)); - Assert.Equal(5, MaxDeltaIndex(results, m => m.L2)); + Assert.Equal(2, MinDeltaIndex(results, m => m.L2.Mean)); + Assert.Equal(5, MaxDeltaIndex(results, m => m.L2.Mean)); - Assert.Equal(2, MinDeltaIndex(results, m => m.Rms)); - Assert.Equal(5, MaxDeltaIndex(results, m => m.Rms)); + Assert.Equal(2, MinDeltaIndex(results, m => m.Rms.Mean)); + Assert.Equal(5, MaxDeltaIndex(results, m => m.Rms.Mean)); // For the following metrics higher is better, so minimum delta means more important feature, and vice versa - Assert.Equal(2, MaxDeltaIndex(results, m => m.RSquared)); - Assert.Equal(5, MinDeltaIndex(results, m => m.RSquared)); + Assert.Equal(2, MaxDeltaIndex(results, m => m.RSquared.Mean)); + Assert.Equal(5, MinDeltaIndex(results, m => m.RSquared.Mean)); } #endregion @@ -110,22 +162,22 @@ public void TestPfiBinaryClassificationOnDenseFeatures() // X4Rand: 3 // For the following metrics higher is better, so minimum delta means more important feature, and vice versa - Assert.Equal(3, MaxDeltaIndex(pfi, m => m.Auc)); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.Auc)); - Assert.Equal(3, MaxDeltaIndex(pfi, m => m.Accuracy)); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.Accuracy)); - Assert.Equal(3, MaxDeltaIndex(pfi, m => m.PositivePrecision)); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositivePrecision)); - Assert.Equal(3, MaxDeltaIndex(pfi, m => m.PositiveRecall)); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositiveRecall)); - Assert.Equal(3, MaxDeltaIndex(pfi, m => m.NegativePrecision)); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativePrecision)); - Assert.Equal(3, MaxDeltaIndex(pfi, m => m.NegativeRecall)); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativeRecall)); - Assert.Equal(3, MaxDeltaIndex(pfi, m => m.F1Score)); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.F1Score)); - Assert.Equal(3, MaxDeltaIndex(pfi, m => m.Auprc)); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.Auprc)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.Auc.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.Auc.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.Accuracy.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.Accuracy.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.PositivePrecision.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositivePrecision.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.PositiveRecall.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositiveRecall.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.NegativePrecision.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativePrecision.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.NegativeRecall.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativeRecall.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.F1Score.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.F1Score.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.Auprc.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.Auprc.Mean)); Done(); } @@ -149,22 +201,22 @@ public void TestPfiBinaryClassificationOnSparseFeatures() // X3Important: 5 // For the following metrics higher is better, so minimum delta means more important feature, and vice versa - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.Auc)); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.Auc)); - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.Accuracy)); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.Accuracy)); - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.PositivePrecision)); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.PositivePrecision)); - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.PositiveRecall)); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.PositiveRecall)); - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.NegativePrecision)); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.NegativePrecision)); - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.NegativeRecall)); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.NegativeRecall)); - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.F1Score)); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.F1Score)); - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.Auprc)); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.Auprc)); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.Auc.Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.Auc.Mean)); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.Accuracy.Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.Accuracy.Mean)); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.PositivePrecision.Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.PositivePrecision.Mean)); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.PositiveRecall.Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.PositiveRecall.Mean)); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.NegativePrecision.Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.NegativePrecision.Mean)); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.NegativeRecall.Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.NegativeRecall.Mean)); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.F1Score.Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.F1Score.Mean)); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.Auprc.Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.Auprc.Mean)); Done(); } @@ -188,21 +240,21 @@ public void TestPfiMulticlassClassificationOnDenseFeatures() // X4Rand: 3 // For the following metrics higher is better, so minimum delta means more important feature, and vice versa - Assert.Equal(3, MaxDeltaIndex(pfi, m => m.AccuracyMicro)); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.AccuracyMicro)); - Assert.Equal(3, MaxDeltaIndex(pfi, m => m.AccuracyMacro)); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.AccuracyMacro)); - Assert.Equal(3, MaxDeltaIndex(pfi, m => m.LogLossReduction)); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.LogLossReduction)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.AccuracyMicro.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.AccuracyMicro.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.AccuracyMacro.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.AccuracyMacro.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.LogLossReduction.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.LogLossReduction.Mean)); // For the following metrics-delta lower is better, so maximum delta means more important feature, and vice versa // Because they are _negative_, the difference will be positive for worse classifiers. - Assert.Equal(1, MaxDeltaIndex(pfi, m => m.LogLoss)); - Assert.Equal(3, MinDeltaIndex(pfi, m => m.LogLoss)); + Assert.Equal(1, MaxDeltaIndex(pfi, m => m.LogLoss.Mean)); + Assert.Equal(3, MinDeltaIndex(pfi, m => m.LogLoss.Mean)); for (int i = 0; i < pfi[0].PerClassLogLoss.Length; i++) { - Assert.Equal(1, MaxDeltaIndex(pfi, m => m.PerClassLogLoss[i])); - Assert.Equal(3, MinDeltaIndex(pfi, m => m.PerClassLogLoss[i])); + Assert.True(MaxDeltaIndex(pfi, m => m.PerClassLogLoss[i].Mean) == 1); + Assert.True(MinDeltaIndex(pfi, m => m.PerClassLogLoss[i].Mean) == 3); } Done(); @@ -227,21 +279,21 @@ public void TestPfiMulticlassClassificationOnSparseFeatures() // X3Important: 5 // Most important // For the following metrics higher is better, so minimum delta means more important feature, and vice versa - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.AccuracyMicro)); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.AccuracyMicro)); - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.AccuracyMacro)); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.AccuracyMacro)); - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.LogLossReduction)); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.LogLossReduction)); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.AccuracyMicro.Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.AccuracyMicro.Mean)); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.AccuracyMacro.Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.AccuracyMacro.Mean)); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.LogLossReduction.Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.LogLossReduction.Mean)); // For the following metrics-delta lower is better, so maximum delta means more important feature, and vice versa // Because they are negative metrics, the _difference_ will be positive for worse classifiers. - Assert.Equal(5, MaxDeltaIndex(pfi, m => m.LogLoss)); - Assert.Equal(2, MinDeltaIndex(pfi, m => m.LogLoss)); + Assert.Equal(5, MaxDeltaIndex(pfi, m => m.LogLoss.Mean)); + Assert.Equal(2, MinDeltaIndex(pfi, m => m.LogLoss.Mean)); for (int i = 0; i < pfi[0].PerClassLogLoss.Length; i++) { - Assert.Equal(5, MaxDeltaIndex(pfi, m => m.PerClassLogLoss[i])); - Assert.Equal(2, MinDeltaIndex(pfi, m => m.PerClassLogLoss[i])); + Assert.Equal(5, MaxDeltaIndex(pfi, m => m.PerClassLogLoss[i].Mean)); + Assert.Equal(2, MinDeltaIndex(pfi, m => m.PerClassLogLoss[i].Mean)); } Done(); @@ -268,13 +320,13 @@ public void TestPfiRankingOnDenseFeatures() // For the following metrics higher is better, so minimum delta means more important feature, and vice versa for (int i = 0; i < pfi[0].Dcg.Length; i++) { - Assert.Equal(0, MaxDeltaIndex(pfi, m => m.Dcg[i])); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.Dcg[i])); + Assert.Equal(0, MaxDeltaIndex(pfi, m => m.Dcg[i].Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.Dcg[i].Mean)); } for (int i = 0; i < pfi[0].Ndcg.Length; i++) { - Assert.Equal(0, MaxDeltaIndex(pfi, m => m.Ndcg[i])); - Assert.Equal(1, MinDeltaIndex(pfi, m => m.Ndcg[i])); + Assert.Equal(0, MaxDeltaIndex(pfi, m => m.Ndcg[i].Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.Ndcg[i].Mean)); } Done(); @@ -301,52 +353,19 @@ public void TestPfiRankingOnSparseFeatures() // For the following metrics higher is better, so minimum delta means more important feature, and vice versa for (int i = 0; i < pfi[0].Dcg.Length; i++) { - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.Dcg[i])); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.Dcg[i])); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.Dcg[i].Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.Dcg[i].Mean)); } for (int i = 0; i < pfi[0].Ndcg.Length; i++) { - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.Ndcg[i])); - Assert.Equal(5, MinDeltaIndex(pfi, m => m.Ndcg[i])); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.Ndcg[i].Mean)); + Assert.Equal(5, MinDeltaIndex(pfi, m => m.Ndcg[i].Mean)); } Done(); } #endregion - #region Clustering Tests - /// - /// Test PFI Clustering for Dense Features - /// - [Fact] - public void TestPfiClusteringOnDenseFeatures() - { - var data = GetDenseClusteringDataset(); - - var preview = data.Preview(); - - var model = ML.Clustering.Trainers.KMeans("Features", clustersCount: 5, - advancedSettings: args =>{ args.NormalizeFeatures = NormalizeOption.No; args.NumThreads = 1; }) - .Fit(data); - var pfi = ML.Clustering.PermutationFeatureImportance(model, data); - - // Pfi Indices: - // X1: 0 -- middling importance for clustering (middling range) - // X2Important: 1 -- most important for clustering (largest range) - // X3: 2 -- Least important for clustering (smallest range) - - // For the following metrics lower is better, so maximum delta means more important feature, and vice versa - Assert.Equal(0, MinDeltaIndex(pfi, m => m.AvgMinScore)); - Assert.Equal(2, MaxDeltaIndex(pfi, m => m.AvgMinScore)); - - // For the following metrics higher is better, so minimum delta means more important feature, and vice versa - Assert.Equal(2, MinDeltaIndex(pfi, m => m.Nmi)); - Assert.Equal(0, MaxDeltaIndex(pfi, m => m.Nmi)); - - Done(); - } - #endregion - #region Helpers /// /// Features: x1, x2, x3, xRand; y = 10*x1 + 20x2 + 5.5x3 + e, xRand- random and Label y is to dependant on xRand. @@ -355,8 +374,6 @@ public void TestPfiClusteringOnDenseFeatures() /// private IDataView GetDenseDataset(TaskType task = TaskType.Regression) { - Contracts.Assert(task != TaskType.Clustering, $"TaskType {nameof(TaskType.Clustering)} not supported."); - // Setup synthetic dataset. const int numberOfInstances = 1000; var rand = new Random(10); @@ -487,59 +504,6 @@ private IDataView GetSparseDataset(TaskType task = TaskType.Regression) return pipeline.Fit(srcDV).Transform(srcDV); } - /// - /// Features: x1, x2, x3, xRand; y = 10*x1 + 20x2 + 5.5x3 + e, xRand- random and Label y is to dependant on xRand. - /// xRand has the least importance: Evaluation metrics do not change a lot when xRand is permuted. - /// x2 has the biggest importance. - /// - private IDataView GetDenseClusteringDataset() - { - // Define the cluster centers - const int clusterCount = 5; - float[][] clusterCenters = new float[clusterCount][]; - for (int i = 0; i < clusterCount; i++) - { - clusterCenters[i] = new float[3] { i, i, i }; - } - - // Create rows of data sampled from these clusters - const int numberOfInstances = 1000; - var rand = new Random(10); - - // The cluster spacing is 1 - float x1Scale = 0.01f; - float x2Scale = 0.1f; - float x3Scale = 1f; - - float[] yArray = new float[numberOfInstances]; - float[] x1Array = new float[numberOfInstances]; - float[] x2Array = new float[numberOfInstances]; - float[] x3Array = new float[numberOfInstances]; - - for (var i = 0; i < numberOfInstances; i++) - { - // Assign a cluster - var clusterLabel = rand.Next(clusterCount); - yArray[i] = clusterLabel; - - x1Array[i] = clusterCenters[clusterLabel][0] + x1Scale * (float)(rand.NextDouble() - 0.5); - x2Array[i] = clusterCenters[clusterLabel][1] + x2Scale * (float)(rand.NextDouble() - 0.5); - x3Array[i] = clusterCenters[clusterLabel][2] + x3Scale * (float)(rand.NextDouble() - 0.5); - } - - // Create data view. - var bldr = new ArrayDataViewBuilder(Env); - bldr.AddColumn("Label", NumberType.Float, yArray); - bldr.AddColumn("X1", NumberType.Float, x1Array); - bldr.AddColumn("X2", NumberType.Float, x2Array); - bldr.AddColumn("X3", NumberType.Float, x3Array); - var srcDV = bldr.GetDataView(); - - var pipeline = ML.Transforms.Concatenate("Features", "X1", "X2", "X3"); - - return pipeline.Fit(srcDV).Transform(srcDV); - } - private int MinDeltaIndex( ImmutableArray metricsDelta, Func metricSelector) @@ -610,9 +574,8 @@ private enum TaskType Regression, BinaryClassification, MulticlassClassification, - Ranking, - Clustering + Ranking } #endregion } -} +} \ No newline at end of file