Skip to content

Commit 627ad79

Browse files
authored
Static pipeline column indexers, binary/regression evaluators (#869)
Static pipeline column indexers, binary/regression evaluators * Infrastructure for column indexers * Binary classification and regression evaluators
1 parent ef5dbc5 commit 627ad79

File tree

8 files changed

+548
-27
lines changed

8 files changed

+548
-27
lines changed

src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs

Lines changed: 251 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Linq;
8+
using Microsoft.ML.Data.StaticPipe;
9+
using Microsoft.ML.Data.StaticPipe.Runtime;
810
using Microsoft.ML.Runtime;
911
using Microsoft.ML.Runtime.CommandLine;
1012
using Microsoft.ML.Runtime.Data;
@@ -410,66 +412,58 @@ public sealed class Counters
410412

411413
public Double Acc
412414
{
413-
get
414-
{
415+
get {
415416
return (NumTrueNeg + NumTruePos) / (NumTruePos + NumTrueNeg + NumFalseNeg + NumFalsePos);
416417
}
417418
}
418419

419420
public Double RecallPos
420421
{
421-
get
422-
{
422+
get {
423423
return (NumTruePos + NumFalseNeg > 0) ? NumTruePos / (NumTruePos + NumFalseNeg) : 0;
424424
}
425425
}
426426

427427
public Double PrecisionPos
428428
{
429-
get
430-
{
429+
get {
431430
return (NumTruePos + NumFalsePos > 0) ? NumTruePos / (NumTruePos + NumFalsePos) : 0;
432431
}
433432
}
434433

435434
public Double RecallNeg
436435
{
437-
get
438-
{
436+
get {
439437
return (NumTrueNeg + NumFalsePos > 0) ? NumTrueNeg / (NumTrueNeg + NumFalsePos) : 0;
440438
}
441439
}
442440

443441
public Double PrecisionNeg
444442
{
445-
get
446-
{
443+
get {
447444
return (NumTrueNeg + NumFalseNeg > 0) ? NumTrueNeg / (NumTrueNeg + NumFalseNeg) : 0;
448445
}
449446
}
450447

451448
public Double Entropy
452449
{
453-
get
454-
{
450+
get {
455451
return MathUtils.Entropy((NumTruePos + NumFalseNeg) /
456452
(NumTruePos + NumTrueNeg + NumFalseNeg + NumFalsePos));
457453
}
458454
}
459455

460456
public Double LogLoss
461457
{
462-
get
463-
{
458+
get {
464459
return Double.IsNaN(_logLoss) ? Double.NaN : (_numLogLossPositives + _numLogLossNegatives > 0)
465460
? _logLoss / (_numLogLossPositives + _numLogLossNegatives) : 0;
466461
}
467462
}
468463

469464
public Double LogLossReduction
470465
{
471-
get
472-
{
466+
get {
473467
if (_numLogLossPositives + _numLogLossNegatives == 0)
474468
return 0;
475469
var logLoss = _logLoss / (_numLogLossPositives + _numLogLossNegatives);
@@ -787,6 +781,246 @@ private void ComputePrCurves()
787781
}
788782
}
789783
}
784+
785+
/// <summary>
786+
/// Evaluation results for binary classifiers, excluding probabilistic metrics.
787+
/// </summary>
788+
public class Result
789+
{
790+
/// <summary>
791+
/// Gets the area under the ROC curve.
792+
/// </summary>
793+
/// <remarks>
794+
/// The area under the ROC curve is equal to the probability that the classifier ranks
795+
/// a randomly chosen positive instance higher than a randomly chosen negative one
796+
/// (assuming 'positive' ranks higher than 'negative').
797+
/// </remarks>
798+
public double Auc { get; }
799+
800+
/// <summary>
801+
/// Gets the accuracy of a classifier which is the proportion of correct predictions in the test set.
802+
/// </summary>
803+
public double Accuracy { get; }
804+
805+
/// <summary>
806+
/// Gets the positive precision of a classifier which is the proportion of correctly predicted
807+
/// positive instances among all the positive predictions (i.e., the number of positive instances
808+
/// predicted as positive, divided by the total number of instances predicted as positive).
809+
/// </summary>
810+
public double PositivePrecision { get; }
811+
812+
/// <summary>
813+
/// Gets the positive recall of a classifier which is the proportion of correctly predicted
814+
/// positive instances among all the positive instances (i.e., the number of positive instances
815+
/// predicted as positive, divided by the total number of positive instances).
816+
/// </summary>
817+
public double PositiveRecall { get; private set; }
818+
819+
/// <summary>
820+
/// Gets the negative precision of a classifier which is the proportion of correctly predicted
821+
/// negative instances among all the negative predictions (i.e., the number of negative instances
822+
/// predicted as negative, divided by the total number of instances predicted as negative).
823+
/// </summary>
824+
public double NegativePrecision { get; }
825+
826+
/// <summary>
827+
/// Gets the negative recall of a classifier which is the proportion of correctly predicted
828+
/// negative instances among all the negative instances (i.e., the number of negative instances
829+
/// predicted as negative, divided by the total number of negative instances).
830+
/// </summary>
831+
public double NegativeRecall { get; }
832+
833+
/// <summary>
834+
/// Gets the F1 score of the classifier.
835+
/// </summary>
836+
/// <remarks>
837+
/// F1 score is the harmonic mean of precision and recall: 2 * precision * recall / (precision + recall).
838+
/// </remarks>
839+
public double F1Score { get; }
840+
841+
/// <summary>
842+
/// Gets the area under the precision/recall curve of the classifier.
843+
/// </summary>
844+
/// <remarks>
845+
/// The area under the precision/recall curve is a single number summary of the information in the
846+
/// precision/recall curve. It is increasingly used in the machine learning community, particularly
847+
/// for imbalanced datasets where one class is observed more frequently than the other. On these
848+
/// datasets, AUPRC can highlight performance differences that are lost with AUC.
849+
/// </remarks>
850+
public double Auprc { get; }
851+
852+
protected private static T Fetch<T>(IExceptionContext ectx, IRow row, string name)
853+
{
854+
if (!row.Schema.TryGetColumnIndex(name, out int col))
855+
throw ectx.Except($"Could not find column '{name}'");
856+
T val = default;
857+
row.GetGetter<T>(col)(ref val);
858+
return val;
859+
}
860+
861+
internal Result(IExceptionContext ectx, IRow overallResult)
862+
{
863+
double Fetch(string name) => Fetch<double>(ectx, overallResult, name);
864+
Auc = Fetch(BinaryClassifierEvaluator.Auc);
865+
Accuracy = Fetch(BinaryClassifierEvaluator.Accuracy);
866+
PositivePrecision = Fetch(BinaryClassifierEvaluator.PosPrecName);
867+
PositiveRecall = Fetch(BinaryClassifierEvaluator.PosRecallName);
868+
NegativePrecision = Fetch(BinaryClassifierEvaluator.NegPrecName);
869+
NegativeRecall = Fetch(BinaryClassifierEvaluator.NegRecallName);
870+
F1Score = Fetch(BinaryClassifierEvaluator.F1);
871+
Auprc = Fetch(BinaryClassifierEvaluator.AuPrc);
872+
}
873+
}
874+
875+
/// <summary>
876+
/// Evaluation results for binary classifiers, including probabilistic metrics.
877+
/// </summary>
878+
public sealed class CalibratedResult : Result
879+
{
880+
/// <summary>
881+
/// Gets the log-loss of the classifier.
882+
/// </summary>
883+
/// <remarks>
884+
/// The log-loss metric, is computed as follows:
885+
/// LL = - (1/m) * sum( log(p[i]))
886+
/// where m is the number of instances in the test set.
887+
/// p[i] is the probability returned by the classifier if the instance belongs to class 1,
888+
/// and 1 minus the probability returned by the classifier if the instance belongs to class 0.
889+
/// </remarks>
890+
public double LogLoss { get; }
891+
892+
/// <summary>
893+
/// Gets the log-loss reduction (also known as relative log-loss, or reduction in information gain - RIG)
894+
/// of the classifier.
895+
/// </summary>
896+
/// <remarks>
897+
/// The log-loss reduction is scaled relative to a classifier that predicts the prior for every example:
898+
/// (LL(prior) - LL(classifier)) / LL(prior)
899+
/// This metric can be interpreted as the advantage of the classifier over a random prediction.
900+
/// E.g., if the RIG equals 20, it can be interpreted as &quot;the probability of a correct prediction is
901+
/// 20% better than random guessing.&quot;
902+
/// </remarks>
903+
public double LogLossReduction { get; }
904+
905+
/// <summary>
906+
/// Gets the test-set entropy (prior Log-Loss/instance) of the classifier.
907+
/// </summary>
908+
public double Entropy { get; }
909+
910+
internal CalibratedResult(IExceptionContext ectx, IRow overallResult)
911+
: base(ectx, overallResult)
912+
{
913+
double Fetch(string name) => Fetch<double>(ectx, overallResult, name);
914+
LogLoss = Fetch(BinaryClassifierEvaluator.LogLoss);
915+
LogLossReduction = Fetch(BinaryClassifierEvaluator.LogLossReduction);
916+
Entropy = Fetch(BinaryClassifierEvaluator.Entropy);
917+
}
918+
}
919+
920+
/// <summary>
921+
/// Evaluates scored binary classification data.
922+
/// </summary>
923+
/// <typeparam name="T">The shape type for the input data.</typeparam>
924+
/// <param name="data">The data to evaluate.</param>
925+
/// <param name="label">The index delegate for the label column.</param>
926+
/// <param name="pred">The index delegate for columns from calibrated prediction of a binary classifier.
927+
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
928+
/// <returns>The evaluation results for these calibrated outputs.</returns>
929+
public static CalibratedResult Evaluate<T>(
930+
DataView<T> data,
931+
Func<T, Scalar<bool>> label,
932+
Func<T, (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel)> pred)
933+
{
934+
Contracts.CheckValue(data, nameof(data));
935+
var env = StaticPipeUtils.GetEnvironment(data);
936+
Contracts.AssertValue(env);
937+
env.CheckValue(label, nameof(label));
938+
env.CheckValue(pred, nameof(pred));
939+
940+
var indexer = StaticPipeUtils.GetIndexer(data);
941+
string labelName = indexer.Get(label(indexer.Indices));
942+
(var scoreCol, var probCol, var predCol) = pred(indexer.Indices);
943+
Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column.");
944+
Contracts.CheckParam(probCol != null, nameof(pred), "Indexing delegate resulted in null probability column.");
945+
Contracts.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column.");
946+
string scoreName = indexer.Get(scoreCol);
947+
string probName = indexer.Get(probCol);
948+
string predName = indexer.Get(predCol);
949+
950+
var eval = new BinaryClassifierEvaluator(env, new Arguments() { });
951+
952+
var roles = new RoleMappedData(data.AsDynamic, opt: false,
953+
RoleMappedSchema.ColumnRole.Label.Bind(labelName),
954+
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreName),
955+
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probName),
956+
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.PredictedLabel, predName));
957+
958+
var resultDict = eval.Evaluate(roles);
959+
env.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
960+
var overall = resultDict[MetricKinds.OverallMetrics];
961+
962+
CalibratedResult result;
963+
using (var cursor = overall.GetRowCursor(i => true))
964+
{
965+
var moved = cursor.MoveNext();
966+
env.Assert(moved);
967+
result = new CalibratedResult(env, cursor);
968+
moved = cursor.MoveNext();
969+
env.Assert(!moved);
970+
}
971+
return result;
972+
}
973+
974+
/// <summary>
975+
/// Evaluates scored binary classification data.
976+
/// </summary>
977+
/// <typeparam name="T">The shape type for the input data.</typeparam>
978+
/// <param name="data">The data to evaluate.</param>
979+
/// <param name="label">The index delegate for the label column.</param>
980+
/// <param name="pred">The index delegate for columns from calibrated prediction of a binary classifier.
981+
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
982+
/// <returns>The evaluation results for these uncalibrated outputs.</returns>
983+
public static Result Evaluate<T>(
984+
DataView<T> data,
985+
Func<T, Scalar<bool>> label,
986+
Func<T, (Scalar<float> score, Scalar<bool> predictedLabel)> pred)
987+
{
988+
Contracts.CheckValue(data, nameof(data));
989+
var env = StaticPipeUtils.GetEnvironment(data);
990+
Contracts.AssertValue(env);
991+
env.CheckValue(label, nameof(label));
992+
env.CheckValue(pred, nameof(pred));
993+
994+
var indexer = StaticPipeUtils.GetIndexer(data);
995+
string labelName = indexer.Get(label(indexer.Indices));
996+
(var scoreCol, var predCol) = pred(indexer.Indices);
997+
Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column.");
998+
Contracts.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column.");
999+
string scoreName = indexer.Get(scoreCol);
1000+
string predName = indexer.Get(predCol);
1001+
1002+
var eval = new BinaryClassifierEvaluator(env, new Arguments() { });
1003+
1004+
var roles = new RoleMappedData(data.AsDynamic, opt: false,
1005+
RoleMappedSchema.ColumnRole.Label.Bind(labelName),
1006+
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreName),
1007+
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.PredictedLabel, predName));
1008+
1009+
var resultDict = eval.Evaluate(roles);
1010+
env.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
1011+
var overall = resultDict[MetricKinds.OverallMetrics];
1012+
1013+
Result result;
1014+
using (var cursor = overall.GetRowCursor(i => true))
1015+
{
1016+
var moved = cursor.MoveNext();
1017+
env.Assert(moved);
1018+
result = new Result(env, cursor);
1019+
moved = cursor.MoveNext();
1020+
env.Assert(!moved);
1021+
}
1022+
return result;
1023+
}
7901024
}
7911025

7921026
public sealed class BinaryPerInstanceEvaluator : PerInstanceEvaluatorBase
@@ -1526,4 +1760,4 @@ private static IDataView ExtractConfusionMatrix(IHost host, Dictionary<string, I
15261760
return confusionMatrix;
15271761
}
15281762
}
1529-
}
1763+
}

0 commit comments

Comments
 (0)