@@ -25,6 +25,31 @@ public abstract class TrainCatalogBase
25
25
[ BestFriend ]
26
26
internal IHostEnvironment Environment => Host ;
27
27
28
+ /// <summary>
29
+ /// A pair of datasets, for the train and test set.
30
+ /// </summary>
31
+ public struct TrainTestData
32
+ {
33
+ /// <summary>
34
+ /// Training set.
35
+ /// </summary>
36
+ public readonly IDataView TrainSet ;
37
+ /// <summary>
38
+ /// Testing set.
39
+ /// </summary>
40
+ public readonly IDataView TestSet ;
41
+ /// <summary>
42
+ /// Create pair of datasets.
43
+ /// </summary>
44
+ /// <param name="trainSet">Training set.</param>
45
+ /// <param name="testSet">Testing set.</param>
46
+ internal TrainTestData ( IDataView trainSet , IDataView testSet )
47
+ {
48
+ TrainSet = trainSet ;
49
+ TestSet = testSet ;
50
+ }
51
+ }
52
+
28
53
/// <summary>
29
54
/// Split the dataset into the train set and test set according to the given fraction.
30
55
/// Respects the <paramref name="stratificationColumn"/> if provided.
@@ -37,8 +62,7 @@ public abstract class TrainCatalogBase
37
62
/// <param name="seed">Optional parameter used in combination with the <paramref name="stratificationColumn"/>.
38
63
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
39
64
/// And if it is not provided, the default value will be used.</param>
40
- /// <returns>A pair of datasets, for the train and test set.</returns>
41
- public ( IDataView trainSet , IDataView testSet ) TrainTestSplit ( IDataView data , double testFraction = 0.1 , string stratificationColumn = null , uint ? seed = null )
65
+ public TrainTestData TrainTestSplit ( IDataView data , double testFraction = 0.1 , string stratificationColumn = null , uint ? seed = null )
42
66
{
43
67
Host . CheckValue ( data , nameof ( data ) ) ;
44
68
Host . CheckParam ( 0 < testFraction && testFraction < 1 , nameof ( testFraction ) , "Must be between 0 and 1 exclusive" ) ;
@@ -61,14 +85,71 @@ public abstract class TrainCatalogBase
61
85
Complement = false
62
86
} , data ) ;
63
87
64
- return ( trainFilter , testFilter ) ;
88
+ return new TrainTestData ( trainFilter , testFilter ) ;
89
+ }
90
+
91
+ /// <summary>
92
+ /// Results for specific cross-validation fold.
93
+ /// </summary>
94
+ protected internal struct CrossValidationResult
95
+ {
96
+ /// <summary>
97
+ /// Model trained during cross validation fold.
98
+ /// </summary>
99
+ public readonly ITransformer Model ;
100
+ /// <summary>
101
+ /// Scored test set with <see cref="Model"/> for this fold.
102
+ /// </summary>
103
+ public readonly IDataView Scores ;
104
+ /// <summary>
105
+ /// Fold number.
106
+ /// </summary>
107
+ public readonly int Fold ;
108
+
109
+ public CrossValidationResult ( ITransformer model , IDataView scores , int fold )
110
+ {
111
+ Model = model ;
112
+ Scores = scores ;
113
+ Fold = fold ;
114
+ }
115
+ }
116
+ /// <summary>
117
+ /// Results of running cross-validation.
118
+ /// </summary>
119
+ /// <typeparam name="T">Type of metric class.</typeparam>
120
+ public sealed class CrossValidationResult < T > where T : class
121
+ {
122
+ /// <summary>
123
+ /// Metrics for this cross-validation fold.
124
+ /// </summary>
125
+ public readonly T Metrics ;
126
+ /// <summary>
127
+ /// Model trained during cross-validation fold.
128
+ /// </summary>
129
+ public readonly ITransformer Model ;
130
+ /// <summary>
131
+ /// The scored hold-out set for this fold.
132
+ /// </summary>
133
+ public readonly IDataView ScoredHoldOutSet ;
134
+ /// <summary>
135
+ /// Fold number.
136
+ /// </summary>
137
+ public readonly int Fold ;
138
+
139
+ internal CrossValidationResult ( ITransformer model , T metrics , IDataView scores , int fold )
140
+ {
141
+ Model = model ;
142
+ Metrics = metrics ;
143
+ ScoredHoldOutSet = scores ;
144
+ Fold = fold ;
145
+ }
65
146
}
66
147
67
148
/// <summary>
68
149
/// Train the <paramref name="estimator"/> on <paramref name="numFolds"/> folds of the data sequentially.
69
150
/// Return each model and each scored test dataset.
70
151
/// </summary>
71
- protected internal ( IDataView scoredTestSet , ITransformer model ) [ ] CrossValidateTrain ( IDataView data , IEstimator < ITransformer > estimator ,
152
+ protected internal CrossValidationResult [ ] CrossValidateTrain ( IDataView data , IEstimator < ITransformer > estimator ,
72
153
int numFolds , string stratificationColumn , uint ? seed = null )
73
154
{
74
155
Host . CheckValue ( data , nameof ( data ) ) ;
@@ -78,7 +159,7 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate
78
159
79
160
EnsureStratificationColumn ( ref data , ref stratificationColumn , seed ) ;
80
161
81
- Func < int , ( IDataView scores , ITransformer model ) > foldFunction =
162
+ Func < int , CrossValidationResult > foldFunction =
82
163
fold =>
83
164
{
84
165
var trainFilter = new RangeFilter ( Host , new RangeFilter . Options
@@ -98,17 +179,17 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate
98
179
99
180
var model = estimator . Fit ( trainFilter ) ;
100
181
var scoredTest = model . Transform ( testFilter ) ;
101
- return ( scoredTest , model ) ;
182
+ return new CrossValidationResult ( model , scoredTest , fold ) ;
102
183
} ;
103
184
104
185
// Sequential per-fold training.
105
186
// REVIEW: we could have a parallel implementation here. We would need to
106
187
// spawn off a separate host per fold in that case.
107
- var result = new List < ( IDataView scores , ITransformer model ) > ( ) ;
188
+ var result = new CrossValidationResult [ numFolds ] ;
108
189
for ( int fold = 0 ; fold < numFolds ; fold ++ )
109
- result . Add ( foldFunction ( fold ) ) ;
190
+ result [ fold ] = foldFunction ( fold ) ;
110
191
111
- return result . ToArray ( ) ;
192
+ return result ;
112
193
}
113
194
114
195
protected internal TrainCatalogBase ( IHostEnvironment env , string registrationName )
@@ -263,13 +344,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
263
344
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
264
345
/// And if it is not provided, the default value will be used.</param>
265
346
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
266
- public ( BinaryClassificationMetrics metrics , ITransformer model , IDataView scoredTestData ) [ ] CrossValidateNonCalibrated (
347
+ public CrossValidationResult < BinaryClassificationMetrics > [ ] CrossValidateNonCalibrated (
267
348
IDataView data , IEstimator < ITransformer > estimator , int numFolds = 5 , string labelColumn = DefaultColumnNames . Label ,
268
349
string stratificationColumn = null , uint ? seed = null )
269
350
{
270
351
Host . CheckNonEmpty ( labelColumn , nameof ( labelColumn ) ) ;
271
352
var result = CrossValidateTrain ( data , estimator , numFolds , stratificationColumn , seed ) ;
272
- return result . Select ( x => ( EvaluateNonCalibrated ( x . scoredTestSet , labelColumn ) , x . model , x . scoredTestSet ) ) . ToArray ( ) ;
353
+ return result . Select ( x => new CrossValidationResult < BinaryClassificationMetrics > ( x . Model ,
354
+ EvaluateNonCalibrated ( x . Scores , labelColumn ) , x . Scores , x . Fold ) ) . ToArray ( ) ;
273
355
}
274
356
275
357
/// <summary>
@@ -287,13 +369,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
287
369
/// train to the test set.</remarks>
288
370
/// <param name="seed">If <paramref name="stratificationColumn"/> not present in dataset we will generate random filled column based on provided <paramref name="seed"/>.</param>
289
371
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
290
- public ( CalibratedBinaryClassificationMetrics metrics , ITransformer model , IDataView scoredTestData ) [ ] CrossValidate (
372
+ public CrossValidationResult < CalibratedBinaryClassificationMetrics > [ ] CrossValidate (
291
373
IDataView data , IEstimator < ITransformer > estimator , int numFolds = 5 , string labelColumn = DefaultColumnNames . Label ,
292
374
string stratificationColumn = null , uint ? seed = null )
293
375
{
294
376
Host . CheckNonEmpty ( labelColumn , nameof ( labelColumn ) ) ;
295
377
var result = CrossValidateTrain ( data , estimator , numFolds , stratificationColumn , seed ) ;
296
- return result . Select ( x => ( Evaluate ( x . scoredTestSet , labelColumn ) , x . model , x . scoredTestSet ) ) . ToArray ( ) ;
378
+ return result . Select ( x => new CrossValidationResult < CalibratedBinaryClassificationMetrics > ( x . Model ,
379
+ Evaluate ( x . Scores , labelColumn ) , x . Scores , x . Fold ) ) . ToArray ( ) ;
297
380
}
298
381
}
299
382
@@ -369,12 +452,13 @@ public ClusteringMetrics Evaluate(IDataView data,
369
452
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
370
453
/// And if it is not provided, the default value will be used.</param>
371
454
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
372
- public ( ClusteringMetrics metrics , ITransformer model , IDataView scoredTestData ) [ ] CrossValidate (
455
+ public CrossValidationResult < ClusteringMetrics > [ ] CrossValidate (
373
456
IDataView data , IEstimator < ITransformer > estimator , int numFolds = 5 , string labelColumn = null , string featuresColumn = null ,
374
457
string stratificationColumn = null , uint ? seed = null )
375
458
{
376
459
var result = CrossValidateTrain ( data , estimator , numFolds , stratificationColumn , seed ) ;
377
- return result . Select ( x => ( Evaluate ( x . scoredTestSet , label : labelColumn , features : featuresColumn ) , x . model , x . scoredTestSet ) ) . ToArray ( ) ;
460
+ return result . Select ( x => new CrossValidationResult < ClusteringMetrics > ( x . Model ,
461
+ Evaluate ( x . Scores , label : labelColumn , features : featuresColumn ) , x . Scores , x . Fold ) ) . ToArray ( ) ;
378
462
}
379
463
}
380
464
@@ -444,13 +528,14 @@ public MultiClassClassifierMetrics Evaluate(IDataView data, string label = Defau
444
528
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
445
529
/// And if it is not provided, the default value will be used.</param>
446
530
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
447
- public ( MultiClassClassifierMetrics metrics , ITransformer model , IDataView scoredTestData ) [ ] CrossValidate (
531
+ public CrossValidationResult < MultiClassClassifierMetrics > [ ] CrossValidate (
448
532
IDataView data , IEstimator < ITransformer > estimator , int numFolds = 5 , string labelColumn = DefaultColumnNames . Label ,
449
533
string stratificationColumn = null , uint ? seed = null )
450
534
{
451
535
Host . CheckNonEmpty ( labelColumn , nameof ( labelColumn ) ) ;
452
536
var result = CrossValidateTrain ( data , estimator , numFolds , stratificationColumn , seed ) ;
453
- return result . Select ( x => ( Evaluate ( x . scoredTestSet , labelColumn ) , x . model , x . scoredTestSet ) ) . ToArray ( ) ;
537
+ return result . Select ( x => new CrossValidationResult < MultiClassClassifierMetrics > ( x . Model ,
538
+ Evaluate ( x . Scores , labelColumn ) , x . Scores , x . Fold ) ) . ToArray ( ) ;
454
539
}
455
540
}
456
541
@@ -511,13 +596,14 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa
511
596
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
512
597
/// And if it is not provided, the default value will be used.</param>
513
598
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
514
- public ( RegressionMetrics metrics , ITransformer model , IDataView scoredTestData ) [ ] CrossValidate (
599
+ public CrossValidationResult < RegressionMetrics > [ ] CrossValidate (
515
600
IDataView data , IEstimator < ITransformer > estimator , int numFolds = 5 , string labelColumn = DefaultColumnNames . Label ,
516
601
string stratificationColumn = null , uint ? seed = null )
517
602
{
518
603
Host . CheckNonEmpty ( labelColumn , nameof ( labelColumn ) ) ;
519
604
var result = CrossValidateTrain ( data , estimator , numFolds , stratificationColumn , seed ) ;
520
- return result . Select ( x => ( Evaluate ( x . scoredTestSet , labelColumn ) , x . model , x . scoredTestSet ) ) . ToArray ( ) ;
605
+ return result . Select ( x => new CrossValidationResult < RegressionMetrics > ( x . Model ,
606
+ Evaluate ( x . Scores , labelColumn ) , x . Scores , x . Fold ) ) . ToArray ( ) ;
521
607
}
522
608
}
523
609
0 commit comments