@@ -107,6 +107,7 @@ public sealed class Arguments : ArgumentsBase
107
107
internal const string ShortName = "gam" ;
108
108
109
109
public override PredictionKind PredictionKind => PredictionKind . BinaryClassification ;
110
+ private protected override bool NeedCalibration => true ;
110
111
111
112
public BinaryClassificationGamTrainer ( IHostEnvironment env , Arguments args )
112
113
: base ( env , args ) { }
@@ -225,6 +226,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
225
226
protected int [ ] FeatureMap ;
226
227
227
228
public override TrainerInfo Info { get ; }
229
+ private protected virtual bool NeedCalibration => false ;
228
230
229
231
private protected GamTrainerBase ( IHostEnvironment env , TArgs args )
230
232
: base ( env , RegisterName )
@@ -240,7 +242,7 @@ private protected GamTrainerBase(IHostEnvironment env, TArgs args)
240
242
Host . CheckParam ( 0 < args . NumIterations , nameof ( args . NumIterations ) , "Must be positive." ) ;
241
243
242
244
Args = args ;
243
- Info = new TrainerInfo ( normalization : false , calibration : this is BinaryClassificationGamTrainer , caching : false ) ;
245
+ Info = new TrainerInfo ( normalization : false , calibration : NeedCalibration , caching : false ) ;
244
246
_gainConfidenceInSquaredStandardDeviations = Math . Pow ( ProbabilityFunctions . Probit ( 1 - ( 1 - Args . GainConfidenceLevel ) * 0.5 ) , 2 ) ;
245
247
_entropyCoefficient = Args . EntropyCoefficient * 1e-6 ;
246
248
int numThreads = args . NumThreads ?? Environment . ProcessorCount ;
0 commit comments