20
20
using Microsoft . ML . Training ;
21
21
using Microsoft . ML . Transforms ;
22
22
23
- [ assembly: LoadableClass ( typeof ( SymSgdClassificationTrainer ) , typeof ( SymSgdClassificationTrainer . Arguments ) ,
23
+ [ assembly: LoadableClass ( typeof ( SymSgdClassificationTrainer ) , typeof ( SymSgdClassificationTrainer . Options ) ,
24
24
new [ ] { typeof ( SignatureBinaryClassifierTrainer ) , typeof ( SignatureTrainer ) , typeof ( SignatureFeatureScorerTrainer ) } ,
25
25
SymSgdClassificationTrainer . UserNameValue ,
26
26
SymSgdClassificationTrainer . LoadNameValue ,
@@ -33,48 +33,78 @@ namespace Microsoft.ML.Trainers.SymSgd
33
33
using TPredictor = IPredictorWithFeatureWeights < float > ;
34
34
35
35
/// <include file='doc.xml' path='doc/members/member[@name="SymSGD"]/*' />
36
+ [ BestFriend ]
36
37
public sealed class SymSgdClassificationTrainer : TrainerEstimatorBase < BinaryPredictionTransformer < TPredictor > , TPredictor >
37
38
{
38
39
internal const string LoadNameValue = "SymbolicSGD" ;
39
40
internal const string UserNameValue = "Symbolic SGD (binary)" ;
40
41
internal const string ShortName = "SymSGD" ;
41
42
42
- public sealed class Arguments : LearnerInputBaseWithLabel
43
+ public sealed class Options : LearnerInputBaseWithLabel
43
44
{
45
+ /// <summary>
46
+ /// Degree of lock-free parallelism. Determinism not guaranteed.
47
+ /// Multi-threading is not supported currently.
48
+ /// </summary>
44
49
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Degree of lock-free parallelism. Determinism not guaranteed. " +
45
50
"Multi-threading is not supported currently." , ShortName = "nt" ) ]
46
51
public int ? NumberOfThreads ;
47
52
53
+ /// <summary>
54
+ /// Number of passes over the data.
55
+ /// </summary>
48
56
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Number of passes over the data." , ShortName = "iter" , SortOrder = 50 ) ]
49
57
[ TGUI ( SuggestedSweeps = "1,5,10,20,30,40,50" ) ]
50
58
[ TlcModule . SweepableDiscreteParam ( "NumberOfIterations" , new object [ ] { 1 , 5 , 10 , 20 , 30 , 40 , 50 } ) ]
51
59
public int NumberOfIterations = 50 ;
52
60
61
+ /// <summary>
62
+ /// Tolerance for difference in average loss in consecutive passes.
63
+ /// </summary>
53
64
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Tolerance for difference in average loss in consecutive passes." , ShortName = "tol" ) ]
54
65
public float Tolerance = 1e-4f ;
55
66
67
+ /// <summary>
68
+ /// Learning rate.
69
+ /// </summary>
56
70
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Learning rate" , ShortName = "lr" , NullName = "<Auto>" , SortOrder = 51 ) ]
57
71
[ TGUI ( SuggestedSweeps = "<Auto>,1e1,1e0,1e-1,1e-2,1e-3" ) ]
58
72
[ TlcModule . SweepableDiscreteParam ( "LearningRate" , new object [ ] { "<Auto>" , 1e1f , 1e0f , 1e-1f , 1e-2f , 1e-3f } ) ]
59
73
public float ? LearningRate ;
60
74
75
+ /// <summary>
76
+ /// L2 regularization.
77
+ /// </summary>
61
78
[ Argument ( ArgumentType . AtMostOnce , HelpText = "L2 regularization" , ShortName = "l2" , SortOrder = 52 ) ]
62
79
[ TGUI ( SuggestedSweeps = "0.0,1e-5,1e-5,1e-6,1e-7" ) ]
63
80
[ TlcModule . SweepableDiscreteParam ( "L2Regularization" , new object [ ] { 0.0f , 1e-5f , 1e-5f , 1e-6f , 1e-7f } ) ]
64
81
public float L2Regularization ;
65
82
83
+ /// <summary>
84
+ /// The number of iterations each thread learns a local model until combining it with the
85
+ /// global model. Low value means more updated global model and high value means less cache traffic.
86
+ /// </summary>
66
87
[ Argument ( ArgumentType . AtMostOnce , HelpText = "The number of iterations each thread learns a local model until combining it with the " +
67
88
"global model. Low value means more updated global model and high value means less cache traffic." , ShortName = "freq" , NullName = "<Auto>" ) ]
68
89
[ TGUI ( SuggestedSweeps = "<Auto>,5,20" ) ]
69
90
[ TlcModule . SweepableDiscreteParam ( "UpdateFrequency" , new object [ ] { "<Auto>" , 5 , 20 } ) ]
70
91
public int ? UpdateFrequency ;
71
92
93
+ /// <summary>
94
+ /// The acceleration memory budget in MB.
95
+ /// </summary>
72
96
[ Argument ( ArgumentType . AtMostOnce , HelpText = "The acceleration memory budget in MB" , ShortName = "accelMemBudget" ) ]
73
97
public long MemorySize = 1024 ;
74
98
99
+ /// <summary>
100
+ /// Set to <see langword="true" /> causes the data to shuffle.
101
+ /// </summary>
75
102
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Shuffle data?" , ShortName = "shuf" ) ]
76
103
public bool Shuffle = true ;
77
104
105
+ /// <summary>
106
+ /// Apply weight to the positive class, for imbalanced data.
107
+ /// </summary>
78
108
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Apply weight to the positive class, for imbalanced data" , ShortName = "piw" ) ]
79
109
public float PositiveInstanceWeight = 1 ;
80
110
@@ -88,7 +118,7 @@ public void Check(IExceptionContext ectx)
88
118
}
89
119
90
120
public override TrainerInfo Info { get ; }
91
- private readonly Arguments _args ;
121
+ private readonly Options _args ;
92
122
93
123
/// <summary>
94
124
/// This method ensures that the data meets the requirements of this trainer and its
@@ -152,32 +182,7 @@ private protected override TPredictor TrainModelCore(TrainContext context)
152
182
/// <summary>
153
183
/// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
154
184
/// </summary>
155
- /// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
156
- /// <param name="labelColumn">The name of the label column.</param>
157
- /// <param name="featureColumn">The name of the feature column.</param>
158
- /// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
159
- public SymSgdClassificationTrainer ( IHostEnvironment env ,
160
- string labelColumn = DefaultColumnNames . Label ,
161
- string featureColumn = DefaultColumnNames . Features ,
162
- Action < Arguments > advancedSettings = null )
163
- : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( LoadNameValue ) , TrainerUtils . MakeR4VecFeature ( featureColumn ) ,
164
- TrainerUtils . MakeBoolScalarLabel ( labelColumn ) )
165
- {
166
- _args = new Arguments ( ) ;
167
-
168
- // Apply the advanced args, if the user supplied any.
169
- _args . Check ( Host ) ;
170
- advancedSettings ? . Invoke ( _args ) ;
171
- _args . FeatureColumn = featureColumn ;
172
- _args . LabelColumn = labelColumn ;
173
-
174
- Info = new TrainerInfo ( supportIncrementalTrain : true ) ;
175
- }
176
-
177
- /// <summary>
178
- /// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
179
- /// </summary>
180
- internal SymSgdClassificationTrainer ( IHostEnvironment env , Arguments args )
185
+ internal SymSgdClassificationTrainer ( IHostEnvironment env , Options args )
181
186
: base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( LoadNameValue ) , TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) ,
182
187
TrainerUtils . MakeBoolScalarLabel ( args . LabelColumn ) )
183
188
{
@@ -218,14 +223,14 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
218
223
UserName = SymSgdClassificationTrainer . UserNameValue ,
219
224
ShortName = SymSgdClassificationTrainer . ShortName ,
220
225
XmlInclude = new [ ] { @"<include file='../Microsoft.ML.HalLearners/doc.xml' path='doc/members/member[@name=""SymSGD""]/*' />" } ) ]
221
- public static CommonOutputs . BinaryClassificationOutput TrainSymSgd ( IHostEnvironment env , Arguments input )
226
+ public static CommonOutputs . BinaryClassificationOutput TrainSymSgd ( IHostEnvironment env , Options input )
222
227
{
223
228
Contracts . CheckValue ( env , nameof ( env ) ) ;
224
229
var host = env . Register ( "TrainSymSGD" ) ;
225
230
host . CheckValue ( input , nameof ( input ) ) ;
226
231
EntryPointUtils . CheckInputArgs ( host , input ) ;
227
232
228
- return LearnerEntryPointsUtils . Train < Arguments , CommonOutputs . BinaryClassificationOutput > ( host , input ,
233
+ return LearnerEntryPointsUtils . Train < Options , CommonOutputs . BinaryClassificationOutput > ( host , input ,
229
234
( ) => new SymSgdClassificationTrainer ( host , input ) ,
230
235
( ) => LearnerEntryPointsUtils . FindColumn ( host , input . TrainingData . Schema , input . LabelColumn ) ) ;
231
236
}
0 commit comments