@@ -1348,20 +1348,49 @@ internal static class MetricWriter
1348
1348
/// is assigned the string representation of the weighted confusion table. Otherwise it is assigned null.</param>
1349
1349
/// <param name="binary">Indicates whether the confusion table is for binary classification.</param>
1350
1350
/// <param name="sample">Indicates how many classes to sample from the confusion table (-1 indicates no sampling)</param>
1351
- public static string GetConfusionTable ( IHost host , IDataView confusionDataView , out string weightedConfusionTable , bool binary = true , int sample = - 1 )
1351
+ public static string GetConfusionTableAsFormattedString ( IHost host , IDataView confusionDataView , out string weightedConfusionTable , bool binary = true , int sample = - 1 )
1352
1352
{
1353
1353
host . CheckValue ( confusionDataView , nameof ( confusionDataView ) ) ;
1354
1354
host . CheckParam ( sample == - 1 || sample >= 2 , nameof ( sample ) , "Should be -1 to indicate no sampling, or at least 2" ) ;
1355
1355
1356
- // Get the class names.
1357
- int countCol ;
1358
- host . Check ( confusionDataView . Schema . TryGetColumnIndex ( MetricKinds . ColumnNames . Count , out countCol ) , "Did not find the count column" ) ;
1359
- var type = confusionDataView . Schema [ countCol ] . Annotations . Schema . GetColumnOrNull ( AnnotationUtils . Kinds . SlotNames ) ? . Type as VectorDataViewType ;
1360
- host . Check ( type != null && type . IsKnownSize && type . ItemType is TextDataViewType , "The Count column does not have a text vector metadata of kind SlotNames." ) ;
1356
+ var weightColumn = confusionDataView . Schema . GetColumnOrNull ( MetricKinds . ColumnNames . Weight ) ;
1357
+ bool isWeighted = weightColumn . HasValue ;
1361
1358
1359
+ var confusionMatrix = GetConfusionMatrix ( host , confusionDataView , binary , sample , false ) ;
1360
+ var confusionTableString = GetConfusionTableAsString ( confusionMatrix , false ) ;
1361
+
1362
+ // If there is a Weight column, return the weighted confusionMatrix as well, from this function.
1363
+ if ( isWeighted )
1364
+ {
1365
+ confusionMatrix = GetConfusionMatrix ( host , confusionDataView , binary , sample , true ) ;
1366
+ weightedConfusionTable = GetConfusionTableAsString ( confusionMatrix , true ) ;
1367
+ }
1368
+ else
1369
+ weightedConfusionTable = null ;
1370
+
1371
+ return confusionTableString ;
1372
+ }
1373
+
1374
+ public static ConfusionMatrix GetConfusionMatrix ( IHost host , IDataView confusionDataView , bool binary = true , int sample = - 1 , bool getWeighted = false )
1375
+ {
1376
+ host . CheckValue ( confusionDataView , nameof ( confusionDataView ) ) ;
1377
+ host . CheckParam ( sample == - 1 || sample >= 2 , nameof ( sample ) , "Should be -1 to indicate no sampling, or at least 2" ) ;
1378
+
1379
+ // check that there is a Weight column, if isWeighted parameter is set to true.
1380
+ var weightColumn = confusionDataView . Schema . GetColumnOrNull ( MetricKinds . ColumnNames . Weight ) ;
1381
+ if ( getWeighted )
1382
+ host . CheckParam ( weightColumn . HasValue , nameof ( getWeighted ) , "There is no Weight column in the confusionMatrix data view." ) ;
1383
+
1384
+ // Get the counts names.
1385
+ var countColumn = confusionDataView . Schema [ MetricKinds . ColumnNames . Count ] ;
1386
+ var type = countColumn . Annotations . Schema . GetColumnOrNull ( AnnotationUtils . Kinds . SlotNames ) ? . Type as VectorDataViewType ;
1387
+ //"The Count column does not have a text vector metadata of kind SlotNames."
1388
+ host . Assert ( type != null && type . IsKnownSize && type . ItemType is TextDataViewType ) ;
1389
+
1390
+ // Get the class names
1362
1391
var labelNames = default ( VBuffer < ReadOnlyMemory < char > > ) ;
1363
- confusionDataView . Schema [ countCol ] . Annotations . GetValue ( AnnotationUtils . Kinds . SlotNames , ref labelNames ) ;
1364
- host . Check ( labelNames . IsDense , "Slot names vector must be dense" ) ;
1392
+ countColumn . Annotations . GetValue ( AnnotationUtils . Kinds . SlotNames , ref labelNames ) ;
1393
+ host . Assert ( labelNames . IsDense , "Slot names vector must be dense" ) ;
1365
1394
1366
1395
int numConfusionTableLabels = sample < 0 ? labelNames . Length : Math . Min ( labelNames . Length , sample ) ;
1367
1396
@@ -1387,32 +1416,32 @@ public static string GetConfusionTable(IHost host, IDataView confusionDataView,
1387
1416
1388
1417
double [ ] precisionSums ;
1389
1418
double [ ] recallSums ;
1390
- var confusionTable = GetConfusionTableAsArray ( confusionDataView , countCol , labelNames . Length ,
1391
- labelIndexToConfIndexMap , numConfusionTableLabels , out precisionSums , out recallSums ) ;
1419
+ double [ ] [ ] confusionTable ;
1392
1420
1393
- var predictedLabelNames = GetPredictedLabelNames ( in labelNames , labelIndexToConfIndexMap ) ;
1394
- var confusionTableString = GetConfusionTableAsString ( confusionTable , recallSums , precisionSums ,
1395
- predictedLabelNames ,
1396
- sampled : numConfusionTableLabels < labelNames . Length , binary : binary ) ;
1421
+ if ( getWeighted )
1422
+ confusionTable = GetConfusionTableAsArray ( confusionDataView , weightColumn . Value . Index , labelNames . Length ,
1423
+ labelIndexToConfIndexMap , numConfusionTableLabels , out precisionSums , out recallSums ) ;
1424
+ else
1425
+ confusionTable = GetConfusionTableAsArray ( confusionDataView , countColumn . Index , labelNames . Length ,
1426
+ labelIndexToConfIndexMap , numConfusionTableLabels , out precisionSums , out recallSums ) ;
1397
1427
1398
- int weightIndex ;
1399
- if ( confusionDataView . Schema . TryGetColumnIndex ( MetricKinds . ColumnNames . Weight , out weightIndex ) )
1428
+ double [ ] precision = new double [ numConfusionTableLabels ] ;
1429
+ double [ ] recall = new double [ numConfusionTableLabels ] ;
1430
+ for ( int i = 0 ; i < numConfusionTableLabels ; i ++ )
1400
1431
{
1401
- confusionTable = GetConfusionTableAsArray ( confusionDataView , weightIndex , labelNames . Length ,
1402
- labelIndexToConfIndexMap , numConfusionTableLabels , out precisionSums , out recallSums ) ;
1403
- weightedConfusionTable = GetConfusionTableAsString ( confusionTable , recallSums , precisionSums ,
1404
- predictedLabelNames ,
1405
- sampled : numConfusionTableLabels < labelNames . Length , prefix : "Weighted " , binary : binary ) ;
1432
+ recall [ i ] = recallSums [ i ] > 0 ? confusionTable [ i ] [ i ] / recallSums [ i ] : 0 ;
1433
+ precision [ i ] = precisionSums [ i ] > 0 ? confusionTable [ i ] [ i ] / precisionSums [ i ] : 0 ;
1406
1434
}
1407
- else
1408
- weightedConfusionTable = null ;
1409
1435
1410
- return confusionTableString ;
1436
+ var predictedLabelNames = GetPredictedLabelNames ( in labelNames , labelIndexToConfIndexMap ) ;
1437
+ bool sampled = numConfusionTableLabels < labelNames . Length ;
1438
+
1439
+ return new ConfusionMatrix ( host , precision , recall , confusionTable , predictedLabelNames , sampled , binary ) ;
1411
1440
}
1412
1441
1413
1442
private static List < ReadOnlyMemory < char > > GetPredictedLabelNames ( in VBuffer < ReadOnlyMemory < char > > labelNames , int [ ] labelIndexToConfIndexMap )
1414
1443
{
1415
- List < ReadOnlyMemory < char > > result = new List < ReadOnlyMemory < char > > ( ) ;
1444
+ List < ReadOnlyMemory < char > > result = new List < ReadOnlyMemory < char > > ( ) ;
1416
1445
var values = labelNames . GetValues ( ) ;
1417
1446
for ( int i = 0 ; i < values . Length ; i ++ )
1418
1447
{
@@ -1553,13 +1582,13 @@ private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView dat
1553
1582
}
1554
1583
1555
1584
// Get a string representation of a confusion table.
1556
- private static string GetConfusionTableAsString ( double [ ] [ ] confusionTable , double [ ] rowSums , double [ ] columnSums ,
1557
- List < ReadOnlyMemory < char > > predictedLabelNames , string prefix = "" , bool sampled = false , bool binary = true )
1585
+ internal static string GetConfusionTableAsString ( ConfusionMatrix confusionMatrix , bool isWeighted )
1558
1586
{
1559
- int numLabels = Utils . Size ( confusionTable ) ;
1587
+ string prefix = isWeighted ? "Weighted " : "" ;
1588
+ int numLabels = confusionMatrix ? . Counts == null ? 0 : confusionMatrix . Counts . Count ;
1560
1589
1561
1590
int colWidth = numLabels == 2 ? 8 : 5 ;
1562
- int maxNameLen = predictedLabelNames . Max ( name => name . Length ) ;
1591
+ int maxNameLen = confusionMatrix . PredictedClassesIndicators . Max ( name => name . Length ) ;
1563
1592
// If the names are too long to fit in the column header, we back off to using class indices
1564
1593
// in the header. This will also require putting the indices in the row, but it's better than
1565
1594
// the alternative of having ambiguous abbreviated column headers, or having a table potentially
@@ -1572,7 +1601,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
1572
1601
{
1573
1602
// The row label will also include the index, so a user can easily match against the header.
1574
1603
// In such a case, a label like "Foo" would be presented as something like "5. Foo".
1575
- rowDigitLen = Math . Max ( predictedLabelNames . Count - 1 , 0 ) . ToString ( ) . Length ;
1604
+ rowDigitLen = Math . Max ( confusionMatrix . PredictedClassesIndicators . Count - 1 , 0 ) . ToString ( ) . Length ;
1576
1605
Contracts . Assert ( rowDigitLen >= 1 ) ;
1577
1606
rowLabelLen += rowDigitLen + 2 ;
1578
1607
}
@@ -1591,10 +1620,11 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
1591
1620
else
1592
1621
rowLabelFormat = string . Format ( "{{1,{0}}} ||" , paddingLen ) ;
1593
1622
1623
+ var confusionTable = confusionMatrix . Counts ;
1594
1624
var sb = new StringBuilder ( ) ;
1595
- if ( numLabels == 2 && binary )
1625
+ if ( numLabels == 2 && confusionMatrix . IsBinary )
1596
1626
{
1597
- var positiveCaps = predictedLabelNames [ 0 ] . ToString ( ) . ToUpper ( ) ;
1627
+ var positiveCaps = confusionMatrix . PredictedClassesIndicators [ 0 ] . ToString ( ) . ToUpper ( ) ;
1598
1628
1599
1629
var numTruePos = confusionTable [ 0 ] [ 0 ] ;
1600
1630
var numFalseNeg = confusionTable [ 0 ] [ 1 ] ;
@@ -1607,7 +1637,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
1607
1637
1608
1638
sb . AppendLine ( ) ;
1609
1639
sb . AppendFormat ( "{0}Confusion table" , prefix ) ;
1610
- if ( sampled )
1640
+ if ( confusionMatrix . IsSampled )
1611
1641
sb . AppendLine ( " (sampled)" ) ;
1612
1642
else
1613
1643
sb . AppendLine ( ) ;
@@ -1619,7 +1649,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
1619
1649
sb . AppendFormat ( "PREDICTED {0}||" , pad ) ;
1620
1650
string format = string . Format ( " {{{0},{1}}} |" , useNumbersInHeader ? 0 : 1 , colWidth ) ;
1621
1651
for ( int i = 0 ; i < numLabels ; i ++ )
1622
- sb . AppendFormat ( format , i , predictedLabelNames [ i ] ) ;
1652
+ sb . AppendFormat ( format , i , confusionMatrix . PredictedClassesIndicators [ i ] ) ;
1623
1653
sb . AppendLine ( " Recall" ) ;
1624
1654
sb . AppendFormat ( "TRUTH {0}||" , pad ) ;
1625
1655
for ( int i = 0 ; i < numLabels ; i ++ )
@@ -1631,11 +1661,10 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
1631
1661
string . IsNullOrWhiteSpace ( prefix ) ? "N0" : "F1" ) ;
1632
1662
for ( int i = 0 ; i < numLabels ; i ++ )
1633
1663
{
1634
- sb . AppendFormat ( rowLabelFormat , i , predictedLabelNames [ i ] ) ;
1664
+ sb . AppendFormat ( rowLabelFormat , i , confusionMatrix . PredictedClassesIndicators [ i ] ) ;
1635
1665
for ( int j = 0 ; j < numLabels ; j ++ )
1636
1666
sb . AppendFormat ( format2 , confusionTable [ i ] [ j ] ) ;
1637
- Double recall = rowSums [ i ] > 0 ? confusionTable [ i ] [ i ] / rowSums [ i ] : 0 ;
1638
- sb . AppendFormat ( " {0,5:F4}" , recall ) ;
1667
+ sb . AppendFormat ( " {0,5:F4}" , confusionMatrix . PerClassRecall [ i ] ) ;
1639
1668
sb . AppendLine ( ) ;
1640
1669
}
1641
1670
sb . AppendFormat ( " {0}||" , pad ) ;
@@ -1645,10 +1674,8 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
1645
1674
sb . AppendFormat ( "Precision {0}||" , pad ) ;
1646
1675
format = string . Format ( "{{0,{0}:N4}} |" , colWidth + 1 ) ;
1647
1676
for ( int i = 0 ; i < numLabels ; i ++ )
1648
- {
1649
- Double precision = columnSums [ i ] > 0 ? confusionTable [ i ] [ i ] / columnSums [ i ] : 0 ;
1650
- sb . AppendFormat ( format , precision ) ;
1651
- }
1677
+ sb . AppendFormat ( format , confusionMatrix . PerClassPrecision [ i ] ) ;
1678
+
1652
1679
sb . AppendLine ( ) ;
1653
1680
return sb . ToString ( ) ;
1654
1681
}
@@ -1701,7 +1728,7 @@ public static void PrintWarnings(IChannel ch, Dictionary<string, IDataView> metr
1701
1728
if ( metrics . TryGetValue ( MetricKinds . Warnings , out warnings ) )
1702
1729
{
1703
1730
var warningTextColumn = warnings . Schema . GetColumnOrNull ( MetricKinds . ColumnNames . WarningText ) ;
1704
- if ( warningTextColumn != null && warningTextColumn . HasValue && warningTextColumn . Value . Type is TextDataViewType )
1731
+ if ( warningTextColumn != null && warningTextColumn . HasValue && warningTextColumn . Value . Type is TextDataViewType )
1705
1732
{
1706
1733
using ( var cursor = warnings . GetRowCursor ( warnings . Schema [ MetricKinds . ColumnNames . WarningText ] ) )
1707
1734
{
0 commit comments