@@ -187,11 +187,11 @@ public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel
187
187
public TModel InnerModel { get ; }
188
188
}
189
189
190
- public class BinaryScorerWrapper < TModel > : ScorerWrapper < TModel >
191
- where TModel : IPredictor
190
+ public class BinaryScorerWrapper < TModel > : ScorerWrapper < TModel >
191
+ where TModel : IPredictor
192
192
{
193
193
public BinaryScorerWrapper ( IHostEnvironment env , TModel model , ISchema inputSchema , string featureColumn , BinaryClassifierScorer . Arguments args )
194
- : base ( env , MakeScorer ( env , inputSchema , featureColumn , model , args ) , model , featureColumn )
194
+ : base ( env , MakeScorer ( env , inputSchema , featureColumn , model , args ) , model , featureColumn )
195
195
{
196
196
}
197
197
@@ -235,7 +235,7 @@ public SchemaShape GetOutputSchema()
235
235
}
236
236
237
237
public abstract class TrainerBase < TTransformer , TModel > : IEstimator < TTransformer >
238
- where TTransformer : ScorerWrapper < TModel >
238
+ where TTransformer : ScorerWrapper < TModel >
239
239
where TModel : IPredictor
240
240
{
241
241
protected readonly IHostEnvironment _env ;
@@ -414,7 +414,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
414
414
}
415
415
}
416
416
417
- public sealed class MySdca : TrainerBase < BinaryScorerWrapper < IPredictor > , IPredictor >
417
+ public sealed class MySdca : TrainerBase < BinaryScorerWrapper < IPredictor > , IPredictor >
418
418
{
419
419
private readonly LinearClassificationTrainer . Arguments _args ;
420
420
@@ -428,7 +428,7 @@ public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args,
428
428
429
429
public ITransformer Train ( IDataView trainData , IDataView validationData = null ) => TrainTransformer ( trainData , validationData ) ;
430
430
431
- protected override BinaryScorerWrapper < IPredictor > MakeScorer ( IPredictor predictor , RoleMappedData data )
431
+ protected override BinaryScorerWrapper < IPredictor > MakeScorer ( IPredictor predictor , RoleMappedData data )
432
432
=> new BinaryScorerWrapper < IPredictor > ( _env , predictor , data . Data . Schema , _featureCol , new BinaryClassifierScorer . Arguments ( ) ) ;
433
433
}
434
434
@@ -512,68 +512,92 @@ public BinaryClassificationMetrics Evaluate(IDataView data, string labelColumn,
512
512
}
513
513
}
514
514
515
- public static class MyHelperExtensions
515
+ public static class MyCrossValidation
516
516
{
517
- public static void SaveAsBinary ( this IDataView data , IHostEnvironment env , Stream stream )
517
+ public sealed class BinaryCrossValidationMetrics
518
518
{
519
- var saver = new BinarySaver ( env , new BinarySaver . Arguments ( ) ) ;
520
- using ( var ch = env . Start ( "SaveData" ) )
521
- DataSaverUtils . SaveDataView ( ch , saver , data , stream ) ;
519
+ public readonly ITransformer [ ] FoldModels ;
520
+ public readonly BinaryClassificationMetrics [ ] FoldMetrics ;
521
+
522
+ public BinaryCrossValidationMetrics ( ITransformer [ ] models , BinaryClassificationMetrics [ ] metrics )
523
+ {
524
+ FoldModels = models ;
525
+ FoldMetrics = metrics ;
526
+ }
522
527
}
523
528
524
- public static IDataView FitAndTransform ( this IEstimator < ITransformer > est , IDataView data ) => est . Fit ( data ) . Transform ( data ) ;
529
+ public sealed class BinaryCrossValidator
530
+ {
531
+ private readonly IHostEnvironment _env ;
525
532
526
- public static IDataView FitAndRead < TSource > ( this IDataReaderEstimator < TSource , IDataReader < TSource > > est , TSource source )
527
- => est . Fit ( source ) . Read ( source ) ;
533
+ public int NumFolds { get ; set ; } = 2 ;
528
534
529
- public static ( ITransformer [ ] Models , BinaryClassificationMetrics [ ] Metrics ) CrossValidateBinary ( IHostEnvironment env , IDataView trainData , IEstimator < ITransformer > estimator ,
530
- string labelColumn ,
531
- int numFolds = 2 ,
532
- string stratificationColumn = null ,
533
- bool cache = false )
534
- {
535
- var models = new ITransformer [ numFolds ] ;
536
- var metrics = new BinaryClassificationMetrics [ numFolds ] ;
535
+ public string StratificationColumn { get ; set ; }
537
536
538
- if ( stratificationColumn == null )
537
+ public string LabelColumn { get ; set ; } = DefaultColumnNames . Label ;
538
+
539
+ public BinaryCrossValidator ( IHostEnvironment env )
539
540
{
540
- stratificationColumn = "StratificationColumn" ;
541
- var random = new GenerateNumberTransform ( env , trainData , stratificationColumn ) ;
542
- trainData = random ;
541
+ _env = env ;
543
542
}
544
- else
545
- throw new NotImplementedException ( ) ;
546
-
547
- IDataView cachedTrain = trainData ;
548
- if ( cache )
549
- cachedTrain = new CacheDataView ( env , trainData , prefetch : null ) ;
550
543
551
- var evaluator = new MyBinaryClassifierEvaluator ( env , new BinaryClassifierEvaluator . Arguments ( ) { } ) ;
552
-
553
- for ( int fold = 0 ; fold < numFolds ; fold ++ )
544
+ public BinaryCrossValidationMetrics CrossValidate ( IDataView trainData , IEstimator < ITransformer > estimator )
554
545
{
555
- var trainFilter = new RangeFilter ( env , new RangeFilter . Arguments ( )
546
+ var models = new ITransformer [ NumFolds ] ;
547
+ var metrics = new BinaryClassificationMetrics [ NumFolds ] ;
548
+
549
+ if ( StratificationColumn == null )
556
550
{
557
- Column = stratificationColumn ,
558
- Min = ( Double ) fold / numFolds ,
559
- Max = ( Double ) ( fold + 1 ) / numFolds ,
560
- Complement = true
561
- } , cachedTrain ) ;
562
- var testFilter = new RangeFilter ( env , new RangeFilter . Arguments ( )
551
+ StratificationColumn = "StratificationColumn" ;
552
+ var random = new GenerateNumberTransform ( _env , trainData , StratificationColumn ) ;
553
+ trainData = random ;
554
+ }
555
+ else
556
+ throw new NotImplementedException ( ) ;
557
+
558
+ var evaluator = new MyBinaryClassifierEvaluator ( _env , new BinaryClassifierEvaluator . Arguments ( ) { } ) ;
559
+
560
+ for ( int fold = 0 ; fold < NumFolds ; fold ++ )
563
561
{
564
- Column = stratificationColumn ,
565
- Min = ( Double ) fold / numFolds ,
566
- Max = ( Double ) ( fold + 1 ) / numFolds ,
567
- Complement = false
568
- } , cachedTrain ) ;
569
-
570
- models [ fold ] = estimator . Fit ( trainFilter ) ;
571
- var scoredTest = models [ fold ] . Transform ( testFilter ) ;
572
- metrics [ fold ] = evaluator . Evaluate ( scoredTest , labelColumn : labelColumn , probabilityColumn : "Probability" ) ;
562
+ var trainFilter = new RangeFilter ( _env , new RangeFilter . Arguments ( )
563
+ {
564
+ Column = StratificationColumn ,
565
+ Min = ( Double ) fold / NumFolds ,
566
+ Max = ( Double ) ( fold + 1 ) / NumFolds ,
567
+ Complement = true
568
+ } , trainData ) ;
569
+ var testFilter = new RangeFilter ( _env , new RangeFilter . Arguments ( )
570
+ {
571
+ Column = StratificationColumn ,
572
+ Min = ( Double ) fold / NumFolds ,
573
+ Max = ( Double ) ( fold + 1 ) / NumFolds ,
574
+ Complement = false
575
+ } , trainData ) ;
576
+
577
+ models [ fold ] = estimator . Fit ( trainFilter ) ;
578
+ var scoredTest = models [ fold ] . Transform ( testFilter ) ;
579
+ metrics [ fold ] = evaluator . Evaluate ( scoredTest , labelColumn : LabelColumn , probabilityColumn : "Probability" ) ;
580
+ }
581
+
582
+ return new BinaryCrossValidationMetrics ( models , metrics ) ;
583
+
573
584
}
585
+ }
586
+ }
574
587
575
- return ( models , metrics ) ;
588
+
589
+ public static class MyHelperExtensions
590
+ {
591
+ public static void SaveAsBinary ( this IDataView data , IHostEnvironment env , Stream stream )
592
+ {
593
+ var saver = new BinarySaver ( env , new BinarySaver . Arguments ( ) ) ;
594
+ using ( var ch = env . Start ( "SaveData" ) )
595
+ DataSaverUtils . SaveDataView ( ch , saver , data , stream ) ;
576
596
}
577
597
598
+ public static IDataView FitAndTransform ( this IEstimator < ITransformer > est , IDataView data ) => est . Fit ( data ) . Transform ( data ) ;
599
+
600
+ public static IDataView FitAndRead < TSource > ( this IDataReaderEstimator < TSource , IDataReader < TSource > > est , TSource source )
601
+ => est . Fit ( source ) . Read ( source ) ;
578
602
}
579
603
}
0 commit comments