@@ -49,7 +49,8 @@ private protected override void CheckOptions(IChannel ch)
49
49
if ( FastTreeTrainerOptions . EnablePruning && ! HasValidSet )
50
50
throw ch . Except ( "Cannot perform pruning (pruning) without a validation set (valid)." ) ;
51
51
52
- if ( FastTreeTrainerOptions . EarlyStoppingRule != null && ! HasValidSet )
52
+ bool doEarlyStop = FastTreeTrainerOptions . EarlyStoppingRuleFactory != null || FastTreeTrainerOptions . EarlyStoppingRule != null ;
53
+ if ( doEarlyStop && ! HasValidSet )
53
54
throw ch . Except ( "Cannot perform early stopping without a validation set (valid)." ) ;
54
55
55
56
if ( FastTreeTrainerOptions . UseTolerantPruning && ( ! FastTreeTrainerOptions . EnablePruning || ! HasValidSet ) )
@@ -113,9 +114,9 @@ private protected override IGradientAdjuster MakeGradientWrapper(IChannel ch)
113
114
return new BestStepRegressionGradientWrapper ( ) ;
114
115
}
115
116
116
- private protected override bool ShouldStop ( IChannel ch , ref IEarlyStoppingCriterion earlyStoppingRule , ref int bestIteration )
117
+ private protected override bool ShouldStop ( IChannel ch , ref EarlyStoppingRuleBase earlyStoppingRule , ref int bestIteration )
117
118
{
118
- if ( FastTreeTrainerOptions . EarlyStoppingRule == null )
119
+ if ( FastTreeTrainerOptions . EarlyStoppingRuleFactory == null && FastTreeTrainerOptions . EarlyStoppingRule == null )
119
120
return false ;
120
121
121
122
ch . AssertValue ( ValidTest ) ;
@@ -128,13 +129,20 @@ private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriter
128
129
var trainingResult = TrainTest . ComputeTests ( ) . First ( ) ;
129
130
ch . Assert ( trainingResult . FinalValue >= 0 ) ;
130
131
131
- // Create early stopping rule.
132
+ // Create early stopping rule if it's null .
132
133
if ( earlyStoppingRule == null )
133
134
{
134
- earlyStoppingRule = FastTreeTrainerOptions . EarlyStoppingRule . CreateComponent ( Host , lowerIsBetter ) ;
135
- ch . Assert ( earlyStoppingRule != null ) ;
135
+ // There are two possible sources of stopping rules. One is the classical IComponentFactory and
136
+ // the other one is the rule passed in directly by user.
137
+ if ( FastTreeTrainerOptions . EarlyStoppingRuleFactory != null )
138
+ earlyStoppingRule = FastTreeTrainerOptions . EarlyStoppingRuleFactory . CreateComponent ( Host , lowerIsBetter ) ;
139
+ else if ( FastTreeTrainerOptions . EarlyStoppingRule != null )
140
+ earlyStoppingRule = FastTreeTrainerOptions . EarlyStoppingRule ;
136
141
}
137
142
143
+ // Early stopping rule cannot be null!
144
+ ch . Assert ( earlyStoppingRule != null ) ;
145
+
138
146
bool isBestCandidate ;
139
147
bool shouldStop = earlyStoppingRule . CheckScore ( ( float ) validationResult . FinalValue ,
140
148
( float ) trainingResult . FinalValue , out isBestCandidate ) ;
0 commit comments