Skip to content

Commit 3c61add

Browse files
committed
Make some changes.
1 parent 768e3bf commit 3c61add

File tree

2 files changed

+92
-51
lines changed

2 files changed

+92
-51
lines changed

src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs

+91-50
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,10 @@ public interface IEarlyStoppingCriterionFactory : IComponentFactory<bool, IEarly
4343
new IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter);
4444
}
4545

46-
public abstract class EarlyStoppingCriterion<TOptions> : IEarlyStoppingCriterion
47-
where TOptions : EarlyStoppingCriterion<TOptions>.OptionsBase
46+
public abstract class EarlyStoppingCriterion : IEarlyStoppingCriterion
4847
{
49-
public abstract class OptionsBase { }
50-
5148
private float _bestScore;
5249

53-
protected readonly TOptions EarlyStoppingCriterionOptions;
5450
protected readonly bool LowerIsBetter;
5551
protected float BestScore {
5652
get { return _bestScore; }
@@ -61,9 +57,8 @@ protected float BestScore {
6157
}
6258
}
6359

64-
internal EarlyStoppingCriterion(TOptions options, bool lowerIsBetter)
60+
internal EarlyStoppingCriterion(bool lowerIsBetter)
6561
{
66-
EarlyStoppingCriterionOptions = options;
6762
LowerIsBetter = lowerIsBetter;
6863
_bestScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity;
6964
}
@@ -83,25 +78,34 @@ protected bool CheckBestScore(float score)
8378
}
8479
}
8580

86-
public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion<TolerantEarlyStoppingCriterion.Options>
81+
public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion
8782
{
8883
[TlcModule.Component(FriendlyName = "Tolerant (TR)", Name = "TR", Desc = "Stop if validation score exceeds threshold value.")]
89-
public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory
84+
public sealed class Options : IEarlyStoppingCriterionFactory
9085
{
9186
[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance threshold. (Non negative value)", ShortName = "th")]
9287
[TlcModule.Range(Min = 0.0f)]
9388
public float Threshold = 0.01f;
9489

9590
public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
9691
{
97-
return new TolerantEarlyStoppingCriterion(this, lowerIsBetter);
92+
return new TolerantEarlyStoppingCriterion(Threshold, lowerIsBetter);
9893
}
9994
}
10095

101-
public TolerantEarlyStoppingCriterion(Options options, bool lowerIsBetter)
102-
: base(options, lowerIsBetter)
96+
public float Threshold { get; }
97+
98+
public TolerantEarlyStoppingCriterion(float threshold, bool lowerIsBetter = true)
99+
: base(lowerIsBetter)
100+
{
101+
Contracts.CheckUserArg(threshold >= 0, nameof(threshold), "Must be non-negative.");
102+
Threshold = threshold;
103+
}
104+
105+
[BestFriend]
106+
internal TolerantEarlyStoppingCriterion(Options options, bool lowerIsBetter = true)
107+
: this(options.Threshold, lowerIsBetter)
103108
{
104-
Contracts.CheckUserArg(EarlyStoppingCriterionOptions.Threshold >= 0, nameof(options.Threshold), "Must be non-negative.");
105109
}
106110

107111
public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate)
@@ -111,19 +115,19 @@ public override bool CheckScore(float validationScore, float trainingScore, out
111115
isBestCandidate = CheckBestScore(validationScore);
112116

113117
if (LowerIsBetter)
114-
return (validationScore - BestScore > EarlyStoppingCriterionOptions.Threshold);
118+
return (validationScore - BestScore > Threshold);
115119
else
116-
return (BestScore - validationScore > EarlyStoppingCriterionOptions.Threshold);
120+
return (BestScore - validationScore > Threshold);
117121
}
118122
}
119123

120124
// For the detail of the following rules, see the following paper.
121125
// Lodwich, Aleksander, Yves Rangoni, and Thomas Breuel. "Evaluation of robustness and performance of early stopping rules with multi layer perceptrons."
122126
// Neural Networks, 2009. IJCNN 2009. International Joint Conference on. IEEE, 2009.
123127

124-
public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion<MovingWindowEarlyStoppingCriterion.Options>
128+
public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion
125129
{
126-
public class Options : OptionsBase
130+
public class Options
127131
{
128132
[Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")]
129133
[TlcModule.Range(Min = 0.0f, Max = 1.0f)]
@@ -134,15 +138,20 @@ public class Options : OptionsBase
134138
public int WindowSize = 5;
135139
}
136140

141+
public float Threshold { get; }
142+
public int WindowSize { get; }
143+
137144
protected Queue<float> PastScores;
138145

139-
private protected MovingWindowEarlyStoppingCriterion(Options args, bool lowerIsBetter)
140-
: base(args, lowerIsBetter)
146+
private protected MovingWindowEarlyStoppingCriterion(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5)
147+
: base(lowerIsBetter)
141148
{
142-
Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1].");
143-
Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(args.WindowSize), "Must be positive.");
149+
Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1].");
150+
Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive.");
144151

145-
PastScores = new Queue<float>(EarlyStoppingCriterionOptions.WindowSize);
152+
Threshold = threshold;
153+
WindowSize = windowSize;
154+
PastScores = new Queue<float>(windowSize);
146155
}
147156

148157
/// <summary>
@@ -200,26 +209,35 @@ protected bool CheckRecentScores(float score, int windowSize, out float recentBe
200209
/// <summary>
201210
/// Loss of Generality (GL).
202211
/// </summary>
203-
public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion<GLEarlyStoppingCriterion.Options>
212+
public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion
204213
{
205214
[TlcModule.Component(FriendlyName = "Loss of Generality (GL)", Name = "GL",
206215
Desc = "Stop in case of loss of generality.")]
207-
public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory
216+
public sealed class Options : IEarlyStoppingCriterionFactory
208217
{
209218
[Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")]
210219
[TlcModule.Range(Min = 0.0f, Max = 1.0f)]
211220
public float Threshold = 0.01f;
212221

213222
public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
214223
{
215-
return new GLEarlyStoppingCriterion(this, lowerIsBetter);
224+
return new GLEarlyStoppingCriterion(lowerIsBetter, Threshold);
216225
}
217226
}
218227

219-
public GLEarlyStoppingCriterion(Options options, bool lowerIsBetter)
220-
: base(options, lowerIsBetter)
228+
public float Threshold { get; }
229+
230+
public GLEarlyStoppingCriterion(bool lowerIsBetter = true, float threshold = 0.01f) :
231+
base(lowerIsBetter)
232+
{
233+
Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1].");
234+
Threshold = threshold;
235+
}
236+
237+
[BestFriend]
238+
internal GLEarlyStoppingCriterion(Options options, bool lowerIsBetter = true)
239+
: this(lowerIsBetter, options.Threshold)
221240
{
222-
Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1].");
223241
}
224242

225243
public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate)
@@ -229,9 +247,9 @@ public override bool CheckScore(float validationScore, float trainingScore, out
229247
isBestCandidate = CheckBestScore(validationScore);
230248

231249
if (LowerIsBetter)
232-
return (validationScore > (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore);
250+
return (validationScore > (1 + Threshold) * BestScore);
233251
else
234-
return (validationScore < (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore);
252+
return (validationScore < (1 - Threshold) * BestScore);
235253
}
236254
}
237255

@@ -246,12 +264,20 @@ public sealed class LPEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterio
246264
{
247265
public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
248266
{
249-
return new LPEarlyStoppingCriterion(this, lowerIsBetter);
267+
return new LPEarlyStoppingCriterion(lowerIsBetter, Threshold, WindowSize);
250268
}
251269
}
252270

253-
public LPEarlyStoppingCriterion(Options options, bool lowerIsBetter)
254-
: base(options, lowerIsBetter) { }
271+
public LPEarlyStoppingCriterion(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5)
272+
: base(lowerIsBetter, threshold, windowSize)
273+
{
274+
}
275+
276+
[BestFriend]
277+
internal LPEarlyStoppingCriterion(Options options, bool lowerIsBetter = true)
278+
: this(lowerIsBetter, options.Threshold, options.WindowSize)
279+
{
280+
}
255281

256282
public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate)
257283
{
@@ -262,12 +288,12 @@ public override bool CheckScore(float validationScore, float trainingScore, out
262288

263289
float recentBest;
264290
float recentAverage;
265-
if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage))
291+
if (CheckRecentScores(trainingScore, WindowSize, out recentBest, out recentAverage))
266292
{
267293
if (LowerIsBetter)
268-
return (recentAverage <= (1 + EarlyStoppingCriterionOptions.Threshold) * recentBest);
294+
return (recentAverage <= (1 + Threshold) * recentBest);
269295
else
270-
return (recentAverage >= (1 - EarlyStoppingCriterionOptions.Threshold) * recentBest);
296+
return (recentAverage >= (1 - Threshold) * recentBest);
271297
}
272298

273299
return false;
@@ -284,12 +310,20 @@ public sealed class PQEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterio
284310
{
285311
public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
286312
{
287-
return new PQEarlyStoppingCriterion(this, lowerIsBetter);
313+
return new PQEarlyStoppingCriterion(lowerIsBetter, Threshold, WindowSize);
288314
}
289315
}
290316

291-
public PQEarlyStoppingCriterion(Options options, bool lowerIsBetter)
292-
: base(options, lowerIsBetter) { }
317+
public PQEarlyStoppingCriterion(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5)
318+
: base(lowerIsBetter, threshold, windowSize)
319+
{
320+
}
321+
322+
[BestFriend]
323+
internal PQEarlyStoppingCriterion(Options options, bool lowerIsBetter = true)
324+
: this(lowerIsBetter, options.Threshold, options.WindowSize)
325+
{
326+
}
293327

294328
public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate)
295329
{
@@ -300,12 +334,12 @@ public override bool CheckScore(float validationScore, float trainingScore, out
300334

301335
float recentBest;
302336
float recentAverage;
303-
if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage))
337+
if (CheckRecentScores(trainingScore, WindowSize, out recentBest, out recentAverage))
304338
{
305339
if (LowerIsBetter)
306-
return (validationScore * recentBest >= (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage);
340+
return (validationScore * recentBest >= (1 + Threshold) * BestScore * recentAverage);
307341
else
308-
return (validationScore * recentBest <= (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage);
342+
return (validationScore * recentBest <= (1 - Threshold) * BestScore * recentAverage);
309343
}
310344

311345
return false;
@@ -315,33 +349,40 @@ public override bool CheckScore(float validationScore, float trainingScore, out
315349
/// <summary>
316350
/// Consecutive Loss in Generality (UP).
317351
/// </summary>
318-
public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion<UPEarlyStoppingCriterion.Options>
352+
public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion
319353
{
320354
[TlcModule.Component(FriendlyName = "Consecutive Loss in Generality (UP)", Name = "UP",
321355
Desc = "Stops in case of consecutive loss in generality.")]
322-
public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory
356+
public sealed class Options : IEarlyStoppingCriterionFactory
323357
{
324358
[Argument(ArgumentType.AtMostOnce, HelpText = "The window size.", ShortName = "w")]
325359
[TlcModule.Range(Inf = 0)]
326360
public int WindowSize = 5;
327361

328362
public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
329363
{
330-
return new UPEarlyStoppingCriterion(this, lowerIsBetter);
364+
return new UPEarlyStoppingCriterion(lowerIsBetter, WindowSize);
331365
}
332366
}
333367

368+
public int WindowSize { get; }
334369
private int _count;
335370
private float _prevScore;
336371

337-
public UPEarlyStoppingCriterion(Options options, bool lowerIsBetter)
338-
: base(options, lowerIsBetter)
372+
public UPEarlyStoppingCriterion(bool lowerIsBetter, int windowSize = 5)
373+
: base(lowerIsBetter)
339374
{
340-
Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(options.WindowSize), "Must be positive");
341-
375+
Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive");
376+
WindowSize = windowSize;
342377
_prevScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity;
343378
}
344379

380+
[BestFriend]
381+
internal UPEarlyStoppingCriterion(Options options, bool lowerIsBetter = true)
382+
: this(lowerIsBetter, options.WindowSize)
383+
{
384+
}
385+
345386
public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate)
346387
{
347388
Contracts.Assert(validationScore >= 0);
@@ -351,7 +392,7 @@ public override bool CheckScore(float validationScore, float trainingScore, out
351392
_count = ((validationScore < _prevScore) != LowerIsBetter) ? _count + 1 : 0;
352393
_prevScore = validationScore;
353394

354-
return (_count >= EarlyStoppingCriterionOptions.WindowSize);
395+
return (_count >= WindowSize);
355396
}
356397
}
357398
}

test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -4974,7 +4974,7 @@ public void TestCrossValidationMacroWithNonDefaultNames()
49744974
'NumPostBracketSteps': 0,
49754975
'MinStepSize': 0.0,
49764976
'OptimizationAlgorithm': 'GradientDescent',
4977-
'EarlyStoppingRule': null,
4977+
'EarlyStoppingRuleFactory': null,
49784978
'EarlyStoppingMetrics': 1,
49794979
'EnablePruning': false,
49804980
'UseTolerantPruning': false,

0 commit comments

Comments
 (0)