@@ -43,14 +43,10 @@ public interface IEarlyStoppingCriterionFactory : IComponentFactory<bool, IEarly
43
43
new IEarlyStoppingCriterion CreateComponent ( IHostEnvironment env , bool lowerIsBetter ) ;
44
44
}
45
45
46
- public abstract class EarlyStoppingCriterion < TOptions > : IEarlyStoppingCriterion
47
- where TOptions : EarlyStoppingCriterion < TOptions > . OptionsBase
46
+ public abstract class EarlyStoppingCriterion : IEarlyStoppingCriterion
48
47
{
49
- public abstract class OptionsBase { }
50
-
51
48
private float _bestScore ;
52
49
53
- protected readonly TOptions EarlyStoppingCriterionOptions ;
54
50
protected readonly bool LowerIsBetter ;
55
51
protected float BestScore {
56
52
get { return _bestScore ; }
@@ -61,9 +57,8 @@ protected float BestScore {
61
57
}
62
58
}
63
59
64
- internal EarlyStoppingCriterion ( TOptions options , bool lowerIsBetter )
60
+ internal EarlyStoppingCriterion ( bool lowerIsBetter )
65
61
{
66
- EarlyStoppingCriterionOptions = options ;
67
62
LowerIsBetter = lowerIsBetter ;
68
63
_bestScore = LowerIsBetter ? float . PositiveInfinity : float . NegativeInfinity ;
69
64
}
@@ -83,25 +78,34 @@ protected bool CheckBestScore(float score)
83
78
}
84
79
}
85
80
86
- public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion < TolerantEarlyStoppingCriterion . Options >
81
+ public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion
87
82
{
88
83
[ TlcModule . Component ( FriendlyName = "Tolerant (TR)" , Name = "TR" , Desc = "Stop if validation score exceeds threshold value." ) ]
89
- public sealed class Options : OptionsBase , IEarlyStoppingCriterionFactory
84
+ public sealed class Options : IEarlyStoppingCriterionFactory
90
85
{
91
86
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Tolerance threshold. (Non negative value)" , ShortName = "th" ) ]
92
87
[ TlcModule . Range ( Min = 0.0f ) ]
93
88
public float Threshold = 0.01f ;
94
89
95
90
public IEarlyStoppingCriterion CreateComponent ( IHostEnvironment env , bool lowerIsBetter )
96
91
{
97
- return new TolerantEarlyStoppingCriterion ( this , lowerIsBetter ) ;
92
+ return new TolerantEarlyStoppingCriterion ( Threshold , lowerIsBetter ) ;
98
93
}
99
94
}
100
95
101
- public TolerantEarlyStoppingCriterion ( Options options , bool lowerIsBetter )
102
- : base ( options , lowerIsBetter )
96
+ public float Threshold { get ; }
97
+
98
+ public TolerantEarlyStoppingCriterion ( float threshold , bool lowerIsBetter = true )
99
+ : base ( lowerIsBetter )
100
+ {
101
+ Contracts . CheckUserArg ( threshold >= 0 , nameof ( threshold ) , "Must be non-negative." ) ;
102
+ Threshold = threshold ;
103
+ }
104
+
105
+ [ BestFriend ]
106
+ internal TolerantEarlyStoppingCriterion ( Options options , bool lowerIsBetter = true )
107
+ : this ( options . Threshold , lowerIsBetter )
103
108
{
104
- Contracts . CheckUserArg ( EarlyStoppingCriterionOptions . Threshold >= 0 , nameof ( options . Threshold ) , "Must be non-negative." ) ;
105
109
}
106
110
107
111
public override bool CheckScore ( float validationScore , float trainingScore , out bool isBestCandidate )
@@ -111,19 +115,19 @@ public override bool CheckScore(float validationScore, float trainingScore, out
111
115
isBestCandidate = CheckBestScore ( validationScore ) ;
112
116
113
117
if ( LowerIsBetter )
114
- return ( validationScore - BestScore > EarlyStoppingCriterionOptions . Threshold ) ;
118
+ return ( validationScore - BestScore > Threshold ) ;
115
119
else
116
- return ( BestScore - validationScore > EarlyStoppingCriterionOptions . Threshold ) ;
120
+ return ( BestScore - validationScore > Threshold ) ;
117
121
}
118
122
}
119
123
120
124
// For the detail of the following rules, see the following paper.
121
125
// Lodwich, Aleksander, Yves Rangoni, and Thomas Breuel. "Evaluation of robustness and performance of early stopping rules with multi layer perceptrons."
122
126
// Neural Networks, 2009. IJCNN 2009. International Joint Conference on. IEEE, 2009.
123
127
124
- public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion < MovingWindowEarlyStoppingCriterion . Options >
128
+ public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion
125
129
{
126
- public class Options : OptionsBase
130
+ public class Options
127
131
{
128
132
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Threshold in range [0,1]." , ShortName = "th" ) ]
129
133
[ TlcModule . Range ( Min = 0.0f , Max = 1.0f ) ]
@@ -134,15 +138,20 @@ public class Options : OptionsBase
134
138
public int WindowSize = 5 ;
135
139
}
136
140
141
+ public float Threshold { get ; }
142
+ public int WindowSize { get ; }
143
+
137
144
protected Queue < float > PastScores ;
138
145
139
- private protected MovingWindowEarlyStoppingCriterion ( Options args , bool lowerIsBetter )
140
- : base ( args , lowerIsBetter )
146
+ private protected MovingWindowEarlyStoppingCriterion ( bool lowerIsBetter , float threshold = 0.01f , int windowSize = 5 )
147
+ : base ( lowerIsBetter )
141
148
{
142
- Contracts . CheckUserArg ( 0 <= EarlyStoppingCriterionOptions . Threshold && args . Threshold <= 1 , nameof ( args . Threshold ) , "Must be in range [0,1]." ) ;
143
- Contracts . CheckUserArg ( EarlyStoppingCriterionOptions . WindowSize > 0 , nameof ( args . WindowSize ) , "Must be positive." ) ;
149
+ Contracts . CheckUserArg ( 0 <= threshold && threshold <= 1 , nameof ( threshold ) , "Must be in range [0,1]." ) ;
150
+ Contracts . CheckUserArg ( windowSize > 0 , nameof ( windowSize ) , "Must be positive." ) ;
144
151
145
- PastScores = new Queue < float > ( EarlyStoppingCriterionOptions . WindowSize ) ;
152
+ Threshold = threshold ;
153
+ WindowSize = windowSize ;
154
+ PastScores = new Queue < float > ( windowSize ) ;
146
155
}
147
156
148
157
/// <summary>
@@ -200,26 +209,35 @@ protected bool CheckRecentScores(float score, int windowSize, out float recentBe
200
209
/// <summary>
201
210
/// Loss of Generality (GL).
202
211
/// </summary>
203
- public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion < GLEarlyStoppingCriterion . Options >
212
+ public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion
204
213
{
205
214
[ TlcModule . Component ( FriendlyName = "Loss of Generality (GL)" , Name = "GL" ,
206
215
Desc = "Stop in case of loss of generality." ) ]
207
- public sealed class Options : OptionsBase , IEarlyStoppingCriterionFactory
216
+ public sealed class Options : IEarlyStoppingCriterionFactory
208
217
{
209
218
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Threshold in range [0,1]." , ShortName = "th" ) ]
210
219
[ TlcModule . Range ( Min = 0.0f , Max = 1.0f ) ]
211
220
public float Threshold = 0.01f ;
212
221
213
222
public IEarlyStoppingCriterion CreateComponent ( IHostEnvironment env , bool lowerIsBetter )
214
223
{
215
- return new GLEarlyStoppingCriterion ( this , lowerIsBetter ) ;
224
+ return new GLEarlyStoppingCriterion ( lowerIsBetter , Threshold ) ;
216
225
}
217
226
}
218
227
219
- public GLEarlyStoppingCriterion ( Options options , bool lowerIsBetter )
220
- : base ( options , lowerIsBetter )
228
+ public float Threshold { get ; }
229
+
230
+ public GLEarlyStoppingCriterion ( bool lowerIsBetter = true , float threshold = 0.01f ) :
231
+ base ( lowerIsBetter )
232
+ {
233
+ Contracts . CheckUserArg ( 0 <= threshold && threshold <= 1 , nameof ( threshold ) , "Must be in range [0,1]." ) ;
234
+ Threshold = threshold ;
235
+ }
236
+
237
+ [ BestFriend ]
238
+ internal GLEarlyStoppingCriterion ( Options options , bool lowerIsBetter = true )
239
+ : this ( lowerIsBetter , options . Threshold )
221
240
{
222
- Contracts . CheckUserArg ( 0 <= EarlyStoppingCriterionOptions . Threshold && options . Threshold <= 1 , nameof ( options . Threshold ) , "Must be in range [0,1]." ) ;
223
241
}
224
242
225
243
public override bool CheckScore ( float validationScore , float trainingScore , out bool isBestCandidate )
@@ -229,9 +247,9 @@ public override bool CheckScore(float validationScore, float trainingScore, out
229
247
isBestCandidate = CheckBestScore ( validationScore ) ;
230
248
231
249
if ( LowerIsBetter )
232
- return ( validationScore > ( 1 + EarlyStoppingCriterionOptions . Threshold ) * BestScore ) ;
250
+ return ( validationScore > ( 1 + Threshold ) * BestScore ) ;
233
251
else
234
- return ( validationScore < ( 1 - EarlyStoppingCriterionOptions . Threshold ) * BestScore ) ;
252
+ return ( validationScore < ( 1 - Threshold ) * BestScore ) ;
235
253
}
236
254
}
237
255
@@ -246,12 +264,20 @@ public sealed class LPEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterio
246
264
{
247
265
public IEarlyStoppingCriterion CreateComponent ( IHostEnvironment env , bool lowerIsBetter )
248
266
{
249
- return new LPEarlyStoppingCriterion ( this , lowerIsBetter ) ;
267
+ return new LPEarlyStoppingCriterion ( lowerIsBetter , Threshold , WindowSize ) ;
250
268
}
251
269
}
252
270
253
- public LPEarlyStoppingCriterion ( Options options , bool lowerIsBetter )
254
- : base ( options , lowerIsBetter ) { }
271
+ public LPEarlyStoppingCriterion ( bool lowerIsBetter , float threshold = 0.01f , int windowSize = 5 )
272
+ : base ( lowerIsBetter , threshold , windowSize )
273
+ {
274
+ }
275
+
276
+ [ BestFriend ]
277
+ internal LPEarlyStoppingCriterion ( Options options , bool lowerIsBetter = true )
278
+ : this ( lowerIsBetter , options . Threshold , options . WindowSize )
279
+ {
280
+ }
255
281
256
282
public override bool CheckScore ( float validationScore , float trainingScore , out bool isBestCandidate )
257
283
{
@@ -262,12 +288,12 @@ public override bool CheckScore(float validationScore, float trainingScore, out
262
288
263
289
float recentBest ;
264
290
float recentAverage ;
265
- if ( CheckRecentScores ( trainingScore , EarlyStoppingCriterionOptions . WindowSize , out recentBest , out recentAverage ) )
291
+ if ( CheckRecentScores ( trainingScore , WindowSize , out recentBest , out recentAverage ) )
266
292
{
267
293
if ( LowerIsBetter )
268
- return ( recentAverage <= ( 1 + EarlyStoppingCriterionOptions . Threshold ) * recentBest ) ;
294
+ return ( recentAverage <= ( 1 + Threshold ) * recentBest ) ;
269
295
else
270
- return ( recentAverage >= ( 1 - EarlyStoppingCriterionOptions . Threshold ) * recentBest ) ;
296
+ return ( recentAverage >= ( 1 - Threshold ) * recentBest ) ;
271
297
}
272
298
273
299
return false ;
@@ -284,12 +310,20 @@ public sealed class PQEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterio
284
310
{
285
311
public IEarlyStoppingCriterion CreateComponent ( IHostEnvironment env , bool lowerIsBetter )
286
312
{
287
- return new PQEarlyStoppingCriterion ( this , lowerIsBetter ) ;
313
+ return new PQEarlyStoppingCriterion ( lowerIsBetter , Threshold , WindowSize ) ;
288
314
}
289
315
}
290
316
291
- public PQEarlyStoppingCriterion ( Options options , bool lowerIsBetter )
292
- : base ( options , lowerIsBetter ) { }
317
+ public PQEarlyStoppingCriterion ( bool lowerIsBetter , float threshold = 0.01f , int windowSize = 5 )
318
+ : base ( lowerIsBetter , threshold , windowSize )
319
+ {
320
+ }
321
+
322
+ [ BestFriend ]
323
+ internal PQEarlyStoppingCriterion ( Options options , bool lowerIsBetter = true )
324
+ : this ( lowerIsBetter , options . Threshold , options . WindowSize )
325
+ {
326
+ }
293
327
294
328
public override bool CheckScore ( float validationScore , float trainingScore , out bool isBestCandidate )
295
329
{
@@ -300,12 +334,12 @@ public override bool CheckScore(float validationScore, float trainingScore, out
300
334
301
335
float recentBest ;
302
336
float recentAverage ;
303
- if ( CheckRecentScores ( trainingScore , EarlyStoppingCriterionOptions . WindowSize , out recentBest , out recentAverage ) )
337
+ if ( CheckRecentScores ( trainingScore , WindowSize , out recentBest , out recentAverage ) )
304
338
{
305
339
if ( LowerIsBetter )
306
- return ( validationScore * recentBest >= ( 1 + EarlyStoppingCriterionOptions . Threshold ) * BestScore * recentAverage ) ;
340
+ return ( validationScore * recentBest >= ( 1 + Threshold ) * BestScore * recentAverage ) ;
307
341
else
308
- return ( validationScore * recentBest <= ( 1 - EarlyStoppingCriterionOptions . Threshold ) * BestScore * recentAverage ) ;
342
+ return ( validationScore * recentBest <= ( 1 - Threshold ) * BestScore * recentAverage ) ;
309
343
}
310
344
311
345
return false ;
@@ -315,33 +349,40 @@ public override bool CheckScore(float validationScore, float trainingScore, out
315
349
/// <summary>
316
350
/// Consecutive Loss in Generality (UP).
317
351
/// </summary>
318
- public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion < UPEarlyStoppingCriterion . Options >
352
+ public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion
319
353
{
320
354
[ TlcModule . Component ( FriendlyName = "Consecutive Loss in Generality (UP)" , Name = "UP" ,
321
355
Desc = "Stops in case of consecutive loss in generality." ) ]
322
- public sealed class Options : OptionsBase , IEarlyStoppingCriterionFactory
356
+ public sealed class Options : IEarlyStoppingCriterionFactory
323
357
{
324
358
[ Argument ( ArgumentType . AtMostOnce , HelpText = "The window size." , ShortName = "w" ) ]
325
359
[ TlcModule . Range ( Inf = 0 ) ]
326
360
public int WindowSize = 5 ;
327
361
328
362
public IEarlyStoppingCriterion CreateComponent ( IHostEnvironment env , bool lowerIsBetter )
329
363
{
330
- return new UPEarlyStoppingCriterion ( this , lowerIsBetter ) ;
364
+ return new UPEarlyStoppingCriterion ( lowerIsBetter , WindowSize ) ;
331
365
}
332
366
}
333
367
368
+ public int WindowSize { get ; }
334
369
private int _count ;
335
370
private float _prevScore ;
336
371
337
- public UPEarlyStoppingCriterion ( Options options , bool lowerIsBetter )
338
- : base ( options , lowerIsBetter )
372
+ public UPEarlyStoppingCriterion ( bool lowerIsBetter , int windowSize = 5 )
373
+ : base ( lowerIsBetter )
339
374
{
340
- Contracts . CheckUserArg ( EarlyStoppingCriterionOptions . WindowSize > 0 , nameof ( options . WindowSize ) , "Must be positive" ) ;
341
-
375
+ Contracts . CheckUserArg ( windowSize > 0 , nameof ( windowSize ) , "Must be positive" ) ;
376
+ WindowSize = windowSize ;
342
377
_prevScore = LowerIsBetter ? float . PositiveInfinity : float . NegativeInfinity ;
343
378
}
344
379
380
+ [ BestFriend ]
381
+ internal UPEarlyStoppingCriterion ( Options options , bool lowerIsBetter = true )
382
+ : this ( lowerIsBetter , options . WindowSize )
383
+ {
384
+ }
385
+
345
386
public override bool CheckScore ( float validationScore , float trainingScore , out bool isBestCandidate )
346
387
{
347
388
Contracts . Assert ( validationScore >= 0 ) ;
@@ -351,7 +392,7 @@ public override bool CheckScore(float validationScore, float trainingScore, out
351
392
_count = ( ( validationScore < _prevScore ) != LowerIsBetter ) ? _count + 1 : 0 ;
352
393
_prevScore = validationScore ;
353
394
354
- return ( _count >= EarlyStoppingCriterionOptions . WindowSize ) ;
395
+ return ( _count >= WindowSize ) ;
355
396
}
356
397
}
357
398
}
0 commit comments