Skip to content

Commit 6a413ed

Browse files
jasallenantoniovs1029justinormont
authored
Perf improvement for TopK Accuracy and return all topK in Classification Evaluator (#5395)
* Fix for issue 744 * cleanup * fixing report output * fixedTestReferenceOutputs * Fixed test reference outputs for NetCore31 * change top k acc output string format * Ranking algorithm now uses first appearance in dataset rather than worstCase * fixed benchmark * various minor changes from code review * limit TopK to OutputTopKAcc parameter * top k output name changes * make old TopK readOnly * restored old baselineOutputs since respecting outputTopK param means no topK in most test output * fix test fails, re-add names parameter * Clean up commented code * that'll teach me to edit from the github webpage * use existing method, fix nits * Slight comment change * Comment change / Touch to kick off build pipeline * fix whitespace * Added new test * Code formatting nits * Code formatting nit * Fixed undefined rankofCorrectLabel and trailing whitespace warning * Removed _numUnknownClassInstances and added test for unknown labels * Add weight to seenRanks * Nits * Removed FastTree import Co-authored-by: Antonio Velazquez <[email protected]> Co-authored-by: Justin Ormont <[email protected]>
1 parent 0c6238e commit 6a413ed

File tree

9 files changed

+203
-63
lines changed

9 files changed

+203
-63
lines changed

src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetric
123123
logLoss: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLoss)),
124124
logLossReduction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLossReduction)),
125125
topKPredictionCount: newMetrics.ElementAt(0).TopKPredictionCount,
126-
topKAccuracy: GetAverageOfNonNaNScores(newMetrics.Select(x => x.TopKAccuracy)),
127-
// Return PerClassLogLoss and ConfusionMatrix from the fold closest to average score
126+
topKAccuracies: GetAverageOfNonNaNScoresInNestedEnumerable(newMetrics.Select(x => x.TopKAccuracyForAllK)),
128127
perClassLogLoss: (metricsClosestToAvg as MulticlassClassificationMetrics).PerClassLogLoss.ToArray(),
129128
confusionMatrix: (metricsClosestToAvg as MulticlassClassificationMetrics).ConfusionMatrix);
130129
return result as TMetrics;
@@ -163,7 +162,6 @@ private static double[] GetAverageOfNonNaNScoresInNestedEnumerable(IEnumerable<I
163162
double[] arr = new double[results.ElementAt(0).Count()];
164163
for (int i = 0; i < arr.Length; i++)
165164
{
166-
Contracts.Assert(arr.Length == results.ElementAt(i).Count());
167165
arr[i] = GetAverageOfNonNaNScores(results.Select(x => x.ElementAt(i)));
168166
}
169167
return arr;

src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs

+7-1
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,13 @@ private static List<string> GetMetricNames(IChannel ch, DataViewSchema schema, D
10351035
names = editor.Commit();
10361036
}
10371037
foreach (var name in names.Items(all: true))
1038-
metricNames.Add(string.Format("{0}{1}", metricName, name.Value));
1038+
{
1039+
var tryNaming = string.Format(metricName, name.Value);
1040+
if (tryNaming == metricName) // metricName wasn't a format string, so just append slotname
1041+
tryNaming = (string.Format("{0}{1}", metricName, name.Value));
1042+
1043+
metricNames.Add(tryNaming);
1044+
}
10391045
}
10401046
}
10411047
ch.Assert(metricNames.Count == metricCount);

src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs

+18-9
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Generic;
67
using System.Collections.Immutable;
8+
using System.Linq;
79
using Microsoft.ML.Runtime;
810

911
namespace Microsoft.ML.Data
@@ -71,16 +73,22 @@ public sealed class MulticlassClassificationMetrics
7173
public double MicroAccuracy { get; }
7274

7375
/// <summary>
74-
/// If <see cref="TopKPredictionCount"/> is positive, this is the relative number of examples where
75-
/// the true label is one of the top-k predicted labels by the predictor.
76+
/// Convenience method for "TopKAccuracyForAllK[TopKPredictionCount - 1]". If <see cref="TopKPredictionCount"/> is positive,
77+
/// this is the relative number of examples where
78+
/// the true label is one of the top K predicted labels by the predictor.
7679
/// </summary>
77-
public double TopKAccuracy { get; }
80+
public double TopKAccuracy => TopKAccuracyForAllK?.LastOrDefault() ?? 0;
7881

7982
/// <summary>
80-
/// If positive, this indicates the K in <see cref="TopKAccuracy"/>.
83+
/// If positive, this indicates the K in <see cref="TopKAccuracy"/> and <see cref="TopKAccuracyForAllK"/>.
8184
/// </summary>
8285
public int TopKPredictionCount { get; }
8386

87+
/// <summary>
88+
/// Returns the top K accuracy for all K from 1 to the value of TopKPredictionCount.
89+
/// </summary>
90+
public IReadOnlyList<double> TopKAccuracyForAllK { get; }
91+
8492
/// <summary>
8593
/// Gets the log-loss of the classifier for each class. Log-loss measures the performance of a classifier
8694
/// with respect to how much the predicted probabilities diverge from the true class label. Lower
@@ -115,29 +123,30 @@ internal MulticlassClassificationMetrics(IHost host, DataViewRow overallResult,
115123
LogLoss = FetchDouble(MulticlassClassificationEvaluator.LogLoss);
116124
LogLossReduction = FetchDouble(MulticlassClassificationEvaluator.LogLossReduction);
117125
TopKPredictionCount = topKPredictionCount;
126+
118127
if (topKPredictionCount > 0)
119-
TopKAccuracy = FetchDouble(MulticlassClassificationEvaluator.TopKAccuracy);
128+
TopKAccuracyForAllK = RowCursorUtils.Fetch<VBuffer<double>>(host, overallResult, MulticlassClassificationEvaluator.AllTopKAccuracy).DenseValues().ToImmutableArray();
120129

121130
var perClassLogLoss = RowCursorUtils.Fetch<VBuffer<double>>(host, overallResult, MulticlassClassificationEvaluator.PerClassLogLoss);
122131
PerClassLogLoss = perClassLogLoss.DenseValues().ToImmutableArray();
123132
ConfusionMatrix = MetricWriter.GetConfusionMatrix(host, confusionMatrix, binary: false, perClassLogLoss.Length);
124133
}
125134

126135
internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction,
127-
int topKPredictionCount, double topKAccuracy, double[] perClassLogLoss)
136+
int topKPredictionCount, double[] topKAccuracies, double[] perClassLogLoss)
128137
{
129138
MicroAccuracy = accuracyMicro;
130139
MacroAccuracy = accuracyMacro;
131140
LogLoss = logLoss;
132141
LogLossReduction = logLossReduction;
133142
TopKPredictionCount = topKPredictionCount;
134-
TopKAccuracy = topKAccuracy;
143+
TopKAccuracyForAllK = topKAccuracies;
135144
PerClassLogLoss = perClassLogLoss.ToImmutableArray();
136145
}
137146

138147
internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction,
139-
int topKPredictionCount, double topKAccuracy, double[] perClassLogLoss, ConfusionMatrix confusionMatrix)
140-
: this(accuracyMicro, accuracyMacro, logLoss, logLossReduction, topKPredictionCount, topKAccuracy, perClassLogLoss)
148+
int topKPredictionCount, double[] topKAccuracies, double[] perClassLogLoss, ConfusionMatrix confusionMatrix)
149+
: this(accuracyMicro, accuracyMacro, logLoss, logLossReduction, topKPredictionCount, topKAccuracies, perClassLogLoss)
141150
{
142151
ConfusionMatrix = confusionMatrix;
143152
}

0 commit comments

Comments
 (0)