Skip to content

Commit bb92c06

Browse files
authored
Modify API for advanced settings (several learners) (#2163)
* LogisticRegression, MulticlassLogisticRegression, PoissonRegression * Options rename * weights updates * AveragedPerceptron * OnlineGradientDescent * LinearSvm * Options rename * review comments
1 parent bafd40c commit bb92c06

File tree

27 files changed

+573
-277
lines changed

27 files changed

+573
-277
lines changed

src/Microsoft.ML.StandardLearners/Properties/AssemblyInfo.cs

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Runtime.CompilerServices;
66
using Microsoft.ML;
77

8+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Ensemble" + PublicKey.Value)]
89
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipe" + PublicKey.Value)]
910

1011
[assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)]

src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ internal LbfgsTrainerBase(IHostEnvironment env,
156156
string featureColumn,
157157
SchemaShape.Column labelColumn,
158158
string weightColumn,
159-
Action<TArgs> advancedSettings,
160159
float l1Weight,
161160
float l2Weight,
162161
float optimizationTolerance,
@@ -173,7 +172,7 @@ internal LbfgsTrainerBase(IHostEnvironment env,
173172
MemorySize = memorySize,
174173
EnforceNonNegativity = enforceNoNegativity
175174
},
176-
labelColumn, advancedSettings)
175+
labelColumn)
177176
{
178177
}
179178

src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs

+16-18
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
using Microsoft.ML.Numeric;
1717
using Microsoft.ML.Training;
1818

19-
[assembly: LoadableClass(LogisticRegression.Summary, typeof(LogisticRegression), typeof(LogisticRegression.Arguments),
19+
[assembly: LoadableClass(LogisticRegression.Summary, typeof(LogisticRegression), typeof(LogisticRegression.Options),
2020
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
2121
LogisticRegression.UserNameValue,
2222
LogisticRegression.LoadNameValue,
@@ -30,15 +30,15 @@ namespace Microsoft.ML.Learners
3030

3131
/// <include file='doc.xml' path='doc/members/member[@name="LBFGS"]/*' />
3232
/// <include file='doc.xml' path='docs/members/example[@name="LogisticRegressionBinaryClassifier"]/*' />
33-
public sealed partial class LogisticRegression : LbfgsTrainerBase<LogisticRegression.Arguments, BinaryPredictionTransformer<ParameterMixingCalibratedPredictor>, ParameterMixingCalibratedPredictor>
33+
public sealed partial class LogisticRegression : LbfgsTrainerBase<LogisticRegression.Options, BinaryPredictionTransformer<ParameterMixingCalibratedPredictor>, ParameterMixingCalibratedPredictor>
3434
{
3535
public const string LoadNameValue = "LogisticRegression";
3636
internal const string UserNameValue = "Logistic Regression";
3737
internal const string ShortName = "lr";
3838
internal const string Summary = "Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can "
3939
+ "be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function.";
4040

41-
public sealed class Arguments : ArgumentsBase
41+
public sealed class Options : ArgumentsBase
4242
{
4343
/// <summary>
4444
/// If set to <value>true</value>training statistics will be generated at the end of training.
@@ -53,7 +53,7 @@ public sealed class Arguments : ArgumentsBase
5353
/// <summary>
5454
/// The instance of <see cref="ComputeLRTrainingStd"/> that computes the std of the training statistics, at the end of training.
5555
/// The calculations are not part of Microsoft.ML package, due to the size of MKL.
56-
/// If you need these calculations, add the Microsoft.ML.HalLearners package, and initialize <see cref="LogisticRegression.Arguments.StdComputer"/>.
56+
/// If you need these calculations, add the Microsoft.ML.HalLearners package, and initialize <see cref="LogisticRegression.Options.StdComputer"/>.
5757
/// to the <see cref="ComputeLRTrainingStd"/> implementation in the Microsoft.ML.HalLearners package.
5858
/// </summary>
5959
public ComputeLRTrainingStd StdComputer;
@@ -74,18 +74,16 @@ public sealed class Arguments : ArgumentsBase
7474
/// <param name="l2Weight">Weight of L2 regularizer term.</param>
7575
/// <param name="memorySize">Memory size for <see cref="LogisticRegression"/>. Low=faster, less accurate.</param>
7676
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param>
77-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
78-
public LogisticRegression(IHostEnvironment env,
77+
internal LogisticRegression(IHostEnvironment env,
7978
string labelColumn = DefaultColumnNames.Label,
8079
string featureColumn = DefaultColumnNames.Features,
8180
string weights = null,
82-
float l1Weight = Arguments.Defaults.L1Weight,
83-
float l2Weight = Arguments.Defaults.L2Weight,
84-
float optimizationTolerance = Arguments.Defaults.OptTol,
85-
int memorySize = Arguments.Defaults.MemorySize,
86-
bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity,
87-
Action<Arguments> advancedSettings = null)
88-
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weights, advancedSettings,
81+
float l1Weight = Options.Defaults.L1Weight,
82+
float l2Weight = Options.Defaults.L2Weight,
83+
float optimizationTolerance = Options.Defaults.OptTol,
84+
int memorySize = Options.Defaults.MemorySize,
85+
bool enforceNoNegativity = Options.Defaults.EnforceNonNegativity)
86+
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weights,
8987
l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity)
9088
{
9189
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
@@ -98,8 +96,8 @@ public LogisticRegression(IHostEnvironment env,
9896
/// <summary>
9997
/// Initializes a new instance of <see cref="LogisticRegression"/>
10098
/// </summary>
101-
internal LogisticRegression(IHostEnvironment env, Arguments args)
102-
: base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
99+
internal LogisticRegression(IHostEnvironment env, Options options)
100+
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn))
103101
{
104102
_posWeight = 0;
105103
ShowTrainingStats = Args.ShowTrainingStats;
@@ -410,14 +408,14 @@ protected override ParameterMixingCalibratedPredictor CreatePredictor()
410408
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/LogisticRegression/doc.xml' path='doc/members/member[@name=""LBFGS""]/*' />",
411409
@"<include file='../Microsoft.ML.StandardLearners/Standard/LogisticRegression/doc.xml' path='doc/members/example[@name=""LogisticRegressionBinaryClassifier""]/*' />"})]
412410

413-
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input)
411+
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
414412
{
415413
Contracts.CheckValue(env, nameof(env));
416414
var host = env.Register("TrainLRBinary");
417415
host.CheckValue(input, nameof(input));
418416
EntryPointUtils.CheckInputArgs(host, input);
419417

420-
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
418+
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
421419
() => new LogisticRegression(host, input),
422420
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
423421
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));
@@ -437,7 +435,7 @@ public abstract class ComputeLRTrainingStd
437435
/// Computes the standard deviation matrix of each of the non-zero training weights, needed to calculate further the standard deviation,
438436
/// p-value and z-Score.
439437
/// The calculations are not part of Microsoft.ML package, due to the size of MKL.
440-
/// If you need these calculations, add the Microsoft.ML.HalLearners package, and initialize <see cref="LogisticRegression.Arguments.StdComputer"/>
438+
/// If you need these calculations, add the Microsoft.ML.HalLearners package, and initialize <see cref="LogisticRegression.Options.StdComputer"/>
441439
/// to the <see cref="ComputeLRTrainingStd"/> implementation in the Microsoft.ML.HalLearners package.
442440
/// Due to the existence of regularization, an approximation is used to compute the variances of the trained linear coefficients.
443441
/// </summary>

src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs

+14-17
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
using Microsoft.ML.Training;
2323
using Newtonsoft.Json.Linq;
2424

25-
[assembly: LoadableClass(typeof(MulticlassLogisticRegression), typeof(MulticlassLogisticRegression.Arguments),
25+
[assembly: LoadableClass(typeof(MulticlassLogisticRegression), typeof(MulticlassLogisticRegression.Options),
2626
new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer) },
2727
MulticlassLogisticRegression.UserNameValue,
2828
MulticlassLogisticRegression.LoadNameValue,
@@ -38,14 +38,14 @@ namespace Microsoft.ML.Learners
3838
{
3939
/// <include file = 'doc.xml' path='doc/members/member[@name="LBFGS"]/*' />
4040
/// <include file = 'doc.xml' path='docs/members/example[@name="LogisticRegressionClassifier"]/*' />
41-
public sealed class MulticlassLogisticRegression : LbfgsTrainerBase<MulticlassLogisticRegression.Arguments,
41+
public sealed class MulticlassLogisticRegression : LbfgsTrainerBase<MulticlassLogisticRegression.Options,
4242
MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters>
4343
{
4444
public const string LoadNameValue = "MultiClassLogisticRegression";
4545
internal const string UserNameValue = "Multi-class Logistic Regression";
4646
internal const string ShortName = "mlr";
4747

48-
public sealed class Arguments : ArgumentsBase
48+
public sealed class Options : ArgumentsBase
4949
{
5050
[Argument(ArgumentType.AtMostOnce, HelpText = "Show statistics of training examples.", ShortName = "stat", SortOrder = 50)]
5151
public bool ShowTrainingStats = false;
@@ -82,19 +82,16 @@ public sealed class Arguments : ArgumentsBase
8282
/// <param name="l2Weight">Weight of L2 regularizer term.</param>
8383
/// <param name="memorySize">Memory size for <see cref="LogisticRegression"/>. Low=faster, less accurate.</param>
8484
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param>
85-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
86-
public MulticlassLogisticRegression(IHostEnvironment env,
85+
internal MulticlassLogisticRegression(IHostEnvironment env,
8786
string labelColumn = DefaultColumnNames.Label,
8887
string featureColumn = DefaultColumnNames.Features,
8988
string weights = null,
90-
float l1Weight = Arguments.Defaults.L1Weight,
91-
float l2Weight = Arguments.Defaults.L2Weight,
92-
float optimizationTolerance = Arguments.Defaults.OptTol,
93-
int memorySize = Arguments.Defaults.MemorySize,
94-
bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity,
95-
Action<Arguments> advancedSettings = null)
96-
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), weights, advancedSettings,
97-
l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity)
89+
float l1Weight = Options.Defaults.L1Weight,
90+
float l2Weight = Options.Defaults.L2Weight,
91+
float optimizationTolerance = Options.Defaults.OptTol,
92+
int memorySize = Options.Defaults.MemorySize,
93+
bool enforceNoNegativity = Options.Defaults.EnforceNonNegativity)
94+
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity)
9895
{
9996
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
10097
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
@@ -105,8 +102,8 @@ public MulticlassLogisticRegression(IHostEnvironment env,
105102
/// <summary>
106103
/// Initializes a new instance of <see cref="MulticlassLogisticRegression"/>
107104
/// </summary>
108-
internal MulticlassLogisticRegression(IHostEnvironment env, Arguments args)
109-
: base(env, args, TrainerUtils.MakeU4ScalarColumn(args.LabelColumn))
105+
internal MulticlassLogisticRegression(IHostEnvironment env, Options options)
106+
: base(env, options, TrainerUtils.MakeU4ScalarColumn(options.LabelColumn))
110107
{
111108
ShowTrainingStats = Args.ShowTrainingStats;
112109
}
@@ -1007,14 +1004,14 @@ public partial class LogisticRegression
10071004
ShortName = MulticlassLogisticRegression.ShortName,
10081005
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/LogisticRegression/doc.xml' path='doc/members/member[@name=""LBFGS""]/*' />",
10091006
@"<include file='../Microsoft.ML.StandardLearners/Standard/LogisticRegression/doc.xml' path='doc/members/example[@name=""LogisticRegressionClassifier""]/*' />" })]
1010-
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, MulticlassLogisticRegression.Arguments input)
1007+
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, MulticlassLogisticRegression.Options input)
10111008
{
10121009
Contracts.CheckValue(env, nameof(env));
10131010
var host = env.Register("TrainLRMultiClass");
10141011
host.CheckValue(input, nameof(input));
10151012
EntryPointUtils.CheckInputArgs(host, input);
10161013

1017-
return LearnerEntryPointsUtils.Train<MulticlassLogisticRegression.Arguments, CommonOutputs.MulticlassClassificationOutput>(host, input,
1014+
return LearnerEntryPointsUtils.Train<MulticlassLogisticRegression.Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
10181015
() => new MulticlassLogisticRegression(host, input),
10191016
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
10201017
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));

src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ private TScalarTrainer CreateTrainer()
8787
{
8888
return Args.PredictorType != null ?
8989
Args.PredictorType.CreateComponent(Host) :
90-
new LinearSvmTrainer(Host, new LinearSvmTrainer.Arguments());
90+
new LinearSvmTrainer(Host, new LinearSvmTrainer.Options());
9191
}
9292

9393
private protected IDataView MapLabelsCore<T>(ColumnType type, InPredicate<T> equalsTarget, RoleMappedData data)

src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs

+15-17
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
using Microsoft.ML.Trainers.Online;
1616
using Microsoft.ML.Training;
1717

18-
[assembly: LoadableClass(AveragedPerceptronTrainer.Summary, typeof(AveragedPerceptronTrainer), typeof(AveragedPerceptronTrainer.Arguments),
18+
[assembly: LoadableClass(AveragedPerceptronTrainer.Summary, typeof(AveragedPerceptronTrainer), typeof(AveragedPerceptronTrainer.Options),
1919
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
2020
AveragedPerceptronTrainer.UserNameValue,
2121
AveragedPerceptronTrainer.LoadNameValue, "avgper", AveragedPerceptronTrainer.ShortName)]
@@ -37,9 +37,9 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPred
3737
internal const string ShortName = "ap";
3838
internal const string Summary = "Averaged Perceptron Binary Classifier.";
3939

40-
private readonly Arguments _args;
40+
private readonly Options _args;
4141

42-
public sealed class Arguments : AveragedLinearArguments
42+
public sealed class Options : AveragedLinearArguments
4343
{
4444
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
4545
public ISupportClassificationLossFactory LossFunction = new HingeLoss.Arguments();
@@ -83,10 +83,10 @@ public override LinearBinaryModelParameters CreatePredictor()
8383
}
8484
}
8585

86-
internal AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
87-
: base(args, env, UserNameValue, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
86+
internal AveragedPerceptronTrainer(IHostEnvironment env, Options options)
87+
: base(options, env, UserNameValue, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn))
8888
{
89-
_args = args;
89+
_args = options;
9090
LossFunction = _args.LossFunction.CreateComponent(env);
9191
}
9292

@@ -103,18 +103,16 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
103103
/// <param name="decreaseLearningRate">Wheather to decrease learning rate as iterations progress.</param>
104104
/// <param name="l2RegularizerWeight">L2 Regularization Weight.</param>
105105
/// <param name="numIterations">The number of training iteraitons.</param>
106-
/// <param name="advancedSettings">A delegate to supply more advanced arguments to the algorithm.</param>
107-
public AveragedPerceptronTrainer(IHostEnvironment env,
106+
internal AveragedPerceptronTrainer(IHostEnvironment env,
108107
string labelColumn = DefaultColumnNames.Label,
109108
string featureColumn = DefaultColumnNames.Features,
110109
string weights = null,
111110
IClassificationLoss lossFunction = null,
112-
float learningRate = Arguments.AveragedDefaultArgs.LearningRate,
113-
bool decreaseLearningRate = Arguments.AveragedDefaultArgs.DecreaseLearningRate,
114-
float l2RegularizerWeight = Arguments.AveragedDefaultArgs.L2RegularizerWeight,
115-
int numIterations = Arguments.AveragedDefaultArgs.NumIterations,
116-
Action<Arguments> advancedSettings = null)
117-
: this(env, InvokeAdvanced(advancedSettings, new Arguments
111+
float learningRate = Options.AveragedDefaultArgs.LearningRate,
112+
bool decreaseLearningRate = Options.AveragedDefaultArgs.DecreaseLearningRate,
113+
float l2RegularizerWeight = Options.AveragedDefaultArgs.L2RegularizerWeight,
114+
int numIterations = Options.AveragedDefaultArgs.NumIterations)
115+
: this(env, new Options
118116
{
119117
LabelColumn = labelColumn,
120118
FeatureColumn = featureColumn,
@@ -124,7 +122,7 @@ public AveragedPerceptronTrainer(IHostEnvironment env,
124122
L2RegularizerWeight = l2RegularizerWeight,
125123
NumIterations = numIterations,
126124
LossFunction = new TrivialFactory(lossFunction ?? new HingeLoss())
127-
}))
125+
})
128126
{
129127
}
130128

@@ -191,14 +189,14 @@ public BinaryPredictionTransformer<LinearBinaryModelParameters> Train(IDataView
191189
ShortName = ShortName,
192190
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/Online/doc.xml' path='doc/members/member[@name=""AP""]/*' />",
193191
@"<include file='../Microsoft.ML.StandardLearners/Standard/Online/doc.xml' path='doc/members/example[@name=""AP""]/*' />"})]
194-
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input)
192+
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
195193
{
196194
Contracts.CheckValue(env, nameof(env));
197195
var host = env.Register("TrainAP");
198196
host.CheckValue(input, nameof(input));
199197
EntryPointUtils.CheckInputArgs(host, input);
200198

201-
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
199+
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
202200
() => new AveragedPerceptronTrainer(host, input),
203201
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
204202
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);

0 commit comments

Comments
 (0)