2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
5
+ using Microsoft . ML . Data ;
5
6
using Microsoft . ML . Functional . Tests . Datasets ;
6
7
using Microsoft . ML . RunTests ;
7
8
using Microsoft . ML . TestFramework ;
8
9
using Microsoft . ML . TestFramework . Attributes ;
10
+ using Microsoft . ML . Tools ;
9
11
using Microsoft . ML . Trainers ;
10
12
using Microsoft . ML . Trainers . FastTree ;
11
13
using Xunit ;
@@ -165,14 +167,8 @@ public void TrainAndEvaluateMulticlassClassification()
165
167
Common . AssertMetrics ( metrics ) ;
166
168
}
167
169
168
- /// <summary>
169
- /// Train and Evaluate: Ranking.
170
- /// </summary>
171
- [ Fact ]
172
- public void TrainAndEvaluateRanking ( )
170
+ private IDataView GetScoredDataForRankingEvaluation ( MLContext mlContext )
173
171
{
174
- var mlContext = new MLContext ( seed : 1 ) ;
175
-
176
172
var data = Iris . LoadAsRankingProblem ( mlContext ,
177
173
GetDataPath ( TestDatasets . iris . trainFilename ) ,
178
174
hasHeader : TestDatasets . iris . fileHasHeader ,
@@ -187,12 +183,45 @@ public void TrainAndEvaluateRanking()
187
183
188
184
// Evaluate the model.
189
185
var scoredData = model . Transform ( data ) ;
186
+
187
+ return scoredData ;
188
+ }
189
+
190
+ /// <summary>
191
+ /// Train and Evaluate: Ranking.
192
+ /// </summary>
193
+ [ Fact ]
194
+ public void TrainAndEvaluateRanking ( )
195
+ {
196
+ var mlContext = new MLContext ( seed : 1 ) ;
197
+
198
+ var scoredData = GetScoredDataForRankingEvaluation ( mlContext ) ;
190
199
var metrics = mlContext . Ranking . Evaluate ( scoredData , labelColumnName : "Label" , rowGroupColumnName : "GroupId" ) ;
191
200
192
201
// Check that the metrics returned are valid.
193
202
Common . AssertMetrics ( metrics ) ;
194
203
}
195
204
205
+ /// <summary>
206
+ /// Train and Evaluate: Ranking with options.
207
+ /// </summary>
208
+ [ Fact ]
209
+ public void TrainAndEvaluateRankingWithOptions ( )
210
+ {
211
+ var mlContext = new MLContext ( seed : 1 ) ;
212
+ int [ ] tlevels = { 50 , 150 , 100 } ;
213
+ var options = new RankingEvaluatorOptions ( ) ;
214
+ foreach ( int i in tlevels )
215
+ {
216
+ options . DcgTruncationLevel = i ;
217
+ var scoredData = GetScoredDataForRankingEvaluation ( mlContext ) ;
218
+ var metrics = mlContext . Ranking . Evaluate ( scoredData , options , labelColumnName : "Label" , rowGroupColumnName : "GroupId" ) ;
219
+ Common . AssertMetrics ( metrics ) ;
220
+ }
221
+
222
+
223
+ }
224
+
196
225
/// <summary>
197
226
/// Train and Evaluate: Recommendation.
198
227
/// </summary>
0 commit comments