Skip to content

Commit 887aad2

Browse files
authored
Polish early stop rules in fast tree (#2851)
1 parent c21a6d6 commit 887aad2

File tree

9 files changed

+308
-104
lines changed

9 files changed

+308
-104
lines changed

src/Microsoft.ML.FastTree/BoostingFastTree.cs

+10-6
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ private protected override void CheckOptions(IChannel ch)
4949
if (FastTreeTrainerOptions.EnablePruning && !HasValidSet)
5050
throw ch.Except("Cannot perform pruning (pruning) without a validation set (valid).");
5151

52-
if (FastTreeTrainerOptions.EarlyStoppingRule != null && !HasValidSet)
52+
bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null;
53+
if (doEarlyStop && !HasValidSet)
5354
throw ch.Except("Cannot perform early stopping without a validation set (valid).");
5455

5556
if (FastTreeTrainerOptions.UseTolerantPruning && (!FastTreeTrainerOptions.EnablePruning || !HasValidSet))
@@ -113,9 +114,9 @@ private protected override IGradientAdjuster MakeGradientWrapper(IChannel ch)
113114
return new BestStepRegressionGradientWrapper();
114115
}
115116

116-
private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStoppingRule, ref int bestIteration)
117+
private protected override bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStoppingRule, ref int bestIteration)
117118
{
118-
if (FastTreeTrainerOptions.EarlyStoppingRule == null)
119+
if (FastTreeTrainerOptions.EarlyStoppingRuleFactory == null)
119120
return false;
120121

121122
ch.AssertValue(ValidTest);
@@ -128,13 +129,16 @@ private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriter
128129
var trainingResult = TrainTest.ComputeTests().First();
129130
ch.Assert(trainingResult.FinalValue >= 0);
130131

131-
// Create early stopping rule.
132+
// Create early stopping rule if it's null.
132133
if (earlyStoppingRule == null)
133134
{
134-
earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRule.CreateComponent(Host, lowerIsBetter);
135-
ch.Assert(earlyStoppingRule != null);
135+
if (FastTreeTrainerOptions.EarlyStoppingRuleFactory != null)
136+
earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRuleFactory.CreateComponent(Host, lowerIsBetter);
136137
}
137138

139+
// Early stopping rule cannot be null!
140+
ch.Assert(earlyStoppingRule != null);
141+
138142
bool isBestCandidate;
139143
bool shouldStop = earlyStoppingRule.CheckScore((float)validationResult.FinalValue,
140144
(float)trainingResult.FinalValue, out isBestCandidate);

src/Microsoft.ML.FastTree/FastTree.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ private protected void TrainCore(IChannel ch)
245245
}
246246
}
247247

248-
private protected virtual bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStopping, ref int bestIteration)
248+
private protected virtual bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStopping, ref int bestIteration)
249249
{
250250
bestIteration = Ensemble.NumTrees;
251251
return false;
@@ -650,7 +650,7 @@ private protected virtual void Train(IChannel ch)
650650
#endif
651651
#endif
652652

653-
IEarlyStoppingCriterion earlyStoppingRule = null;
653+
EarlyStoppingRuleBase earlyStoppingRule = null;
654654
int bestIteration = 0;
655655
int emptyTrees = 0;
656656
using (var pch = Host.StartProgressChannel("FastTree training"))

src/Microsoft.ML.FastTree/FastTreeArguments.cs

+22-2
Original file line numberDiff line numberDiff line change
@@ -621,9 +621,29 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc
621621
/// <summary>
622622
/// Early stopping rule. (Validation set (/valid) is required).
623623
/// </summary>
624-
[Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", ShortName = "esr", NullName = "<Disable>")]
624+
[BestFriend]
625+
[Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", Name = "EarlyStoppingRule", ShortName = "esr", NullName = "<Disable>")]
625626
[TGUI(Label = "Early Stopping Rule", Description = "Early stopping rule. (Validation set (/valid) is required.)")]
626-
public IEarlyStoppingCriterionFactory EarlyStoppingRule;
627+
internal IEarlyStoppingCriterionFactory EarlyStoppingRuleFactory;
628+
629+
/// <summary>
630+
/// The underlying state of <see cref="EarlyStoppingRuleFactory"/> and <see cref="EarlyStoppingRule"/>.
631+
/// </summary>
632+
private EarlyStoppingRuleBase _earlyStoppingRuleBase;
633+
634+
/// <summary>
635+
/// Early stopping rule used to terminate training process once meeting a specified criterion. Possible choices are
636+
/// <see cref="EarlyStoppingRuleBase"/>'s implementations such as <see cref="TolerantEarlyStoppingRule"/> and <see cref="GeneralityLossRule"/>.
637+
/// </summary>
638+
public EarlyStoppingRuleBase EarlyStoppingRule
639+
{
640+
get { return _earlyStoppingRuleBase; }
641+
set
642+
{
643+
_earlyStoppingRuleBase = value;
644+
EarlyStoppingRuleFactory = _earlyStoppingRuleBase.BuildFactory();
645+
}
646+
}
627647

628648
/// <summary>
629649
/// Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3).

src/Microsoft.ML.FastTree/FastTreeRanking.cs

+6-2
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,12 @@ private protected override void CheckOptions(IChannel ch)
156156
Dataset.DatasetSkeleton.LabelGainMap = gains;
157157
}
158158

159-
ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
160-
"earlyStoppingMetrics should be 1 or 3.");
159+
bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
160+
FastTreeTrainerOptions.EnablePruning;
161+
162+
if (doEarlyStop)
163+
ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3,
164+
nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 or 3.");
161165

162166
base.CheckOptions(ch);
163167
}

src/Microsoft.ML.FastTree/FastTreeRegression.cs

+6-2
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,12 @@ private protected override void CheckOptions(IChannel ch)
105105

106106
base.CheckOptions(ch);
107107

108-
ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
109-
"earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
108+
bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
109+
FastTreeTrainerOptions.EnablePruning;
110+
111+
if (doEarlyStop)
112+
ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2,
113+
nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
110114
}
111115

112116
private static SchemaShape.Column MakeLabelColumn(string labelColumn)

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

+7-3
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,16 @@ private protected override void CheckOptions(IChannel ch)
112112
// REVIEW: In order to properly support early stopping, the early stopping metric should be a subcomponent, not just
113113
// a simple integer, because the metric that we might want is parameterized by this floating point "index" parameter. For now
114114
// we just leave the existing regression checks, though with a warning.
115-
116115
if (FastTreeTrainerOptions.EarlyStoppingMetrics > 0)
117116
ch.Warning("For Tweedie regression, early stopping does not yet use the Tweedie distribution.");
118117

119-
ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
120-
"earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
118+
bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
119+
FastTreeTrainerOptions.EnablePruning;
120+
121+
// Please do not remove it! See comment above.
122+
if (doEarlyStop)
123+
ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 2,
124+
nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 (L1-norm) or 2 (L2-norm).");
121125
}
122126

123127
private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)

0 commit comments

Comments
 (0)