@@ -26,29 +26,29 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
26
26
[ Argument ( ArgumentType . AtMostOnce , HelpText = "L2 regularization weight" , ShortName = "l2" , SortOrder = 50 ) ]
27
27
[ TGUI ( Label = "L2 Weight" , Description = "Weight of L2 regularizer term" , SuggestedSweeps = "0,0.1,1" ) ]
28
28
[ TlcModule . SweepableFloatParamAttribute ( 0.0f , 1.0f , numSteps : 4 ) ]
29
- public float L2Weight = 1 ;
29
+ public float L2Weight = Defaults . L2Weight ;
30
30
31
31
[ Argument ( ArgumentType . AtMostOnce , HelpText = "L1 regularization weight" , ShortName = "l1" , SortOrder = 50 ) ]
32
32
[ TGUI ( Label = "L1 Weight" , Description = "Weight of L1 regularizer term" , SuggestedSweeps = "0,0.1,1" ) ]
33
33
[ TlcModule . SweepableFloatParamAttribute ( 0.0f , 1.0f , numSteps : 4 ) ]
34
- public float L1Weight = 1 ;
34
+ public float L1Weight = Defaults . L1Weight ;
35
35
36
36
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Tolerance parameter for optimization convergence. Lower = slower, more accurate" ,
37
37
ShortName = "ot" , SortOrder = 50 ) ]
38
38
[ TGUI ( Label = "Optimization Tolerance" , Description = "Threshold for optimizer convergence" , SuggestedSweeps = "1e-4,1e-7" ) ]
39
39
[ TlcModule . SweepableDiscreteParamAttribute ( new object [ ] { 1e-4f , 1e-7f } ) ]
40
- public float OptTol = 1e-7f ;
40
+ public float OptTol = Defaults . OptTol ;
41
41
42
42
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Memory size for L-BFGS. Lower=faster, less accurate" ,
43
43
ShortName = "m" , SortOrder = 50 ) ]
44
44
[ TGUI ( Description = "Memory size for L-BFGS" , SuggestedSweeps = "5,20,50" ) ]
45
45
[ TlcModule . SweepableDiscreteParamAttribute ( "MemorySize" , new object [ ] { 5 , 20 , 50 } ) ]
46
- public int MemorySize = 20 ;
46
+ public int MemorySize = Defaults . MemorySize ;
47
47
48
48
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Maximum iterations." , ShortName = "maxiter" ) ]
49
49
[ TGUI ( Label = "Max Number of Iterations" ) ]
50
50
[ TlcModule . SweepableLongParamAttribute ( "MaxIterations" , 1 , int . MaxValue ) ]
51
- public int MaxIterations = int . MaxValue ;
51
+ public int MaxIterations = Defaults . MaxIterations ;
52
52
53
53
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Run SGD to initialize LR weights, converging to this tolerance" ,
54
54
ShortName = "sgd" ) ]
@@ -90,7 +90,17 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
90
90
public bool DenseOptimizer = false ;
91
91
92
92
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Enforce non-negative weights" , ShortName = "nn" , SortOrder = 90 ) ]
93
- public bool EnforceNonNegativity = false ;
93
+ public bool EnforceNonNegativity = Defaults . EnforceNonNegativity ;
94
+
95
+ internal static class Defaults
96
+ {
97
+ internal const float L2Weight = 1 ;
98
+ internal const float L1Weight = 1 ;
99
+ internal const float OptTol = 1e-7f ;
100
+ internal const int MemorySize = 20 ;
101
+ internal const int MaxIterations = int . MaxValue ;
102
+ internal const bool EnforceNonNegativity = false ;
103
+ }
94
104
}
95
105
96
106
private const string RegisterName = nameof ( LbfgsTrainerBase < TArgs , TTransformer , TModel > ) ;
@@ -142,40 +152,56 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
142
152
public override TrainerInfo Info => _info ;
143
153
144
154
internal LbfgsTrainerBase ( IHostEnvironment env , string featureColumn , SchemaShape . Column labelColumn ,
145
- string weightColumn = null , Action < TArgs > advancedSettings = null )
146
- : this ( env , ArgsInit ( featureColumn , labelColumn , weightColumn , advancedSettings ) , labelColumn )
155
+ string weightColumn , Action < TArgs > advancedSettings , float l1Weight ,
156
+ float l2Weight ,
157
+ float optimizationTolerance ,
158
+ int memorySize ,
159
+ bool enforceNoNegativity )
160
+ : this ( env , ArgsInit ( featureColumn , labelColumn , weightColumn , advancedSettings ) , labelColumn ,
161
+ l1Weight , l2Weight , optimizationTolerance , memorySize , enforceNoNegativity )
147
162
{
148
163
}
149
164
150
- internal LbfgsTrainerBase ( IHostEnvironment env , TArgs args , SchemaShape . Column labelColumn )
165
+ internal LbfgsTrainerBase ( IHostEnvironment env , TArgs args , SchemaShape . Column labelColumn ,
166
+ float ? l1Weight = null ,
167
+ float ? l2Weight = null ,
168
+ float ? optimizationTolerance = null ,
169
+ int ? memorySize = null ,
170
+ bool ? enforceNoNegativity = null )
151
171
: base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( RegisterName ) , TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) ,
152
172
labelColumn , TrainerUtils . MakeR4ScalarWeightColumn ( args . WeightColumn ) )
153
173
{
154
174
Host . CheckValue ( args , nameof ( args ) ) ;
155
175
Args = args ;
156
176
157
- Contracts . CheckUserArg ( ! Args . UseThreads || Args . NumThreads > 0 || Args . NumThreads == null ,
177
+ Host . CheckUserArg ( ! Args . UseThreads || Args . NumThreads > 0 || Args . NumThreads == null ,
158
178
nameof ( Args . NumThreads ) , "numThreads must be positive (or empty for default)" ) ;
159
- Contracts . CheckUserArg ( Args . L2Weight >= 0 , nameof ( Args . L2Weight ) , "Must be non-negative" ) ;
160
- Contracts . CheckUserArg ( Args . L1Weight >= 0 , nameof ( Args . L1Weight ) , "Must be non-negative" ) ;
161
- Contracts . CheckUserArg ( Args . OptTol > 0 , nameof ( Args . OptTol ) , "Must be positive" ) ;
162
- Contracts . CheckUserArg ( Args . MemorySize > 0 , nameof ( Args . MemorySize ) , "Must be positive" ) ;
163
- Contracts . CheckUserArg ( Args . MaxIterations > 0 , nameof ( Args . MaxIterations ) , "Must be positive" ) ;
164
- Contracts . CheckUserArg ( Args . SgdInitializationTolerance >= 0 , nameof ( Args . SgdInitializationTolerance ) , "Must be non-negative" ) ;
165
- Contracts . CheckUserArg ( Args . NumThreads == null || Args . NumThreads . Value >= 0 , nameof ( Args . NumThreads ) , "Must be non-negative" ) ;
166
-
167
- L2Weight = Args . L2Weight ;
168
- L1Weight = Args . L1Weight ;
169
- OptTol = Args . OptTol ;
170
- MemorySize = Args . MemorySize ;
179
+ Host . CheckUserArg ( Args . L2Weight >= 0 , nameof ( Args . L2Weight ) , "Must be non-negative" ) ;
180
+ Host . CheckUserArg ( Args . L1Weight >= 0 , nameof ( Args . L1Weight ) , "Must be non-negative" ) ;
181
+ Host . CheckUserArg ( Args . OptTol > 0 , nameof ( Args . OptTol ) , "Must be positive" ) ;
182
+ Host . CheckUserArg ( Args . MemorySize > 0 , nameof ( Args . MemorySize ) , "Must be positive" ) ;
183
+ Host . CheckUserArg ( Args . MaxIterations > 0 , nameof ( Args . MaxIterations ) , "Must be positive" ) ;
184
+ Host . CheckUserArg ( Args . SgdInitializationTolerance >= 0 , nameof ( Args . SgdInitializationTolerance ) , "Must be non-negative" ) ;
185
+ Host . CheckUserArg ( Args . NumThreads == null || Args . NumThreads . Value >= 0 , nameof ( Args . NumThreads ) , "Must be non-negative" ) ;
186
+
187
+ Host . CheckParam ( ! ( l2Weight < 0 ) , nameof ( l2Weight ) , "Must be non-negative, if provided." ) ;
188
+ Host . CheckParam ( ! ( l1Weight < 0 ) , nameof ( l1Weight ) , "Must be non-negative, if provided" ) ;
189
+ Host . CheckParam ( ! ( optimizationTolerance <= 0 ) , nameof ( optimizationTolerance ) , "Must be positive, if provided." ) ;
190
+ Host . CheckParam ( ! ( memorySize <= 0 ) , nameof ( memorySize ) , "Must be positive, if provided." ) ;
191
+
192
+ // Review: Warn about the overriding behavior
193
+ L2Weight = l2Weight ?? Args . L2Weight ;
194
+ L1Weight = l1Weight ?? Args . L1Weight ;
195
+ OptTol = optimizationTolerance ?? Args . OptTol ;
196
+ MemorySize = memorySize ?? Args . MemorySize ;
171
197
MaxIterations = Args . MaxIterations ;
172
198
SgdInitializationTolerance = Args . SgdInitializationTolerance ;
173
199
Quiet = Args . Quiet ;
174
200
InitWtsDiameter = Args . InitWtsDiameter ;
175
201
UseThreads = Args . UseThreads ;
176
202
NumThreads = Args . NumThreads ;
177
203
DenseOptimizer = Args . DenseOptimizer ;
178
- EnforceNonNegativity = Args . EnforceNonNegativity ;
204
+ EnforceNonNegativity = enforceNoNegativity ?? Args . EnforceNonNegativity ;
179
205
180
206
if ( EnforceNonNegativity && ShowTrainingStats )
181
207
{
0 commit comments