|
2 | 2 | // The .NET Foundation licenses this file to you under the MIT license.
|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
| 5 | +using System; |
5 | 6 | using System.Collections.Generic;
|
6 | 7 | using System.Collections.Immutable;
|
| 8 | +using System.Linq; |
7 | 9 | using Microsoft.ML.Runtime;
|
8 | 10 |
|
9 | 11 | namespace Microsoft.ML.Data
|
@@ -71,16 +73,22 @@ public sealed class MulticlassClassificationMetrics
|
71 | 73 | public double MicroAccuracy { get; }
|
72 | 74 |
|
73 | 75 | /// <summary>
|
74 |
| - /// If <see cref="TopKPredictionCount"/> is positive, this is the relative number of examples where |
75 |
| - /// the true label is one of the top-k predicted labels by the predictor. |
| 76 | + /// Convenience method for "TopKAccuracyForAllK[TopKPredictionCount - 1]". If <see cref="TopKPredictionCount"/> is positive, |
| 77 | + /// this is the relative number of examples where |
| 78 | + /// the true label is one of the top K predicted labels by the predictor. |
76 | 79 | /// </summary>
|
77 |
| - public double TopKAccuracy { get; } |
| 80 | + public double TopKAccuracy => TopKAccuracyForAllK?.LastOrDefault() ?? 0; |
78 | 81 |
|
79 | 82 | /// <summary>
|
80 |
| - /// If positive, this indicates the K in <see cref="TopKAccuracy"/>. |
| 83 | + /// If positive, this indicates the K in <see cref="TopKAccuracy"/> and <see cref="TopKAccuracyForAllK"/>. |
81 | 84 | /// </summary>
|
82 | 85 | public int TopKPredictionCount { get; }
|
83 | 86 |
|
| 87 | + /// <summary> |
| 88 | + /// Returns the top K accuracy for all K from 1 to the value of TopKPredictionCount. |
| 89 | + /// </summary> |
| 90 | + public IReadOnlyList<double> TopKAccuracyForAllK { get; } |
| 91 | + |
84 | 92 | /// <summary>
|
85 | 93 | /// Gets the log-loss of the classifier for each class. Log-loss measures the performance of a classifier
|
86 | 94 | /// with respect to how much the predicted probabilities diverge from the true class label. Lower
|
@@ -115,29 +123,30 @@ internal MulticlassClassificationMetrics(IHost host, DataViewRow overallResult,
|
115 | 123 | LogLoss = FetchDouble(MulticlassClassificationEvaluator.LogLoss);
|
116 | 124 | LogLossReduction = FetchDouble(MulticlassClassificationEvaluator.LogLossReduction);
|
117 | 125 | TopKPredictionCount = topKPredictionCount;
|
| 126 | + |
118 | 127 | if (topKPredictionCount > 0)
|
119 |
| - TopKAccuracy = FetchDouble(MulticlassClassificationEvaluator.TopKAccuracy); |
| 128 | + TopKAccuracyForAllK = RowCursorUtils.Fetch<VBuffer<double>>(host, overallResult, MulticlassClassificationEvaluator.AllTopKAccuracy).DenseValues().ToImmutableArray(); |
120 | 129 |
|
121 | 130 | var perClassLogLoss = RowCursorUtils.Fetch<VBuffer<double>>(host, overallResult, MulticlassClassificationEvaluator.PerClassLogLoss);
|
122 | 131 | PerClassLogLoss = perClassLogLoss.DenseValues().ToImmutableArray();
|
123 | 132 | ConfusionMatrix = MetricWriter.GetConfusionMatrix(host, confusionMatrix, binary: false, perClassLogLoss.Length);
|
124 | 133 | }
|
125 | 134 |
|
126 | 135 | internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction,
|
127 |
| - int topKPredictionCount, double topKAccuracy, double[] perClassLogLoss) |
| 136 | + int topKPredictionCount, double[] topKAccuracies, double[] perClassLogLoss) |
128 | 137 | {
|
129 | 138 | MicroAccuracy = accuracyMicro;
|
130 | 139 | MacroAccuracy = accuracyMacro;
|
131 | 140 | LogLoss = logLoss;
|
132 | 141 | LogLossReduction = logLossReduction;
|
133 | 142 | TopKPredictionCount = topKPredictionCount;
|
134 |
| - TopKAccuracy = topKAccuracy; |
| 143 | + TopKAccuracyForAllK = topKAccuracies; |
135 | 144 | PerClassLogLoss = perClassLogLoss.ToImmutableArray();
|
136 | 145 | }
|
137 | 146 |
|
138 | 147 | internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction,
|
139 |
| - int topKPredictionCount, double topKAccuracy, double[] perClassLogLoss, ConfusionMatrix confusionMatrix) |
140 |
| - : this(accuracyMicro, accuracyMacro, logLoss, logLossReduction, topKPredictionCount, topKAccuracy, perClassLogLoss) |
| 148 | + int topKPredictionCount, double[] topKAccuracies, double[] perClassLogLoss, ConfusionMatrix confusionMatrix) |
| 149 | + : this(accuracyMicro, accuracyMacro, logLoss, logLossReduction, topKPredictionCount, topKAccuracies, perClassLogLoss) |
141 | 150 | {
|
142 | 151 | ConfusionMatrix = confusionMatrix;
|
143 | 152 | }
|
|
0 commit comments