@@ -28,8 +28,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase
28
28
[ Argument ( ArgumentType . Multiple , HelpText = "Trainer to use" , ShortName = "tr" ) ]
29
29
public SubComponent < ITrainer , SignatureTrainer > Trainer = new SubComponent < ITrainer , SignatureTrainer > ( "AveragedPerceptron" ) ;
30
30
31
- [ Argument ( ArgumentType . Multiple , HelpText = "Scorer to use" , NullName = "<Auto>" , SortOrder = 101 ) ]
32
- public SubComponent < IDataScorerTransform , SignatureDataScorer > Scorer ;
31
+ [ Argument ( ArgumentType . Multiple , HelpText = "Scorer to use" , NullName = "<Auto>" , SortOrder = 101 , SignatureType = typeof ( SignatureDataScorer ) ) ]
32
+ public IComponentFactory < IDataView , ISchemaBoundMapper , RoleMappedSchema , IDataScorerTransform > Scorer ;
33
33
34
34
[ Argument ( ArgumentType . Multiple , HelpText = "Evaluator to use" , ShortName = "eval" , NullName = "<Auto>" , SortOrder = 102 ) ]
35
35
public SubComponent < IMamlEvaluator , SignatureMamlEvaluator > Evaluator ;
@@ -76,8 +76,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase
76
76
[ Argument ( ArgumentType . AtMostOnce , IsInputFileName = true , HelpText = "The validation data file" , ShortName = "valid" ) ]
77
77
public string ValidationFile ;
78
78
79
- [ Argument ( ArgumentType . Multiple , HelpText = "Output calibrator" , ShortName = "cali" , NullName = "<None>" ) ]
80
- public SubComponent < ICalibratorTrainer , SignatureCalibrator > Calibrator = new SubComponent < ICalibratorTrainer , SignatureCalibrator > ( "PlattCalibration" ) ;
79
+ [ Argument ( ArgumentType . Multiple , HelpText = "Output calibrator" , ShortName = "cali" , NullName = "<None>" , SignatureType = typeof ( SignatureCalibrator ) ) ]
80
+ public IComponentFactory < ICalibratorTrainer > Calibrator = new PlattCalibratorTrainerFactory ( ) ;
81
81
82
82
[ Argument ( ArgumentType . LastOccurenceWins , HelpText = "Number of instances to train the calibrator" , ShortName = "numcali" ) ]
83
83
public int MaxCalibrationExamples = 1000000000 ;
@@ -383,9 +383,9 @@ public FoldResult(Dictionary<string, IDataView> metrics, ISchema scoreSchema, Ro
383
383
private readonly string _splitColumn ;
384
384
private readonly int _numFolds ;
385
385
private readonly SubComponent < ITrainer , SignatureTrainer > _trainer ;
386
- private readonly SubComponent < IDataScorerTransform , SignatureDataScorer > _scorer ;
386
+ private readonly IComponentFactory < IDataView , ISchemaBoundMapper , RoleMappedSchema , IDataScorerTransform > _scorer ;
387
387
private readonly SubComponent < IMamlEvaluator , SignatureMamlEvaluator > _evaluator ;
388
- private readonly SubComponent < ICalibratorTrainer , SignatureCalibrator > _calibrator ;
388
+ private readonly IComponentFactory < ICalibratorTrainer > _calibrator ;
389
389
private readonly int _maxCalibrationExamples ;
390
390
private readonly bool _useThreads ;
391
391
private readonly bool ? _cacheData ;
@@ -423,7 +423,7 @@ public FoldHelper(
423
423
Arguments args ,
424
424
Func < IHostEnvironment , IChannel , IDataView , ITrainer , RoleMappedData > createExamples ,
425
425
Func < IHostEnvironment , IChannel , IDataView , RoleMappedData , IDataView , RoleMappedData > applyTransformsToTestData ,
426
- SubComponent < IDataScorerTransform , SignatureDataScorer > scorer ,
426
+ IComponentFactory < IDataView , ISchemaBoundMapper , RoleMappedSchema , IDataScorerTransform > scorer ,
427
427
SubComponent < IMamlEvaluator , SignatureMamlEvaluator > evaluator ,
428
428
Func < IDataView > getValidationDataView = null ,
429
429
Func < IHostEnvironment , IChannel , IDataView , RoleMappedData , IDataView , RoleMappedData > applyTransformsToValidationData = null ,
@@ -559,11 +559,12 @@ private FoldResult RunFold(int fold)
559
559
560
560
// Score.
561
561
ch . Trace ( "Scoring and evaluating" ) ;
562
- var bindable = ScoreUtils . GetSchemaBindableMapper ( host , predictor , _scorer ) ;
562
+ ch . Assert ( _scorer == null || _scorer is ICommandLineComponentFactory , "CrossValidationCommand should only be used from the command line." ) ;
563
+ var bindable = ScoreUtils . GetSchemaBindableMapper ( host , predictor , scorerFactorySettings : _scorer as ICommandLineComponentFactory ) ;
563
564
ch . AssertValue ( bindable ) ;
564
565
var mapper = bindable . Bind ( host , testData . Schema ) ;
565
- var scorerComp = _scorer . IsGood ( ) ? _scorer : ScoreUtils . GetScorerComponent ( mapper ) ;
566
- IDataScorerTransform scorePipe = scorerComp . CreateInstance ( host , testData . Data , mapper , trainData . Schema ) ;
566
+ var scorerComp = _scorer ?? ScoreUtils . GetScorerComponent ( mapper ) ;
567
+ IDataScorerTransform scorePipe = scorerComp . CreateComponent ( host , testData . Data , mapper , trainData . Schema ) ;
567
568
568
569
// Save per-fold model.
569
570
string modelFileName = ConstructPerFoldName ( _outputModelFile , fold ) ;
0 commit comments