Skip to content

Commit 5a6c9c6

Browse files
committed
First version of new early stopping rule.
Generate missing entry points
1 parent 3c61add commit 5a6c9c6

File tree

10 files changed

+190
-87
lines changed

10 files changed

+190
-87
lines changed

src/Microsoft.ML.FastTree/BoostingFastTree.cs

+14-6
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ private protected override void CheckOptions(IChannel ch)
4949
if (FastTreeTrainerOptions.EnablePruning && !HasValidSet)
5050
throw ch.Except("Cannot perform pruning (pruning) without a validation set (valid).");
5151

52-
if (FastTreeTrainerOptions.EarlyStoppingRule != null && !HasValidSet)
52+
bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || FastTreeTrainerOptions.EarlyStoppingRule != null;
53+
if (doEarlyStop && !HasValidSet)
5354
throw ch.Except("Cannot perform early stopping without a validation set (valid).");
5455

5556
if (FastTreeTrainerOptions.UseTolerantPruning && (!FastTreeTrainerOptions.EnablePruning || !HasValidSet))
@@ -113,9 +114,9 @@ private protected override IGradientAdjuster MakeGradientWrapper(IChannel ch)
113114
return new BestStepRegressionGradientWrapper();
114115
}
115116

116-
private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStoppingRule, ref int bestIteration)
117+
private protected override bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStoppingRule, ref int bestIteration)
117118
{
118-
if (FastTreeTrainerOptions.EarlyStoppingRule == null)
119+
if (FastTreeTrainerOptions.EarlyStoppingRuleFactory == null && FastTreeTrainerOptions.EarlyStoppingRule == null)
119120
return false;
120121

121122
ch.AssertValue(ValidTest);
@@ -128,13 +129,20 @@ private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriter
128129
var trainingResult = TrainTest.ComputeTests().First();
129130
ch.Assert(trainingResult.FinalValue >= 0);
130131

131-
// Create early stopping rule.
132+
// Create early stopping rule if it's null.
132133
if (earlyStoppingRule == null)
133134
{
134-
earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRule.CreateComponent(Host, lowerIsBetter);
135-
ch.Assert(earlyStoppingRule != null);
135+
// There are two possible sources of stopping rules. One is the classical IComponentFactory and
136+
// the other one is the rule passed in directly by user.
137+
if (FastTreeTrainerOptions.EarlyStoppingRuleFactory != null)
138+
earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRuleFactory.CreateComponent(Host, lowerIsBetter);
139+
else if (FastTreeTrainerOptions.EarlyStoppingRule != null)
140+
earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRule;
136141
}
137142

143+
// Early stopping rule cannot be null!
144+
ch.Assert(earlyStoppingRule != null);
145+
138146
bool isBestCandidate;
139147
bool shouldStop = earlyStoppingRule.CheckScore((float)validationResult.FinalValue,
140148
(float)trainingResult.FinalValue, out isBestCandidate);

src/Microsoft.ML.FastTree/FastTree.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ private protected void TrainCore(IChannel ch)
245245
}
246246
}
247247

248-
private protected virtual bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStopping, ref int bestIteration)
248+
private protected virtual bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStopping, ref int bestIteration)
249249
{
250250
bestIteration = Ensemble.NumTrees;
251251
return false;
@@ -650,7 +650,7 @@ private protected virtual void Train(IChannel ch)
650650
#endif
651651
#endif
652652

653-
IEarlyStoppingCriterion earlyStoppingRule = null;
653+
EarlyStoppingRuleBase earlyStoppingRule = null;
654654
int bestIteration = 0;
655655
int emptyTrees = 0;
656656
using (var pch = Host.StartProgressChannel("FastTree training"))

src/Microsoft.ML.FastTree/FastTreeArguments.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -482,9 +482,12 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc
482482
/// <summary>
483483
/// Early stopping rule. (Validation set (/valid) is required).
484484
/// </summary>
485+
[BestFriend]
485486
[Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", ShortName = "esr", NullName = "<Disable>")]
486487
[TGUI(Label = "Early Stopping Rule", Description = "Early stopping rule. (Validation set (/valid) is required.)")]
487-
public IEarlyStoppingCriterionFactory EarlyStoppingRule;
488+
internal IEarlyStoppingCriterionFactory EarlyStoppingRuleFactory;
489+
490+
public EarlyStoppingRuleBase EarlyStoppingRule;
488491

489492
/// <summary>
490493
/// Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3).

src/Microsoft.ML.FastTree/FastTreeRanking.cs

+7-2
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,13 @@ private protected override void CheckOptions(IChannel ch)
166166
Dataset.DatasetSkeleton.LabelGainMap = gain;
167167
}
168168

169-
ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
170-
"earlyStoppingMetrics should be 1 or 3.");
169+
bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
170+
FastTreeTrainerOptions.EarlyStoppingRule != null ||
171+
FastTreeTrainerOptions.EnablePruning;
172+
173+
if (doEarlyStop)
174+
ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3,
175+
nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 or 3.");
171176

172177
base.CheckOptions(ch);
173178
}

src/Microsoft.ML.FastTree/FastTreeRegression.cs

+7-2
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,13 @@ private protected override void CheckOptions(IChannel ch)
106106

107107
base.CheckOptions(ch);
108108

109-
ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
110-
"earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
109+
bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
110+
FastTreeTrainerOptions.EarlyStoppingRule != null ||
111+
FastTreeTrainerOptions.EnablePruning;
112+
113+
if (doEarlyStop)
114+
ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2,
115+
nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
111116
}
112117

113118
private static SchemaShape.Column MakeLabelColumn(string labelColumn)

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

+8-3
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,17 @@ private protected override void CheckOptions(IChannel ch)
113113
// REVIEW: In order to properly support early stopping, the early stopping metric should be a subcomponent, not just
114114
// a simple integer, because the metric that we might want is parameterized by this floating point "index" parameter. For now
115115
// we just leave the existing regression checks, though with a warning.
116-
117116
if (FastTreeTrainerOptions.EarlyStoppingMetrics > 0)
118117
ch.Warning("For Tweedie regression, early stopping does not yet use the Tweedie distribution.");
119118

120-
ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
121-
"earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
119+
bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
120+
FastTreeTrainerOptions.EarlyStoppingRule != null ||
121+
FastTreeTrainerOptions.EnablePruning;
122+
123+
// Please do not remove it! See comment above.
124+
if (doEarlyStop)
125+
ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 2,
126+
nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 (L1-norm) or 2 (L2-norm).");
122127
}
123128

124129
private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)

0 commit comments

Comments
 (0)