Skip to content

Commit 610ffcb

Browse files
authored
Exposing the confusion matrix (#3250)
* Exposing the confusion matrix
1 parent d987294 commit 610ffcb

File tree

11 files changed

+323
-59
lines changed

11 files changed

+323
-59
lines changed

src/Microsoft.ML.Core/Data/AnnotationUtils.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co
441441
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
442442
{
443443
var cols = new List<SchemaShape.Column>();
444-
if (labelColumn.HasValue && labelColumn.Value.IsKey)
444+
if (labelColumn != null && labelColumn.Value.IsKey)
445445
{
446446
if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) &&
447447
metaCol.Kind == SchemaShape.Column.VectorKind.Vector)

src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs

+11-5
Original file line numberDiff line numberDiff line change
@@ -815,16 +815,18 @@ public CalibratedBinaryClassificationMetrics Evaluate(IDataView data, string lab
815815
var resultDict = ((IEvaluator)this).Evaluate(roles);
816816
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
817817
var overall = resultDict[MetricKinds.OverallMetrics];
818+
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];
818819

819820
CalibratedBinaryClassificationMetrics result;
820821
using (var cursor = overall.GetRowCursorForAllColumns())
821822
{
822823
var moved = cursor.MoveNext();
823824
Host.Assert(moved);
824-
result = new CalibratedBinaryClassificationMetrics(Host, cursor);
825+
result = new CalibratedBinaryClassificationMetrics(Host, cursor, confusionMatrix);
825826
moved = cursor.MoveNext();
826827
Host.Assert(!moved);
827828
}
829+
828830
return result;
829831
}
830832

@@ -879,13 +881,14 @@ public CalibratedBinaryClassificationMetrics EvaluateWithPRCurve(
879881
}
880882
}
881883
prCurve = prCurveResult;
884+
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];
882885

883886
CalibratedBinaryClassificationMetrics result;
884887
using (var cursor = overall.GetRowCursorForAllColumns())
885888
{
886889
var moved = cursor.MoveNext();
887890
Host.Assert(moved);
888-
result = new CalibratedBinaryClassificationMetrics(Host, cursor);
891+
result = new CalibratedBinaryClassificationMetrics(Host, cursor, confusionMatrix);
889892
moved = cursor.MoveNext();
890893
Host.Assert(!moved);
891894
}
@@ -939,16 +942,18 @@ public BinaryClassificationMetrics Evaluate(IDataView data, string label, string
939942
var resultDict = ((IEvaluator)this).Evaluate(roles);
940943
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
941944
var overall = resultDict[MetricKinds.OverallMetrics];
945+
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];
942946

943947
BinaryClassificationMetrics result;
944948
using (var cursor = overall.GetRowCursorForAllColumns())
945949
{
946950
var moved = cursor.MoveNext();
947951
Host.Assert(moved);
948-
result = new BinaryClassificationMetrics(Host, cursor);
952+
result = new BinaryClassificationMetrics(Host, cursor, confusionMatrix);
949953
moved = cursor.MoveNext();
950954
Host.Assert(!moved);
951955
}
956+
952957
return result;
953958
}
954959

@@ -985,6 +990,7 @@ public BinaryClassificationMetrics EvaluateWithPRCurve(
985990
var prCurveView = resultDict[MetricKinds.PrCurve];
986991
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
987992
var overall = resultDict[MetricKinds.OverallMetrics];
993+
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];
988994

989995
var prCurveResult = new List<BinaryPrecisionRecallDataPoint>();
990996
using (var cursor = prCurveView.GetRowCursorForAllColumns())
@@ -1007,7 +1013,7 @@ public BinaryClassificationMetrics EvaluateWithPRCurve(
10071013
{
10081014
var moved = cursor.MoveNext();
10091015
Host.Assert(moved);
1010-
result = new BinaryClassificationMetrics(Host, cursor);
1016+
result = new BinaryClassificationMetrics(Host, cursor, confusionMatrix);
10111017
moved = cursor.MoveNext();
10121018
Host.Assert(!moved);
10131019
}
@@ -1377,7 +1383,7 @@ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary<str
13771383
fold = ColumnSelectingTransformer.CreateKeep(Host, fold, colsToKeep.ToArray());
13781384

13791385
string weightedConf;
1380-
var unweightedConf = MetricWriter.GetConfusionTable(Host, conf, out weightedConf);
1386+
var unweightedConf = MetricWriter.GetConfusionTableAsFormattedString(Host, conf, out weightedConf);
13811387
string weightedFold;
13821388
var unweightedFold = MetricWriter.GetPerFoldResults(Host, fold, out weightedFold);
13831389
ch.Assert(string.IsNullOrEmpty(weightedConf) == string.IsNullOrEmpty(weightedFold));

src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs

+69-42
Original file line numberDiff line numberDiff line change
@@ -1348,20 +1348,49 @@ internal static class MetricWriter
13481348
/// is assigned the string representation of the weighted confusion table. Otherwise it is assigned null.</param>
13491349
/// <param name="binary">Indicates whether the confusion table is for binary classification.</param>
13501350
/// <param name="sample">Indicates how many classes to sample from the confusion table (-1 indicates no sampling)</param>
1351-
public static string GetConfusionTable(IHost host, IDataView confusionDataView, out string weightedConfusionTable, bool binary = true, int sample = -1)
1351+
public static string GetConfusionTableAsFormattedString(IHost host, IDataView confusionDataView, out string weightedConfusionTable, bool binary = true, int sample = -1)
13521352
{
13531353
host.CheckValue(confusionDataView, nameof(confusionDataView));
13541354
host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2");
13551355

1356-
// Get the class names.
1357-
int countCol;
1358-
host.Check(confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.Count, out countCol), "Did not find the count column");
1359-
var type = confusionDataView.Schema[countCol].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType;
1360-
host.Check(type != null && type.IsKnownSize && type.ItemType is TextDataViewType, "The Count column does not have a text vector metadata of kind SlotNames.");
1356+
var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight);
1357+
bool isWeighted = weightColumn.HasValue;
13611358

1359+
var confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, false);
1360+
var confusionTableString = GetConfusionTableAsString(confusionMatrix, false);
1361+
1362+
// If there is a Weight column, return the weighted confusionMatrix as well, from this function.
1363+
if (isWeighted)
1364+
{
1365+
confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, true);
1366+
weightedConfusionTable = GetConfusionTableAsString(confusionMatrix, true);
1367+
}
1368+
else
1369+
weightedConfusionTable = null;
1370+
1371+
return confusionTableString;
1372+
}
1373+
1374+
public static ConfusionMatrix GetConfusionMatrix(IHost host, IDataView confusionDataView, bool binary = true, int sample = -1, bool getWeighted = false)
1375+
{
1376+
host.CheckValue(confusionDataView, nameof(confusionDataView));
1377+
host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2");
1378+
1379+
// check that there is a Weight column, if isWeighted parameter is set to true.
1380+
var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight);
1381+
if (getWeighted)
1382+
host.CheckParam(weightColumn.HasValue, nameof(getWeighted), "There is no Weight column in the confusionMatrix data view.");
1383+
1384+
// Get the counts names.
1385+
var countColumn = confusionDataView.Schema[MetricKinds.ColumnNames.Count];
1386+
var type = countColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType;
1387+
//"The Count column does not have a text vector metadata of kind SlotNames."
1388+
host.Assert(type != null && type.IsKnownSize && type.ItemType is TextDataViewType);
1389+
1390+
// Get the class names
13621391
var labelNames = default(VBuffer<ReadOnlyMemory<char>>);
1363-
confusionDataView.Schema[countCol].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames);
1364-
host.Check(labelNames.IsDense, "Slot names vector must be dense");
1392+
countColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames);
1393+
host.Assert(labelNames.IsDense, "Slot names vector must be dense");
13651394

13661395
int numConfusionTableLabels = sample < 0 ? labelNames.Length : Math.Min(labelNames.Length, sample);
13671396

@@ -1387,32 +1416,32 @@ public static string GetConfusionTable(IHost host, IDataView confusionDataView,
13871416

13881417
double[] precisionSums;
13891418
double[] recallSums;
1390-
var confusionTable = GetConfusionTableAsArray(confusionDataView, countCol, labelNames.Length,
1391-
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);
1419+
double[][] confusionTable;
13921420

1393-
var predictedLabelNames = GetPredictedLabelNames(in labelNames, labelIndexToConfIndexMap);
1394-
var confusionTableString = GetConfusionTableAsString(confusionTable, recallSums, precisionSums,
1395-
predictedLabelNames,
1396-
sampled: numConfusionTableLabels < labelNames.Length, binary: binary);
1421+
if (getWeighted)
1422+
confusionTable = GetConfusionTableAsArray(confusionDataView, weightColumn.Value.Index, labelNames.Length,
1423+
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);
1424+
else
1425+
confusionTable = GetConfusionTableAsArray(confusionDataView, countColumn.Index, labelNames.Length,
1426+
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);
13971427

1398-
int weightIndex;
1399-
if (confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.Weight, out weightIndex))
1428+
double[] precision = new double[numConfusionTableLabels];
1429+
double[] recall = new double[numConfusionTableLabels];
1430+
for (int i = 0; i < numConfusionTableLabels; i++)
14001431
{
1401-
confusionTable = GetConfusionTableAsArray(confusionDataView, weightIndex, labelNames.Length,
1402-
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);
1403-
weightedConfusionTable = GetConfusionTableAsString(confusionTable, recallSums, precisionSums,
1404-
predictedLabelNames,
1405-
sampled: numConfusionTableLabels < labelNames.Length, prefix: "Weighted ", binary: binary);
1432+
recall[i] = recallSums[i] > 0 ? confusionTable[i][i] / recallSums[i] : 0;
1433+
precision[i] = precisionSums[i] > 0 ? confusionTable[i][i] / precisionSums[i] : 0;
14061434
}
1407-
else
1408-
weightedConfusionTable = null;
14091435

1410-
return confusionTableString;
1436+
var predictedLabelNames = GetPredictedLabelNames(in labelNames, labelIndexToConfIndexMap);
1437+
bool sampled = numConfusionTableLabels < labelNames.Length;
1438+
1439+
return new ConfusionMatrix(host, precision, recall, confusionTable, predictedLabelNames, sampled, binary);
14111440
}
14121441

14131442
private static List<ReadOnlyMemory<char>> GetPredictedLabelNames(in VBuffer<ReadOnlyMemory<char>> labelNames, int[] labelIndexToConfIndexMap)
14141443
{
1415-
List<ReadOnlyMemory<char>> result = new List<ReadOnlyMemory<char>>();
1444+
List <ReadOnlyMemory<char>> result = new List<ReadOnlyMemory<char>>();
14161445
var values = labelNames.GetValues();
14171446
for (int i = 0; i < values.Length; i++)
14181447
{
@@ -1553,13 +1582,13 @@ private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView dat
15531582
}
15541583

15551584
// Get a string representation of a confusion table.
1556-
private static string GetConfusionTableAsString(double[][] confusionTable, double[] rowSums, double[] columnSums,
1557-
List<ReadOnlyMemory<char>> predictedLabelNames, string prefix = "", bool sampled = false, bool binary = true)
1585+
internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix, bool isWeighted)
15581586
{
1559-
int numLabels = Utils.Size(confusionTable);
1587+
string prefix = isWeighted ? "Weighted " : "";
1588+
int numLabels = confusionMatrix?.Counts == null? 0: confusionMatrix.Counts.Count;
15601589

15611590
int colWidth = numLabels == 2 ? 8 : 5;
1562-
int maxNameLen = predictedLabelNames.Max(name => name.Length);
1591+
int maxNameLen = confusionMatrix.PredictedClassesIndicators.Max(name => name.Length);
15631592
// If the names are too long to fit in the column header, we back off to using class indices
15641593
// in the header. This will also require putting the indices in the row, but it's better than
15651594
// the alternative of having ambiguous abbreviated column headers, or having a table potentially
@@ -1572,7 +1601,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
15721601
{
15731602
// The row label will also include the index, so a user can easily match against the header.
15741603
// In such a case, a label like "Foo" would be presented as something like "5. Foo".
1575-
rowDigitLen = Math.Max(predictedLabelNames.Count - 1, 0).ToString().Length;
1604+
rowDigitLen = Math.Max(confusionMatrix.PredictedClassesIndicators.Count - 1, 0).ToString().Length;
15761605
Contracts.Assert(rowDigitLen >= 1);
15771606
rowLabelLen += rowDigitLen + 2;
15781607
}
@@ -1591,10 +1620,11 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
15911620
else
15921621
rowLabelFormat = string.Format("{{1,{0}}} ||", paddingLen);
15931622

1623+
var confusionTable = confusionMatrix.Counts;
15941624
var sb = new StringBuilder();
1595-
if (numLabels == 2 && binary)
1625+
if (numLabels == 2 && confusionMatrix.IsBinary)
15961626
{
1597-
var positiveCaps = predictedLabelNames[0].ToString().ToUpper();
1627+
var positiveCaps = confusionMatrix.PredictedClassesIndicators[0].ToString().ToUpper();
15981628

15991629
var numTruePos = confusionTable[0][0];
16001630
var numFalseNeg = confusionTable[0][1];
@@ -1607,7 +1637,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
16071637

16081638
sb.AppendLine();
16091639
sb.AppendFormat("{0}Confusion table", prefix);
1610-
if (sampled)
1640+
if (confusionMatrix.IsSampled)
16111641
sb.AppendLine(" (sampled)");
16121642
else
16131643
sb.AppendLine();
@@ -1619,7 +1649,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
16191649
sb.AppendFormat("PREDICTED {0}||", pad);
16201650
string format = string.Format(" {{{0},{1}}} |", useNumbersInHeader ? 0 : 1, colWidth);
16211651
for (int i = 0; i < numLabels; i++)
1622-
sb.AppendFormat(format, i, predictedLabelNames[i]);
1652+
sb.AppendFormat(format, i, confusionMatrix.PredictedClassesIndicators[i]);
16231653
sb.AppendLine(" Recall");
16241654
sb.AppendFormat("TRUTH {0}||", pad);
16251655
for (int i = 0; i < numLabels; i++)
@@ -1631,11 +1661,10 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
16311661
string.IsNullOrWhiteSpace(prefix) ? "N0" : "F1");
16321662
for (int i = 0; i < numLabels; i++)
16331663
{
1634-
sb.AppendFormat(rowLabelFormat, i, predictedLabelNames[i]);
1664+
sb.AppendFormat(rowLabelFormat, i, confusionMatrix.PredictedClassesIndicators[i]);
16351665
for (int j = 0; j < numLabels; j++)
16361666
sb.AppendFormat(format2, confusionTable[i][j]);
1637-
Double recall = rowSums[i] > 0 ? confusionTable[i][i] / rowSums[i] : 0;
1638-
sb.AppendFormat(" {0,5:F4}", recall);
1667+
sb.AppendFormat(" {0,5:F4}", confusionMatrix.PerClassRecall[i]);
16391668
sb.AppendLine();
16401669
}
16411670
sb.AppendFormat(" {0}||", pad);
@@ -1645,10 +1674,8 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
16451674
sb.AppendFormat("Precision {0}||", pad);
16461675
format = string.Format("{{0,{0}:N4}} |", colWidth + 1);
16471676
for (int i = 0; i < numLabels; i++)
1648-
{
1649-
Double precision = columnSums[i] > 0 ? confusionTable[i][i] / columnSums[i] : 0;
1650-
sb.AppendFormat(format, precision);
1651-
}
1677+
sb.AppendFormat(format, confusionMatrix.PerClassPrecision[i]);
1678+
16521679
sb.AppendLine();
16531680
return sb.ToString();
16541681
}
@@ -1701,7 +1728,7 @@ public static void PrintWarnings(IChannel ch, Dictionary<string, IDataView> metr
17011728
if (metrics.TryGetValue(MetricKinds.Warnings, out warnings))
17021729
{
17031730
var warningTextColumn = warnings.Schema.GetColumnOrNull(MetricKinds.ColumnNames.WarningText);
1704-
if (warningTextColumn !=null && warningTextColumn.HasValue && warningTextColumn.Value.Type is TextDataViewType)
1731+
if (warningTextColumn != null && warningTextColumn.HasValue && warningTextColumn.Value.Type is TextDataViewType)
17051732
{
17061733
using (var cursor = warnings.GetRowCursor(warnings.Schema[MetricKinds.ColumnNames.WarningText]))
17071734
{

src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs

+9-2
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ public class BinaryClassificationMetrics
7474
/// </remarks>
7575
public double AreaUnderPrecisionRecallCurve { get; }
7676

77+
/// <summary>
78+
/// The <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a> giving the counts of the
79+
/// true positives, true negatives, false positives and false negatives for the two classes of data.
80+
/// </summary>
81+
public ConfusionMatrix ConfusionMatrix { get; }
82+
7783
private protected static T Fetch<T>(IExceptionContext ectx, DataViewRow row, string name)
7884
{
7985
var column = row.Schema.GetColumnOrNull(name);
@@ -84,9 +90,9 @@ private protected static T Fetch<T>(IExceptionContext ectx, DataViewRow row, str
8490
return val;
8591
}
8692

87-
internal BinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overallResult)
93+
internal BinaryClassificationMetrics(IHost host, DataViewRow overallResult, IDataView confusionMatrix)
8894
{
89-
double Fetch(string name) => Fetch<double>(ectx, overallResult, name);
95+
double Fetch(string name) => Fetch<double>(host, overallResult, name);
9096
AreaUnderRocCurve = Fetch(BinaryClassifierEvaluator.Auc);
9197
Accuracy = Fetch(BinaryClassifierEvaluator.Accuracy);
9298
PositivePrecision = Fetch(BinaryClassifierEvaluator.PosPrecName);
@@ -95,6 +101,7 @@ internal BinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overall
95101
NegativeRecall = Fetch(BinaryClassifierEvaluator.NegRecallName);
96102
F1Score = Fetch(BinaryClassifierEvaluator.F1);
97103
AreaUnderPrecisionRecallCurve = Fetch(BinaryClassifierEvaluator.AuPrc);
104+
ConfusionMatrix = MetricWriter.GetConfusionMatrix(host, confusionMatrix);
98105
}
99106

100107
[BestFriend]

0 commit comments

Comments
 (0)