@@ -51,13 +51,13 @@ public static BinaryClassifierEvaluator.CalibratedResult Evaluate<T>(
51
51
}
52
52
53
53
/// <summary>
54
- /// Evaluates scored binary classification data.
54
+ /// Evaluates scored binary classification data, if the predictions are not calibrated .
55
55
/// </summary>
56
56
/// <typeparam name="T">The shape type for the input data.</typeparam>
57
57
/// <param name="ctx">The binary classification context.</param>
58
58
/// <param name="data">The data to evaluate.</param>
59
59
/// <param name="label">The index delegate for the label column.</param>
60
- /// <param name="pred">The index delegate for columns from calibrated prediction of a binary classifier.
60
+ /// <param name="pred">The index delegate for columns from uncalibrated prediction of a binary classifier.
61
61
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
62
62
/// <returns>The evaluation results for these uncalibrated outputs.</returns>
63
63
public static BinaryClassifierEvaluator . Result Evaluate < T > (
@@ -83,5 +83,89 @@ public static BinaryClassifierEvaluator.Result Evaluate<T>(
83
83
var eval = new BinaryClassifierEvaluator ( env , new BinaryClassifierEvaluator . Arguments ( ) { } ) ;
84
84
return eval . Evaluate ( data . AsDynamic , labelName , scoreName , predName ) ;
85
85
}
86
+
87
+ /// <summary>
88
+ /// Evaluates scored multiclass classification data.
89
+ /// </summary>
90
+ /// <typeparam name="T">The shape type for the input data.</typeparam>
91
+ /// <typeparam name="TKey">The value type for the key label.</typeparam>
92
+ /// <param name="ctx">The multiclass classification context.</param>
93
+ /// <param name="data">The data to evaluate.</param>
94
+ /// <param name="label">The index delegate for the label column.</param>
95
+ /// <param name="pred">The index delegate for columns from the prediction of a multiclass classifier.
96
+ /// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
97
+ /// <param name="topK">If given a positive value, the <see cref="MultiClassClassifierEvaluator.Result.TopKAccuracy"/> will be filled with
98
+ /// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within
99
+ /// the top-K values as being stored "correctly."</param>
100
+ /// <returns>The evaluation metrics.</returns>
101
+ public static MultiClassClassifierEvaluator . Result Evaluate < T , TKey > (
102
+ this MulticlassClassificationContext ctx ,
103
+ DataView < T > data ,
104
+ Func < T , Key < uint , TKey > > label ,
105
+ Func < T , ( Vector < float > score , Key < uint , TKey > predictedLabel ) > pred ,
106
+ int topK = 0 )
107
+ {
108
+ Contracts . CheckValue ( data , nameof ( data ) ) ;
109
+ var env = StaticPipeUtils . GetEnvironment ( data ) ;
110
+ Contracts . AssertValue ( env ) ;
111
+ env . CheckValue ( label , nameof ( label ) ) ;
112
+ env . CheckValue ( pred , nameof ( pred ) ) ;
113
+ env . CheckParam ( topK >= 0 , nameof ( topK ) , "Must not be negative." ) ;
114
+
115
+ var indexer = StaticPipeUtils . GetIndexer ( data ) ;
116
+ string labelName = indexer . Get ( label ( indexer . Indices ) ) ;
117
+ ( var scoreCol , var predCol ) = pred ( indexer . Indices ) ;
118
+ Contracts . CheckParam ( scoreCol != null , nameof ( pred ) , "Indexing delegate resulted in null score column." ) ;
119
+ Contracts . CheckParam ( predCol != null , nameof ( pred ) , "Indexing delegate resulted in null predicted label column." ) ;
120
+ string scoreName = indexer . Get ( scoreCol ) ;
121
+ string predName = indexer . Get ( predCol ) ;
122
+
123
+ var args = new MultiClassClassifierEvaluator . Arguments ( ) { } ;
124
+ if ( topK > 0 )
125
+ args . OutputTopKAcc = topK ;
126
+
127
+ var eval = new MultiClassClassifierEvaluator ( env , args ) ;
128
+ return eval . Evaluate ( data . AsDynamic , labelName , scoreName , predName ) ;
129
+ }
130
+
131
+ private sealed class TrivialRegressionLossFactory : ISupportRegressionLossFactory
132
+ {
133
+ private readonly IRegressionLoss _loss ;
134
+ public TrivialRegressionLossFactory ( IRegressionLoss loss ) => _loss = loss ;
135
+ public IRegressionLoss CreateComponent ( IHostEnvironment env ) => _loss ;
136
+ }
137
+
138
+ /// <summary>
139
+ /// Evaluates scored multiclass classification data.
140
+ /// </summary>
141
+ /// <typeparam name="T">The shape type for the input data.</typeparam>
142
+ /// <param name="ctx">The regression context.</param>
143
+ /// <param name="data">The data to evaluate.</param>
144
+ /// <param name="label">The index delegate for the label column.</param>
145
+ /// <param name="score">The index delegate for predicted score column.</param>
146
+ /// <param name="loss">Potentially custom loss function. If left unspecified defaults to <see cref="SquaredLoss"/>.</param>
147
+ /// <returns>The evaluation metrics.</returns>
148
+ public static RegressionEvaluator . Result Evaluate < T > (
149
+ this RegressionContext ctx ,
150
+ DataView < T > data ,
151
+ Func < T , Scalar < float > > label ,
152
+ Func < T , Scalar < float > > score ,
153
+ IRegressionLoss loss = null )
154
+ {
155
+ Contracts . CheckValue ( data , nameof ( data ) ) ;
156
+ var env = StaticPipeUtils . GetEnvironment ( data ) ;
157
+ Contracts . AssertValue ( env ) ;
158
+ env . CheckValue ( label , nameof ( label ) ) ;
159
+ env . CheckValue ( score , nameof ( score ) ) ;
160
+
161
+ var indexer = StaticPipeUtils . GetIndexer ( data ) ;
162
+ string labelName = indexer . Get ( label ( indexer . Indices ) ) ;
163
+ string scoreName = indexer . Get ( score ( indexer . Indices ) ) ;
164
+
165
+ var args = new RegressionEvaluator . Arguments ( ) { } ;
166
+ if ( loss != null )
167
+ args . LossFunction = new TrivialRegressionLossFactory ( loss ) ;
168
+ return new RegressionEvaluator ( env , args ) . Evaluate ( data . AsDynamic , labelName , scoreName ) ;
169
+ }
86
170
}
87
171
}
0 commit comments