Skip to content

Commit c9fbaf8

Browse files
committed
No more superclass referencing subclass
1 parent 0e5988a commit c9fbaf8

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

src/Microsoft.ML.FastTree/FastTree.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ public abstract class FastTreeTrainerBase<TArgs, TPredictor> :
8585

8686
public bool HasCategoricalFeatures => Utils.Size(CategoricalFeatures) > 0;
8787

88+
private protected virtual bool NeedCalibration => false;
89+
8890
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args)
8991
: base(env, RegisterName)
9092
{
@@ -93,7 +95,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args)
9395
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
9496
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
9597
// Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.
96-
Info = new TrainerInfo(normalization: false, caching: false, calibration: this is FastForestClassification);
98+
Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration);
9799
int numThreads = Args.NumThreads ?? Environment.ProcessorCount;
98100
if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
99101
{

src/Microsoft.ML.FastTree/GamTrainer.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ public sealed class Arguments : ArgumentsBase
107107
internal const string ShortName = "gam";
108108

109109
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
110+
private protected override bool NeedCalibration => true;
110111

111112
public BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args)
112113
: base(env, args) { }
@@ -225,6 +226,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
225226
protected int[] FeatureMap;
226227

227228
public override TrainerInfo Info { get; }
229+
private protected virtual bool NeedCalibration => false;
228230

229231
private protected GamTrainerBase(IHostEnvironment env, TArgs args)
230232
: base(env, RegisterName)
@@ -240,7 +242,7 @@ private protected GamTrainerBase(IHostEnvironment env, TArgs args)
240242
Host.CheckParam(0 < args.NumIterations, nameof(args.NumIterations), "Must be positive.");
241243

242244
Args = args;
243-
Info = new TrainerInfo(normalization: false, calibration: this is BinaryClassificationGamTrainer, caching: false);
245+
Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false);
244246
_gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - Args.GainConfidenceLevel) * 0.5), 2);
245247
_entropyCoefficient = Args.EntropyCoefficient * 1e-6;
246248
int numThreads = args.NumThreads ?? Environment.ProcessorCount;

src/Microsoft.ML.FastTree/RandomForestClassification.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,11 @@ public sealed class Arguments : FastForestArgumentsBase
130130
private bool[] _trainSetLabels;
131131

132132
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
133+
private protected override bool NeedCalibration => true;
133134

134135
public FastForestClassification(IHostEnvironment env, Arguments args)
135136
: base(env, args)
136137
{
137-
138138
}
139139

140140
public override IPredictorWithFeatureWeights<Float> Train(TrainContext context)

0 commit comments

Comments
 (0)