Skip to content

Commit e023ab8

Browse files
authored
Scrubbing online learners (#2892)
1 parent acc4ac0 commit e023ab8

File tree

11 files changed

+85
-78
lines changed

11 files changed

+85
-78
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptron.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public static void Example()
1919
var trainTestData = mlContext.Data.TrainTestSplit(data, testFraction: 0.1);
2020

2121
// Create data training pipeline.
22-
var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(numIterations: 10);
22+
var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(numberOfIterations: 10);
2323

2424
// Fit this pipeline to the training data.
2525
var model = pipeline.Fit(trainTestData.TrainSet);

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public static void Example()
2525
{
2626
LossFunction = new SmoothedHingeLoss(),
2727
LearningRate = 0.1f,
28-
DoLazyUpdates = false,
28+
LazyUpdate = false,
2929
RecencyGain = 0.1f,
3030
NumberOfIterations = 10
3131
};

src/Microsoft.ML.StandardTrainers/Standard/Online/AveragedLinear.cs

+15-15
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,16 @@ public abstract class AveragedLinearOptions : OnlineLinearOptions
5757
/// <see langword="false" /> to update averaged weights on every example.
5858
/// Default is <see langword="true" />.
5959
/// </value>
60-
[Argument(ArgumentType.AtMostOnce, HelpText = "Instead of updating averaged weights on every example, only update when loss is nonzero", ShortName = "lazy")]
61-
public bool DoLazyUpdates = true;
60+
[Argument(ArgumentType.AtMostOnce, HelpText = "Instead of updating averaged weights on every example, only update when loss is nonzero", ShortName = "lazy,DoLazyUpdates")]
61+
public bool LazyUpdate = true;
6262

6363
/// <summary>
6464
/// The L2 weight for <a href='tmpurl_regularization'>regularization</a>.
6565
/// </summary>
66-
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg", SortOrder = 50)]
66+
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg,L2RegularizerWeight", SortOrder = 50)]
6767
[TGUI(Label = "L2 Regularization Weight")]
6868
[TlcModule.SweepableFloatParam("L2RegularizerWeight", 0.0f, 0.4f)]
69-
public float L2RegularizerWeight = AveragedDefault.L2RegularizerWeight;
69+
public float L2Regularization = AveragedDefault.L2Regularization;
7070

7171
/// <summary>
7272
/// Extra weight given to more recent updates.
@@ -85,8 +85,8 @@ public abstract class AveragedLinearOptions : OnlineLinearOptions
8585
/// <see langword="false" /> means <see cref="RecencyGain"/> is additive.
8686
/// Default is <see langword="false" />.
8787
/// </value>
88-
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether Recency Gain is multiplicative (vs. additive)", ShortName = "rgm")]
89-
public bool RecencyGainMulti = false;
88+
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether Recency Gain is multiplicative (vs. additive)", ShortName = "rgm,RecencyGainMulti")]
89+
public bool RecencyGainMultiplicative = false;
9090

9191
/// <summary>
9292
/// Determines whether to do averaging or not.
@@ -109,7 +109,7 @@ internal class AveragedDefault : OnlineLinearOptions.OnlineDefault
109109
{
110110
public const float LearningRate = 1;
111111
public const bool DecreaseLearningRate = false;
112-
public const float L2RegularizerWeight = 0;
112+
public const float L2Regularization = 0;
113113
}
114114

115115
internal abstract IComponentFactory<IScalarLoss> LossFunctionFactory { get; }
@@ -186,7 +186,7 @@ public override void FinishIteration(IChannel ch)
186186
// Finalize things
187187
if (Averaged)
188188
{
189-
if (_args.DoLazyUpdates && NumNoUpdates > 0)
189+
if (_args.LazyUpdate && NumNoUpdates > 0)
190190
{
191191
// Update the total weights to include the final loss=0 updates
192192
VectorUtils.AddMult(in Weights, NumNoUpdates * WeightsScale, ref TotalWeights);
@@ -221,10 +221,10 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, fl
221221
// REVIEW: Should this be biasUpdate != 0?
222222
// This loss does not incorporate L2 if present, but the chance of that addition to the loss
223223
// exactly cancelling out loss is remote.
224-
if (loss != 0 || _args.L2RegularizerWeight > 0)
224+
if (loss != 0 || _args.L2Regularization > 0)
225225
{
226226
// If doing lazy weights, we need to update the totalWeights and totalBias before updating weights/bias
227-
if (_args.DoLazyUpdates && _args.Averaged && NumNoUpdates > 0 && TotalMultipliers * _args.AveragedTolerance <= PendingMultipliers)
227+
if (_args.LazyUpdate && _args.Averaged && NumNoUpdates > 0 && TotalMultipliers * _args.AveragedTolerance <= PendingMultipliers)
228228
{
229229
VectorUtils.AddMult(in Weights, NumNoUpdates * WeightsScale, ref TotalWeights);
230230
TotalBias += Bias * NumNoUpdates * WeightsScale;
@@ -242,7 +242,7 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, fl
242242

243243
// Perform the update to weights and bias.
244244
VectorUtils.AddMult(in feat, biasUpdate / WeightsScale, ref Weights);
245-
WeightsScale *= 1 - 2 * _args.L2RegularizerWeight; // L2 regularization.
245+
WeightsScale *= 1 - 2 * _args.L2Regularization; // L2 regularization.
246246
ScaleWeightsIfNeeded();
247247
Bias += biasUpdate;
248248
PendingMultipliers += Math.Abs(biasUpdate);
@@ -251,7 +251,7 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, fl
251251
// Add to averaged weights and increment the count.
252252
if (Averaged)
253253
{
254-
if (!_args.DoLazyUpdates)
254+
if (!_args.LazyUpdate)
255255
IncrementAverageNonLazy();
256256
else
257257
NumNoUpdates++;
@@ -282,7 +282,7 @@ private void IncrementAverageNonLazy()
282282
VectorUtils.AddMult(in Weights, Gain * WeightsScale, ref TotalWeights);
283283
TotalBias += Gain * Bias;
284284
NumWeightUpdates += Gain;
285-
Gain = (_args.RecencyGainMulti ? Gain * _args.RecencyGain : Gain + _args.RecencyGain);
285+
Gain = (_args.RecencyGainMultiplicative ? Gain * _args.RecencyGain : Gain + _args.RecencyGain);
286286

287287
// If gains got too big, rescale!
288288
if (Gain > 1000)
@@ -303,11 +303,11 @@ private protected AveragedLinearTrainer(AveragedLinearOptions options, IHostEnvi
303303
Contracts.CheckUserArg(!options.ResetWeightsAfterXExamples.HasValue || options.ResetWeightsAfterXExamples > 0, nameof(options.ResetWeightsAfterXExamples), UserErrorPositive);
304304

305305
// Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible.
306-
Contracts.CheckUserArg(0 <= options.L2RegularizerWeight && options.L2RegularizerWeight < 0.5, nameof(options.L2RegularizerWeight), "must be in range [0, 0.5)");
306+
Contracts.CheckUserArg(0 <= options.L2Regularization && options.L2Regularization < 0.5, nameof(options.L2Regularization), "must be in range [0, 0.5)");
307307
Contracts.CheckUserArg(options.RecencyGain >= 0, nameof(options.RecencyGain), UserErrorNonNegative);
308308
Contracts.CheckUserArg(options.AveragedTolerance >= 0, nameof(options.AveragedTolerance), UserErrorNonNegative);
309309
// Verify user didn't specify parameters that conflict
310-
Contracts.Check(!options.DoLazyUpdates || !options.RecencyGainMulti && options.RecencyGain == 0, "Cannot have both recency gain and lazy updates.");
310+
Contracts.Check(!options.LazyUpdate || !options.RecencyGainMultiplicative && options.RecencyGain == 0, "Cannot have both recency gain and lazy updates.");
311311

312312
AveragedLinearTrainerOptions = options;
313313
}

src/Microsoft.ML.StandardTrainers/Standard/Online/AveragedPerceptron.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -131,24 +131,24 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, Options options)
131131
/// <param name="featureColumnName">The name of the feature column.</param>
132132
/// <param name="learningRate">The learning rate. </param>
133133
/// <param name="decreaseLearningRate">Whether to decrease learning rate as iterations progress.</param>
134-
/// <param name="l2RegularizerWeight">L2 Regularization Weight.</param>
135-
/// <param name="numIterations">The number of training iterations.</param>
134+
/// <param name="l2Regularization">Weight of L2 regularization term.</param>
135+
/// <param name="numberOfIterations">The number of training iterations.</param>
136136
internal AveragedPerceptronTrainer(IHostEnvironment env,
137137
string labelColumnName = DefaultColumnNames.Label,
138138
string featureColumnName = DefaultColumnNames.Features,
139139
IClassificationLoss lossFunction = null,
140140
float learningRate = Options.AveragedDefault.LearningRate,
141141
bool decreaseLearningRate = Options.AveragedDefault.DecreaseLearningRate,
142-
float l2RegularizerWeight = Options.AveragedDefault.L2RegularizerWeight,
143-
int numIterations = Options.AveragedDefault.NumIterations)
142+
float l2Regularization = Options.AveragedDefault.L2Regularization,
143+
int numberOfIterations = Options.AveragedDefault.NumberOfIterations)
144144
: this(env, new Options
145145
{
146146
LabelColumnName = labelColumnName,
147147
FeatureColumnName = featureColumnName,
148148
LearningRate = learningRate,
149149
DecreaseLearningRate = decreaseLearningRate,
150-
L2RegularizerWeight = l2RegularizerWeight,
151-
NumberOfIterations = numIterations,
150+
L2Regularization = l2Regularization,
151+
NumberOfIterations = numberOfIterations,
152152
LossFunction = lossFunction ?? new HingeLoss()
153153
})
154154
{

src/Microsoft.ML.StandardTrainers/Standard/Online/LinearSvm.cs

+8-8
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ public sealed class Options : OnlineLinearOptions
6969
/// <summary>
7070
/// Column to use for example weight.
7171
/// </summary>
72-
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
73-
public string WeightColumn = null;
72+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight,WeightColumn", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
73+
public string ExampleWeightColumnName = null;
7474
}
7575

7676
private sealed class TrainState : TrainStateBase
@@ -232,20 +232,20 @@ public override LinearBinaryModelParameters CreatePredictor()
232232
/// <param name="env">The environment to use.</param>
233233
/// <param name="labelColumn">The name of the label column. </param>
234234
/// <param name="featureColumn">The name of the feature column.</param>
235-
/// <param name="weightColumn">The optional name of the weight column.</param>
236-
/// <param name="numIterations">The number of training iteraitons.</param>
235+
/// <param name="exampleWeightColumnName">The name of the example weight column (optional).</param>
236+
/// <param name="numberOfIterations">The number of training iteraitons.</param>
237237
[BestFriend]
238238
internal LinearSvmTrainer(IHostEnvironment env,
239239
string labelColumn = DefaultColumnNames.Label,
240240
string featureColumn = DefaultColumnNames.Features,
241-
string weightColumn = null,
242-
int numIterations = Options.OnlineDefault.NumIterations)
241+
string exampleWeightColumnName = null,
242+
int numberOfIterations = Options.OnlineDefault.NumberOfIterations)
243243
: this(env, new Options
244244
{
245245
LabelColumnName = labelColumn,
246246
FeatureColumnName = featureColumn,
247-
WeightColumn = weightColumn,
248-
NumberOfIterations = numIterations,
247+
ExampleWeightColumnName = exampleWeightColumnName,
248+
NumberOfIterations = numberOfIterations,
249249
})
250250
{
251251
}

src/Microsoft.ML.StandardTrainers/Standard/Online/OnlineGradientDescent.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -97,23 +97,23 @@ public override LinearRegressionModelParameters CreatePredictor()
9797
/// <param name="featureColumn">Name of the feature column.</param>
9898
/// <param name="learningRate">The learning Rate.</param>
9999
/// <param name="decreaseLearningRate">Decrease learning rate as iterations progress.</param>
100-
/// <param name="l2RegularizerWeight">L2 Regularization Weight.</param>
101-
/// <param name="numIterations">Number of training iterations through the data.</param>
100+
/// <param name="l2Regularization">Weight of L2 regularization term.</param>
101+
/// <param name="numberOfIterations">Number of training iterations through the data.</param>
102102
/// <param name="lossFunction">The custom loss functions. Defaults to <see cref="SquaredLoss"/> if not provided.</param>
103103
internal OnlineGradientDescentTrainer(IHostEnvironment env,
104104
string labelColumn = DefaultColumnNames.Label,
105105
string featureColumn = DefaultColumnNames.Features,
106106
float learningRate = Options.OgdDefaultArgs.LearningRate,
107107
bool decreaseLearningRate = Options.OgdDefaultArgs.DecreaseLearningRate,
108-
float l2RegularizerWeight = Options.OgdDefaultArgs.L2RegularizerWeight,
109-
int numIterations = Options.OgdDefaultArgs.NumIterations,
108+
float l2Regularization = Options.OgdDefaultArgs.L2Regularization,
109+
int numberOfIterations = Options.OgdDefaultArgs.NumberOfIterations,
110110
IRegressionLoss lossFunction = null)
111111
: this(env, new Options
112112
{
113113
LearningRate = learningRate,
114114
DecreaseLearningRate = decreaseLearningRate,
115-
L2RegularizerWeight = l2RegularizerWeight,
116-
NumberOfIterations = numIterations,
115+
L2Regularization= l2Regularization,
116+
NumberOfIterations = numberOfIterations,
117117
LabelColumnName = labelColumn,
118118
FeatureColumnName = featureColumn,
119119
LossFunction = lossFunction ?? new SquaredLoss()

src/Microsoft.ML.StandardTrainers/Standard/Online/OnlineLinear.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public abstract class OnlineLinearOptions : TrainerInputBaseWithLabel
2727
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter,numIterations", SortOrder = 50)]
2828
[TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")]
2929
[TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)]
30-
public int NumberOfIterations = OnlineDefault.NumIterations;
30+
public int NumberOfIterations = OnlineDefault.NumberOfIterations;
3131

3232
/// <summary>
3333
/// Initial weights and bias, comma-separated.
@@ -62,7 +62,7 @@ public abstract class OnlineLinearOptions : TrainerInputBaseWithLabel
6262
[BestFriend]
6363
internal class OnlineDefault
6464
{
65-
public const int NumIterations = 1;
65+
public const int NumberOfIterations = 1;
6666
}
6767
}
6868

0 commit comments

Comments
 (0)