5
5
using System ;
6
6
using System . Collections . Generic ;
7
7
using System . Linq ;
8
+ using Microsoft . ML . Data . StaticPipe ;
9
+ using Microsoft . ML . Data . StaticPipe . Runtime ;
8
10
using Microsoft . ML . Runtime ;
9
11
using Microsoft . ML . Runtime . CommandLine ;
10
12
using Microsoft . ML . Runtime . Data ;
@@ -410,66 +412,58 @@ public sealed class Counters
410
412
411
413
public Double Acc
412
414
{
413
- get
414
- {
415
+ get {
415
416
return ( NumTrueNeg + NumTruePos ) / ( NumTruePos + NumTrueNeg + NumFalseNeg + NumFalsePos ) ;
416
417
}
417
418
}
418
419
419
420
public Double RecallPos
420
421
{
421
- get
422
- {
422
+ get {
423
423
return ( NumTruePos + NumFalseNeg > 0 ) ? NumTruePos / ( NumTruePos + NumFalseNeg ) : 0 ;
424
424
}
425
425
}
426
426
427
427
public Double PrecisionPos
428
428
{
429
- get
430
- {
429
+ get {
431
430
return ( NumTruePos + NumFalsePos > 0 ) ? NumTruePos / ( NumTruePos + NumFalsePos ) : 0 ;
432
431
}
433
432
}
434
433
435
434
public Double RecallNeg
436
435
{
437
- get
438
- {
436
+ get {
439
437
return ( NumTrueNeg + NumFalsePos > 0 ) ? NumTrueNeg / ( NumTrueNeg + NumFalsePos ) : 0 ;
440
438
}
441
439
}
442
440
443
441
public Double PrecisionNeg
444
442
{
445
- get
446
- {
443
+ get {
447
444
return ( NumTrueNeg + NumFalseNeg > 0 ) ? NumTrueNeg / ( NumTrueNeg + NumFalseNeg ) : 0 ;
448
445
}
449
446
}
450
447
451
448
public Double Entropy
452
449
{
453
- get
454
- {
450
+ get {
455
451
return MathUtils . Entropy ( ( NumTruePos + NumFalseNeg ) /
456
452
( NumTruePos + NumTrueNeg + NumFalseNeg + NumFalsePos ) ) ;
457
453
}
458
454
}
459
455
460
456
public Double LogLoss
461
457
{
462
- get
463
- {
458
+ get {
464
459
return Double . IsNaN ( _logLoss ) ? Double . NaN : ( _numLogLossPositives + _numLogLossNegatives > 0 )
465
460
? _logLoss / ( _numLogLossPositives + _numLogLossNegatives ) : 0 ;
466
461
}
467
462
}
468
463
469
464
public Double LogLossReduction
470
465
{
471
- get
472
- {
466
+ get {
473
467
if ( _numLogLossPositives + _numLogLossNegatives == 0 )
474
468
return 0 ;
475
469
var logLoss = _logLoss / ( _numLogLossPositives + _numLogLossNegatives ) ;
@@ -787,6 +781,246 @@ private void ComputePrCurves()
787
781
}
788
782
}
789
783
}
784
+
785
+ /// <summary>
786
+ /// Evaluation results for binary classifiers, excluding probabilistic metrics.
787
+ /// </summary>
788
+ public class Result
789
+ {
790
+ /// <summary>
791
+ /// Gets the area under the ROC curve.
792
+ /// </summary>
793
+ /// <remarks>
794
+ /// The area under the ROC curve is equal to the probability that the classifier ranks
795
+ /// a randomly chosen positive instance higher than a randomly chosen negative one
796
+ /// (assuming 'positive' ranks higher than 'negative').
797
+ /// </remarks>
798
+ public double Auc { get ; }
799
+
800
+ /// <summary>
801
+ /// Gets the accuracy of a classifier which is the proportion of correct predictions in the test set.
802
+ /// </summary>
803
+ public double Accuracy { get ; }
804
+
805
+ /// <summary>
806
+ /// Gets the positive precision of a classifier which is the proportion of correctly predicted
807
+ /// positive instances among all the positive predictions (i.e., the number of positive instances
808
+ /// predicted as positive, divided by the total number of instances predicted as positive).
809
+ /// </summary>
810
+ public double PositivePrecision { get ; }
811
+
812
+ /// <summary>
813
+ /// Gets the positive recall of a classifier which is the proportion of correctly predicted
814
+ /// positive instances among all the positive instances (i.e., the number of positive instances
815
+ /// predicted as positive, divided by the total number of positive instances).
816
+ /// </summary>
817
+ public double PositiveRecall { get ; private set ; }
818
+
819
+ /// <summary>
820
+ /// Gets the negative precision of a classifier which is the proportion of correctly predicted
821
+ /// negative instances among all the negative predictions (i.e., the number of negative instances
822
+ /// predicted as negative, divided by the total number of instances predicted as negative).
823
+ /// </summary>
824
+ public double NegativePrecision { get ; }
825
+
826
+ /// <summary>
827
+ /// Gets the negative recall of a classifier which is the proportion of correctly predicted
828
+ /// negative instances among all the negative instances (i.e., the number of negative instances
829
+ /// predicted as negative, divided by the total number of negative instances).
830
+ /// </summary>
831
+ public double NegativeRecall { get ; }
832
+
833
+ /// <summary>
834
+ /// Gets the F1 score of the classifier.
835
+ /// </summary>
836
+ /// <remarks>
837
+ /// F1 score is the harmonic mean of precision and recall: 2 * precision * recall / (precision + recall).
838
+ /// </remarks>
839
+ public double F1Score { get ; }
840
+
841
+ /// <summary>
842
+ /// Gets the area under the precision/recall curve of the classifier.
843
+ /// </summary>
844
+ /// <remarks>
845
+ /// The area under the precision/recall curve is a single number summary of the information in the
846
+ /// precision/recall curve. It is increasingly used in the machine learning community, particularly
847
+ /// for imbalanced datasets where one class is observed more frequently than the other. On these
848
+ /// datasets, AUPRC can highlight performance differences that are lost with AUC.
849
+ /// </remarks>
850
+ public double Auprc { get ; }
851
+
852
+ protected private static T Fetch < T > ( IExceptionContext ectx , IRow row , string name )
853
+ {
854
+ if ( ! row . Schema . TryGetColumnIndex ( name , out int col ) )
855
+ throw ectx . Except ( $ "Could not find column '{ name } '") ;
856
+ T val = default ;
857
+ row . GetGetter < T > ( col ) ( ref val ) ;
858
+ return val ;
859
+ }
860
+
861
+ internal Result ( IExceptionContext ectx , IRow overallResult )
862
+ {
863
+ double Fetch ( string name ) => Fetch < double > ( ectx , overallResult , name ) ;
864
+ Auc = Fetch ( BinaryClassifierEvaluator . Auc ) ;
865
+ Accuracy = Fetch ( BinaryClassifierEvaluator . Accuracy ) ;
866
+ PositivePrecision = Fetch ( BinaryClassifierEvaluator . PosPrecName ) ;
867
+ PositiveRecall = Fetch ( BinaryClassifierEvaluator . PosRecallName ) ;
868
+ NegativePrecision = Fetch ( BinaryClassifierEvaluator . NegPrecName ) ;
869
+ NegativeRecall = Fetch ( BinaryClassifierEvaluator . NegRecallName ) ;
870
+ F1Score = Fetch ( BinaryClassifierEvaluator . F1 ) ;
871
+ Auprc = Fetch ( BinaryClassifierEvaluator . AuPrc ) ;
872
+ }
873
+ }
874
+
875
+ /// <summary>
876
+ /// Evaluation results for binary classifiers, including probabilistic metrics.
877
+ /// </summary>
878
+ public sealed class CalibratedResult : Result
879
+ {
880
+ /// <summary>
881
+ /// Gets the log-loss of the classifier.
882
+ /// </summary>
883
+ /// <remarks>
884
+ /// The log-loss metric, is computed as follows:
885
+ /// LL = - (1/m) * sum( log(p[i]))
886
+ /// where m is the number of instances in the test set.
887
+ /// p[i] is the probability returned by the classifier if the instance belongs to class 1,
888
+ /// and 1 minus the probability returned by the classifier if the instance belongs to class 0.
889
+ /// </remarks>
890
+ public double LogLoss { get ; }
891
+
892
+ /// <summary>
893
+ /// Gets the log-loss reduction (also known as relative log-loss, or reduction in information gain - RIG)
894
+ /// of the classifier.
895
+ /// </summary>
896
+ /// <remarks>
897
+ /// The log-loss reduction is scaled relative to a classifier that predicts the prior for every example:
898
+ /// (LL(prior) - LL(classifier)) / LL(prior)
899
+ /// This metric can be interpreted as the advantage of the classifier over a random prediction.
900
+ /// E.g., if the RIG equals 20, it can be interpreted as "the probability of a correct prediction is
901
+ /// 20% better than random guessing."
902
+ /// </remarks>
903
+ public double LogLossReduction { get ; }
904
+
905
+ /// <summary>
906
+ /// Gets the test-set entropy (prior Log-Loss/instance) of the classifier.
907
+ /// </summary>
908
+ public double Entropy { get ; }
909
+
910
+ internal CalibratedResult ( IExceptionContext ectx , IRow overallResult )
911
+ : base ( ectx , overallResult )
912
+ {
913
+ double Fetch ( string name ) => Fetch < double > ( ectx , overallResult , name ) ;
914
+ LogLoss = Fetch ( BinaryClassifierEvaluator . LogLoss ) ;
915
+ LogLossReduction = Fetch ( BinaryClassifierEvaluator . LogLossReduction ) ;
916
+ Entropy = Fetch ( BinaryClassifierEvaluator . Entropy ) ;
917
+ }
918
+ }
919
+
920
+ /// <summary>
921
+ /// Evaluates scored binary classification data.
922
+ /// </summary>
923
+ /// <typeparam name="T">The shape type for the input data.</typeparam>
924
+ /// <param name="data">The data to evaluate.</param>
925
+ /// <param name="label">The index delegate for the label column.</param>
926
+ /// <param name="pred">The index delegate for columns from calibrated prediction of a binary classifier.
927
+ /// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
928
+ /// <returns>The evaluation results for these calibrated outputs.</returns>
929
+ public static CalibratedResult Evaluate < T > (
930
+ DataView < T > data ,
931
+ Func < T , Scalar < bool > > label ,
932
+ Func < T , ( Scalar < float > score , Scalar < float > probability , Scalar < bool > predictedLabel ) > pred )
933
+ {
934
+ Contracts . CheckValue ( data , nameof ( data ) ) ;
935
+ var env = StaticPipeUtils . GetEnvironment ( data ) ;
936
+ Contracts . AssertValue ( env ) ;
937
+ env . CheckValue ( label , nameof ( label ) ) ;
938
+ env . CheckValue ( pred , nameof ( pred ) ) ;
939
+
940
+ var indexer = StaticPipeUtils . GetIndexer ( data ) ;
941
+ string labelName = indexer . Get ( label ( indexer . Indices ) ) ;
942
+ ( var scoreCol , var probCol , var predCol ) = pred ( indexer . Indices ) ;
943
+ Contracts . CheckParam ( scoreCol != null , nameof ( pred ) , "Indexing delegate resulted in null score column." ) ;
944
+ Contracts . CheckParam ( probCol != null , nameof ( pred ) , "Indexing delegate resulted in null probability column." ) ;
945
+ Contracts . CheckParam ( predCol != null , nameof ( pred ) , "Indexing delegate resulted in null predicted label column." ) ;
946
+ string scoreName = indexer . Get ( scoreCol ) ;
947
+ string probName = indexer . Get ( probCol ) ;
948
+ string predName = indexer . Get ( predCol ) ;
949
+
950
+ var eval = new BinaryClassifierEvaluator ( env , new Arguments ( ) { } ) ;
951
+
952
+ var roles = new RoleMappedData ( data . AsDynamic , opt : false ,
953
+ RoleMappedSchema . ColumnRole . Label . Bind ( labelName ) ,
954
+ RoleMappedSchema . CreatePair ( MetadataUtils . Const . ScoreValueKind . Score , scoreName ) ,
955
+ RoleMappedSchema . CreatePair ( MetadataUtils . Const . ScoreValueKind . Probability , probName ) ,
956
+ RoleMappedSchema . CreatePair ( MetadataUtils . Const . ScoreValueKind . PredictedLabel , predName ) ) ;
957
+
958
+ var resultDict = eval . Evaluate ( roles ) ;
959
+ env . Assert ( resultDict . ContainsKey ( MetricKinds . OverallMetrics ) ) ;
960
+ var overall = resultDict [ MetricKinds . OverallMetrics ] ;
961
+
962
+ CalibratedResult result ;
963
+ using ( var cursor = overall . GetRowCursor ( i => true ) )
964
+ {
965
+ var moved = cursor . MoveNext ( ) ;
966
+ env . Assert ( moved ) ;
967
+ result = new CalibratedResult ( env , cursor ) ;
968
+ moved = cursor . MoveNext ( ) ;
969
+ env . Assert ( ! moved ) ;
970
+ }
971
+ return result ;
972
+ }
973
+
974
+ /// <summary>
975
+ /// Evaluates scored binary classification data.
976
+ /// </summary>
977
+ /// <typeparam name="T">The shape type for the input data.</typeparam>
978
+ /// <param name="data">The data to evaluate.</param>
979
+ /// <param name="label">The index delegate for the label column.</param>
980
+ /// <param name="pred">The index delegate for columns from calibrated prediction of a binary classifier.
981
+ /// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
982
+ /// <returns>The evaluation results for these uncalibrated outputs.</returns>
983
+ public static Result Evaluate < T > (
984
+ DataView < T > data ,
985
+ Func < T , Scalar < bool > > label ,
986
+ Func < T , ( Scalar < float > score , Scalar < bool > predictedLabel ) > pred )
987
+ {
988
+ Contracts . CheckValue ( data , nameof ( data ) ) ;
989
+ var env = StaticPipeUtils . GetEnvironment ( data ) ;
990
+ Contracts . AssertValue ( env ) ;
991
+ env . CheckValue ( label , nameof ( label ) ) ;
992
+ env . CheckValue ( pred , nameof ( pred ) ) ;
993
+
994
+ var indexer = StaticPipeUtils . GetIndexer ( data ) ;
995
+ string labelName = indexer . Get ( label ( indexer . Indices ) ) ;
996
+ ( var scoreCol , var predCol ) = pred ( indexer . Indices ) ;
997
+ Contracts . CheckParam ( scoreCol != null , nameof ( pred ) , "Indexing delegate resulted in null score column." ) ;
998
+ Contracts . CheckParam ( predCol != null , nameof ( pred ) , "Indexing delegate resulted in null predicted label column." ) ;
999
+ string scoreName = indexer . Get ( scoreCol ) ;
1000
+ string predName = indexer . Get ( predCol ) ;
1001
+
1002
+ var eval = new BinaryClassifierEvaluator ( env , new Arguments ( ) { } ) ;
1003
+
1004
+ var roles = new RoleMappedData ( data . AsDynamic , opt : false ,
1005
+ RoleMappedSchema . ColumnRole . Label . Bind ( labelName ) ,
1006
+ RoleMappedSchema . CreatePair ( MetadataUtils . Const . ScoreValueKind . Score , scoreName ) ,
1007
+ RoleMappedSchema . CreatePair ( MetadataUtils . Const . ScoreValueKind . PredictedLabel , predName ) ) ;
1008
+
1009
+ var resultDict = eval . Evaluate ( roles ) ;
1010
+ env . Assert ( resultDict . ContainsKey ( MetricKinds . OverallMetrics ) ) ;
1011
+ var overall = resultDict [ MetricKinds . OverallMetrics ] ;
1012
+
1013
+ Result result ;
1014
+ using ( var cursor = overall . GetRowCursor ( i => true ) )
1015
+ {
1016
+ var moved = cursor . MoveNext ( ) ;
1017
+ env . Assert ( moved ) ;
1018
+ result = new Result ( env , cursor ) ;
1019
+ moved = cursor . MoveNext ( ) ;
1020
+ env . Assert ( ! moved ) ;
1021
+ }
1022
+ return result ;
1023
+ }
790
1024
}
791
1025
792
1026
public sealed class BinaryPerInstanceEvaluator : PerInstanceEvaluatorBase
@@ -1526,4 +1760,4 @@ private static IDataView ExtractConfusionMatrix(IHost host, Dictionary<string, I
1526
1760
return confusionMatrix ;
1527
1761
}
1528
1762
}
1529
- }
1763
+ }
0 commit comments