@@ -163,7 +163,7 @@ private void RunCore(IChannel ch, string cmd)
163
163
RoleMappedData validData = null ;
164
164
if ( ! string . IsNullOrWhiteSpace ( Args . ValidationFile ) )
165
165
{
166
- if ( ! TrainUtils . CanUseValidationData ( trainer ) )
166
+ if ( ! trainer . Info . SupportsValidation )
167
167
{
168
168
ch . Warning ( "Ignoring validationFile: Trainer does not accept validation dataset." ) ;
169
169
}
@@ -242,39 +242,32 @@ public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData
242
242
}
243
243
244
244
private static IPredictor TrainCore ( IHostEnvironment env , IChannel ch , RoleMappedData data , ITrainer trainer , string name , RoleMappedData validData ,
245
- ICalibratorTrainer calibrator , int maxCalibrationExamples , bool ? cacheData , IPredictor inpPredictor = null )
245
+ ICalibratorTrainer calibrator , int maxCalibrationExamples , bool ? cacheData , IPredictor inputPredictor = null )
246
246
{
247
247
Contracts . CheckValue ( env , nameof ( env ) ) ;
248
248
env . CheckValue ( ch , nameof ( ch ) ) ;
249
249
ch . CheckValue ( data , nameof ( data ) ) ;
250
250
ch . CheckValue ( trainer , nameof ( trainer ) ) ;
251
251
ch . CheckNonEmpty ( name , nameof ( name ) ) ;
252
252
ch . CheckValueOrNull ( validData ) ;
253
- ch . CheckValueOrNull ( inpPredictor ) ;
253
+ ch . CheckValueOrNull ( inputPredictor ) ;
254
254
255
255
AddCacheIfWanted ( env , ch , trainer , ref data , cacheData ) ;
256
256
ch . Trace ( "Training" ) ;
257
257
if ( validData != null )
258
258
AddCacheIfWanted ( env , ch , trainer , ref validData , cacheData ) ;
259
259
260
- var trainerEx = trainer as ITrainerEx ;
261
- if ( inpPredictor != null && trainerEx ? . SupportsIncrementalTraining != true )
260
+ if ( inputPredictor != null && ! trainer . Info . SupportsIncrementalTraining )
262
261
{
263
262
ch . Warning ( "Ignoring " + nameof ( TrainCommand . Arguments . InputModelFile ) +
264
263
": Trainer does not support incremental training." ) ;
265
- inpPredictor = null ;
264
+ inputPredictor = null ;
266
265
}
267
- ch . Assert ( validData == null || CanUseValidationData ( trainer ) ) ;
268
- var predictor = trainer . Train ( new TrainContext ( data , validData , inpPredictor ) ) ;
266
+ ch . Assert ( validData == null || trainer . Info . SupportsValidation ) ;
267
+ var predictor = trainer . Train ( new TrainContext ( data , validData , inputPredictor ) ) ;
269
268
return CalibratorUtils . TrainCalibratorIfNeeded ( env , ch , calibrator , maxCalibrationExamples , trainer , predictor , data ) ;
270
269
}
271
270
272
- public static bool CanUseValidationData ( ITrainer trainer )
273
- {
274
- Contracts . CheckValue ( trainer , nameof ( trainer ) ) ;
275
- return ( trainer as ITrainerEx ) ? . SupportsValidation ?? false ;
276
- }
277
-
278
271
public static bool TryLoadPredictor ( IChannel ch , IHostEnvironment env , string inputModelFile , out IPredictor inputPredictor )
279
272
{
280
273
Contracts . AssertValue ( env ) ;
@@ -388,9 +381,8 @@ public static void SaveDataPipe(IHostEnvironment env, RepositoryWriter repositor
388
381
IDataView pipeStart ;
389
382
var xfs = BacktrackPipe ( dataPipe , out pipeStart ) ;
390
383
391
- IDataLoader loader ;
392
384
Action < ModelSaveContext > saveAction ;
393
- if ( ! blankLoader && ( loader = pipeStart as IDataLoader ) != null )
385
+ if ( ! blankLoader && pipeStart is IDataLoader loader )
394
386
saveAction = loader . Save ;
395
387
else
396
388
{
@@ -460,7 +452,7 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
460
452
if ( autoNorm != NormalizeOption . Yes )
461
453
{
462
454
DvBool isNormalized = DvBool . False ;
463
- if ( trainer . NeedNormalization ( ) != true || schema . IsNormalized ( featCol ) )
455
+ if ( ! trainer . Info . NeedNormalization || schema . IsNormalized ( featCol ) )
464
456
{
465
457
ch . Info ( "Not adding a normalizer." ) ;
466
458
return false ;
@@ -491,8 +483,7 @@ private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer
491
483
ch . AssertValue ( trainer , nameof ( trainer ) ) ;
492
484
ch . AssertValue ( data , nameof ( data ) ) ;
493
485
494
- ITrainerEx trainerEx = trainer as ITrainerEx ;
495
- bool shouldCache = cacheData ?? ( ! ( data . Data is BinaryLoader ) && ( trainerEx == null || trainerEx . WantCaching ) ) ;
486
+ bool shouldCache = cacheData ?? ! ( data . Data is BinaryLoader ) && trainer . Info . WantCaching ;
496
487
497
488
if ( shouldCache )
498
489
{
0 commit comments