Skip to content

Commit 53b748d

Browse files
authored
Add Cluster evaluator (#316)
* Add Cluster evaluator * fix copypaste * address comments * formatting
1 parent 1bb1249 commit 53b748d

File tree

6 files changed

+197
-7
lines changed

6 files changed

+197
-7
lines changed

src/Microsoft.ML/Models/ClassificationEvaluator.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLo
5757
IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
5858
if (overallMetrics == null)
5959
{
60-
throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
60+
throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClassificationEvaluator)} Evaluate.");
6161
}
6262

6363
IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix);
6464
if (confusionMatrix == null)
6565
{
66-
throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
66+
throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(ClassificationEvaluator)} Evaluate.");
6767
}
6868

6969
var metric = ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Transforms;
8+
9+
namespace Microsoft.ML.Models
10+
{
11+
public sealed partial class ClusterEvaluator
12+
{
13+
/// <summary>
14+
/// Computes the quality metrics for the PredictionModel using the specified data set.
15+
/// </summary>
16+
/// <param name="model">
17+
/// The trained PredictionModel to be evaluated.
18+
/// </param>
19+
/// <param name="testData">
20+
/// The test data that will be predicted and used to evaluate the model.
21+
/// </param>
22+
/// <returns>
23+
/// A ClusterMetrics instance that describes how well the model performed against the test data.
24+
/// </returns>
25+
public ClusterMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
26+
{
27+
using (var environment = new TlcEnvironment())
28+
{
29+
environment.CheckValue(model, nameof(model));
30+
environment.CheckValue(testData, nameof(testData));
31+
32+
Experiment experiment = environment.CreateExperiment();
33+
34+
ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment);
35+
if (!(testDataStep is ILearningPipelineDataStep testDataOutput))
36+
{
37+
throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep.");
38+
}
39+
40+
var datasetScorer = new DatasetTransformScorer
41+
{
42+
Data = testDataOutput.Data,
43+
};
44+
DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer);
45+
46+
Data = scoreOutput.ScoredData;
47+
Output evaluteOutput = experiment.Add(this);
48+
49+
experiment.Compile();
50+
51+
experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel);
52+
testData.SetInput(environment, experiment);
53+
54+
experiment.Run();
55+
56+
IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
57+
58+
if (overallMetrics == null)
59+
{
60+
throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClusterEvaluator)} Evaluate.");
61+
}
62+
63+
var metric = ClusterMetrics.FromOverallMetrics(environment, overallMetrics);
64+
65+
Contracts.Assert(metric.Count == 1, $"Exactly one metric set was expected but found {metric.Count} metrics");
66+
67+
return metric[0];
68+
}
69+
}
70+
}
71+
}
+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Api;
7+
using Microsoft.ML.Runtime.Data;
8+
using System;
9+
using System.Collections.Generic;
10+
11+
namespace Microsoft.ML.Models
12+
{
13+
/// <summary>
14+
/// This class contains the overall metrics computed by cluster evaluators.
15+
/// </summary>
16+
public sealed class ClusterMetrics
17+
{
18+
private ClusterMetrics()
19+
{
20+
}
21+
22+
internal static List<ClusterMetrics> FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics)
23+
{
24+
Contracts.AssertValue(env);
25+
env.AssertValue(overallMetrics);
26+
27+
var metricsEnumerable = overallMetrics.AsEnumerable<SerializationClass>(env, true, ignoreMissingColumns: true);
28+
if (!metricsEnumerable.GetEnumerator().MoveNext())
29+
{
30+
throw env.Except("The overall ClusteringMetrics didn't have any rows.");
31+
}
32+
33+
var metrics = new List<ClusterMetrics>();
34+
foreach (var metric in metricsEnumerable)
35+
{
36+
metrics.Add(new ClusterMetrics()
37+
{
38+
AvgMinScore = metric.AvgMinScore,
39+
Nmi = metric.Nmi,
40+
Dbi = metric.Dbi,
41+
});
42+
}
43+
44+
return metrics;
45+
}
46+
47+
/// <summary>
48+
/// Davies-Bouldin Index.
49+
/// </summary>
50+
/// <remarks>
51+
/// DBI is a measure of the how much scatter is in the cluster and the cluster separation.
52+
/// </remarks>
53+
public double Dbi { get; private set; }
54+
55+
/// <summary>
56+
/// Normalized Mutual Information
57+
/// </summary>
58+
/// <remarks>
59+
/// NMI is a measure of the mutual dependence between the true and predicted cluster labels for instances in the dataset.
60+
/// NMI ranges between 0 and 1 where "0" indicates clustering is random and "1" indicates clustering is perfect w.r.t true labels.
61+
/// </remarks>
62+
public double Nmi { get; private set; }
63+
64+
/// <summary>
65+
/// Average minimum score.
66+
/// </summary>
67+
/// <remarks>
68+
/// AvgMinScore is the average squared-distance of examples from the respective cluster centroids.
69+
/// It is defined as
70+
/// AvgMinScore = (1/m) * sum ((xi - c(xi))^2)
71+
/// where m is the number of instances in the dataset.
72+
/// xi is the i'th instance and c(xi) is the centriod of the predicted cluster for xi.
73+
/// </remarks>
74+
public double AvgMinScore { get; private set; }
75+
76+
/// <summary>
77+
/// This class contains the public fields necessary to deserialize from IDataView.
78+
/// </summary>
79+
private sealed class SerializationClass
80+
{
81+
#pragma warning disable 649 // never assigned
82+
[ColumnName(Runtime.Data.ClusteringEvaluator.Dbi)]
83+
public Double Dbi;
84+
85+
[ColumnName(Runtime.Data.ClusteringEvaluator.Nmi)]
86+
public Double Nmi;
87+
88+
[ColumnName(Runtime.Data.ClusteringEvaluator.AvgMinScore)]
89+
public Double AvgMinScore;
90+
91+
#pragma warning restore 649 // never assigned
92+
}
93+
}
94+
}

src/Microsoft.ML/Models/CrossValidator.cs

+10-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public sealed partial class CrossValidator
1919
/// <typeparam name="TOutput">Class type that represents prediction schema.</typeparam>
2020
/// <param name="pipeline">Machine learning pipeline may contain loader, transforms and at least one trainer.</param>
2121
/// <returns>List containing metrics and predictor model for each fold</returns>
22-
public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(LearningPipeline pipeline)
22+
public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(LearningPipeline pipeline)
2323
where TInput : class
2424
where TOutput : class, new()
2525
{
@@ -76,7 +76,7 @@ public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(Lea
7676
{
7777
PredictorModel = predictorModel
7878
};
79-
79+
8080
var scorerOutput = subGraph.Add(scorer);
8181
lastTransformModel = scorerOutput.ScoringTransform;
8282
step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform);
@@ -129,7 +129,7 @@ public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(Lea
129129
experiment.GetOutput(crossValidateOutput.OverallMetrics),
130130
experiment.GetOutput(crossValidateOutput.ConfusionMatrix), 2);
131131
}
132-
else if(Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer)
132+
else if (Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer)
133133
{
134134
cvOutput.ClassificationMetrics = ClassificationMetrics.FromMetrics(
135135
environment,
@@ -142,6 +142,12 @@ public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(Lea
142142
environment,
143143
experiment.GetOutput(crossValidateOutput.OverallMetrics));
144144
}
145+
else if (Kind == MacroUtilsTrainerKinds.SignatureClusteringTrainer)
146+
{
147+
cvOutput.ClusterMetrics = ClusterMetrics.FromOverallMetrics(
148+
environment,
149+
experiment.GetOutput(crossValidateOutput.OverallMetrics));
150+
}
145151
else
146152
{
147153
//Implement metrics for ranking, clustering and anomaly detection.
@@ -174,6 +180,7 @@ public class CrossValidationOutput<TInput, TOutput>
174180
public List<BinaryClassificationMetrics> BinaryClassificationMetrics;
175181
public List<ClassificationMetrics> ClassificationMetrics;
176182
public List<RegressionMetrics> RegressionMetrics;
183+
public List<ClusterMetrics> ClusterMetrics;
177184
public PredictionModel<TInput, TOutput>[] PredictorModels;
178185

179186
//REVIEW: Add warnings and per instance results and implement

src/Microsoft.ML/Models/TrainTestEvaluator.cs

+9-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public TrainTestEvaluatorOutput<TInput, TOutput> TrainTestEvaluate<TInput, TOutp
102102
}
103103

104104
var experiment = environment.CreateExperiment();
105-
105+
106106
TrainingData = (loaders[0].ApplyStep(null, experiment) as ILearningPipelineDataStep).Data;
107107
TestingData = (testData.ApplyStep(null, experiment) as ILearningPipelineDataStep).Data;
108108
Nodes = subGraph;
@@ -140,6 +140,12 @@ public TrainTestEvaluatorOutput<TInput, TOutput> TrainTestEvaluate<TInput, TOutp
140140
environment,
141141
experiment.GetOutput(crossValidateOutput.OverallMetrics)).FirstOrDefault();
142142
}
143+
else if (Kind == MacroUtilsTrainerKinds.SignatureClusteringTrainer)
144+
{
145+
trainTestOutput.ClusterMetrics = ClusterMetrics.FromOverallMetrics(
146+
environment,
147+
experiment.GetOutput(crossValidateOutput.OverallMetrics)).FirstOrDefault();
148+
}
143149
else
144150
{
145151
//Implement metrics for ranking, clustering and anomaly detection.
@@ -158,7 +164,7 @@ public TrainTestEvaluatorOutput<TInput, TOutput> TrainTestEvaluate<TInput, TOutp
158164

159165
trainTestOutput.PredictorModels = new PredictionModel<TInput, TOutput>(predictor, memoryStream);
160166
}
161-
167+
162168
return trainTestOutput;
163169
}
164170
}
@@ -171,6 +177,7 @@ public class TrainTestEvaluatorOutput<TInput, TOutput>
171177
public BinaryClassificationMetrics BinaryClassificationMetrics;
172178
public ClassificationMetrics ClassificationMetrics;
173179
public RegressionMetrics RegressionMetrics;
180+
public ClusterMetrics ClusterMetrics;
174181
public PredictionModel<TInput, TOutput> PredictorModels;
175182

176183
//REVIEW: Add warnings and per instance results and implement

test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs

+11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Microsoft.ML.Data;
2+
using Microsoft.ML.Models;
23
using Microsoft.ML.Runtime;
34
using Microsoft.ML.Runtime.Api;
45
using Microsoft.ML.Trainers;
@@ -116,6 +117,16 @@ public void PredictClusters()
116117
Assert.True(!labels.Contains(scores.SelectedClusterId));
117118
labels.Add(scores.SelectedClusterId);
118119
}
120+
121+
var evaluator = new ClusterEvaluator();
122+
var testData = CollectionDataSource.Create(clusters);
123+
ClusterMetrics metrics = evaluator.Evaluate(model, testData);
124+
125+
//Label is not specified, so NMI would be equal to NaN
126+
Assert.Equal(metrics.Nmi, double.NaN);
127+
//Calculate dbi is false by default so Dbi would be 0
128+
Assert.Equal(metrics.Dbi, (double)0.0);
129+
Assert.Equal(metrics.AvgMinScore, (double)0.0, 5);
119130
}
120131
}
121132
}

0 commit comments

Comments
 (0)