Skip to content

Commit 1fca6c0

Browse files
committed
Towards dotnet#1798 .
This PR addresses the estimators inside HalLearners: Two public extension methods, one for simple arguments and the other for advanced options Delete unecessary constructors Pass Options objects as arguments instead of Action delegate Rename Arguments to Options Rename Options objects as options (instead of args or advancedSettings used so far)
1 parent b89ce70 commit 1fca6c0

File tree

6 files changed

+102
-82
lines changed

6 files changed

+102
-82
lines changed

src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs

+47-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using Microsoft.ML.Data;
7+
using Microsoft.ML.EntryPoints;
78
using Microsoft.ML.Trainers.HalLearners;
89
using Microsoft.ML.Trainers.SymSgd;
910
using Microsoft.ML.Transforms.Projections;
@@ -22,16 +23,36 @@ public static class HalLearnersCatalog
2223
/// <param name="labelColumn">The labelColumn column.</param>
2324
/// <param name="featureColumn">The features column.</param>
2425
/// <param name="weights">The weights column.</param>
25-
/// <param name="advancedSettings">Algorithm advanced settings.</param>
2626
public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionContext.RegressionTrainers ctx,
2727
string labelColumn = DefaultColumnNames.Label,
2828
string featureColumn = DefaultColumnNames.Features,
29-
string weights = null,
30-
Action<OlsLinearRegressionTrainer.Arguments> advancedSettings = null)
29+
string weights = null)
30+
{
31+
Contracts.CheckValue(ctx, nameof(ctx));
32+
var env = CatalogUtils.GetEnvironment(ctx);
33+
var options = new OlsLinearRegressionTrainer.Options
34+
{
35+
LabelColumn = labelColumn,
36+
FeatureColumn = featureColumn,
37+
WeightColumn = weights != null ? Optional<string>.Explicit(weights) : Optional<string>.Implicit(DefaultColumnNames.Weight)
38+
};
39+
40+
return new OlsLinearRegressionTrainer(env, options);
41+
}
42+
43+
/// <summary>
44+
/// Predict a target using a linear regression model trained with the <see cref="OlsLinearRegressionTrainer"/>.
45+
/// </summary>
46+
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
47+
/// <param name="options">Algorithm advanced options.</param>
48+
public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionContext.RegressionTrainers ctx,
49+
OlsLinearRegressionTrainer.Options options)
3150
{
3251
Contracts.CheckValue(ctx, nameof(ctx));
52+
Contracts.CheckValue(options, nameof(options));
53+
3354
var env = CatalogUtils.GetEnvironment(ctx);
34-
return new OlsLinearRegressionTrainer(env, labelColumn, featureColumn, weights, advancedSettings);
55+
return new OlsLinearRegressionTrainer(env, options);
3556
}
3657

3758
/// <summary>
@@ -44,11 +65,31 @@ public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionCon
4465
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
4566
string labelColumn = DefaultColumnNames.Label,
4667
string featureColumn = DefaultColumnNames.Features,
47-
Action<SymSgdClassificationTrainer.Arguments> advancedSettings = null)
68+
Action<SymSgdClassificationTrainer.Options> advancedSettings = null)
69+
{
70+
Contracts.CheckValue(ctx, nameof(ctx));
71+
var env = CatalogUtils.GetEnvironment(ctx);
72+
var options = new SymSgdClassificationTrainer.Options
73+
{
74+
LabelColumn = labelColumn,
75+
FeatureColumn = featureColumn,
76+
};
77+
78+
return new SymSgdClassificationTrainer(env, options);
79+
}
80+
81+
/// <summary>
82+
/// Predict a target using a linear binary classification model trained with the <see cref="SymSgdClassificationTrainer"/>.
83+
/// </summary>
84+
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
85+
/// <param name="options">Algorithm advanced options.</param>
86+
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
87+
SymSgdClassificationTrainer.Options options)
4888
{
4989
Contracts.CheckValue(ctx, nameof(ctx));
90+
Contracts.CheckValue(options, nameof(options));
5091
var env = CatalogUtils.GetEnvironment(ctx);
51-
return new SymSgdClassificationTrainer(env, labelColumn, featureColumn, advancedSettings);
92+
return new SymSgdClassificationTrainer(env, options);
5293
}
5394

5495
/// <summary>

src/Microsoft.ML.HalLearners/OlsLinearRegression.cs

+13-41
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
using Microsoft.ML.Trainers.HalLearners;
2020
using Microsoft.ML.Training;
2121

22-
[assembly: LoadableClass(OlsLinearRegressionTrainer.Summary, typeof(OlsLinearRegressionTrainer), typeof(OlsLinearRegressionTrainer.Arguments),
22+
[assembly: LoadableClass(OlsLinearRegressionTrainer.Summary, typeof(OlsLinearRegressionTrainer), typeof(OlsLinearRegressionTrainer.Options),
2323
new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
2424
OlsLinearRegressionTrainer.UserNameValue,
2525
OlsLinearRegressionTrainer.LoadNameValue,
@@ -34,9 +34,10 @@
3434
namespace Microsoft.ML.Trainers.HalLearners
3535
{
3636
/// <include file='doc.xml' path='doc/members/member[@name="OLS"]/*' />
37+
[BestFriend]
3738
public sealed class OlsLinearRegressionTrainer : TrainerEstimatorBase<RegressionPredictionTransformer<OlsLinearRegressionModelParameters>, OlsLinearRegressionModelParameters>
3839
{
39-
public sealed class Arguments : LearnerInputBaseWithWeight
40+
public sealed class Options : LearnerInputBaseWithWeight
4041
{
4142
// Adding L2 regularization turns this into a form of ridge regression,
4243
// rather than, strictly speaking, ordinary least squares. But it is an
@@ -46,13 +47,16 @@ public sealed class Arguments : LearnerInputBaseWithWeight
4647
[TlcModule.SweepableDiscreteParamAttribute("L2Weight", new object[] { 1e-6f, 0.1f, 1f })]
4748
public float L2Weight = 1e-6f;
4849

50+
/// <summary>
51+
/// Whether to calculate per parameter significance statistics.
52+
/// </summary>
4953
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether to calculate per parameter significance statistics", ShortName = "sig")]
5054
public bool PerParameterSignificance = true;
5155
}
5256

53-
public const string LoadNameValue = "OLSLinearRegression";
54-
public const string UserNameValue = "Ordinary Least Squares (Regression)";
55-
public const string ShortName = "ols";
57+
internal const string LoadNameValue = "OLSLinearRegression";
58+
internal const string UserNameValue = "Ordinary Least Squares (Regression)";
59+
internal const string ShortName = "ols";
5660
internal const string Summary = "The ordinary least square regression fits the target function as a linear function of the numerical features "
5761
+ "that minimizes the square loss function.";
5862

@@ -68,24 +72,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight
6872
/// <summary>
6973
/// Initializes a new instance of <see cref="OlsLinearRegressionTrainer"/>
7074
/// </summary>
71-
/// <param name="env">The environment to use.</param>
72-
/// <param name="labelColumn">The name of the labelColumn column.</param>
73-
/// <param name="featureColumn">The name of the feature column.</param>
74-
/// <param name="weights">The name for the optional example weight column.</param>
75-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
76-
public OlsLinearRegressionTrainer(IHostEnvironment env,
77-
string labelColumn = DefaultColumnNames.Label,
78-
string featureColumn = DefaultColumnNames.Features,
79-
string weights = null,
80-
Action<Arguments> advancedSettings = null)
81-
: this(env, ArgsInit(featureColumn, labelColumn, weights, advancedSettings))
82-
{
83-
}
84-
85-
/// <summary>
86-
/// Initializes a new instance of <see cref="OlsLinearRegressionTrainer"/>
87-
/// </summary>
88-
internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
75+
internal OlsLinearRegressionTrainer(IHostEnvironment env, Options args)
8976
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
9077
TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
9178
{
@@ -95,21 +82,6 @@ internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
9582
_perParameterSignificance = args.PerParameterSignificance;
9683
}
9784

98-
private static Arguments ArgsInit(string featureColumn,
99-
string labelColumn,
100-
string weightColumn,
101-
Action<Arguments> advancedSettings)
102-
{
103-
var args = new Arguments();
104-
105-
// Apply the advanced args, if the user supplied any.
106-
advancedSettings?.Invoke(args);
107-
args.FeatureColumn = featureColumn;
108-
args.LabelColumn = labelColumn;
109-
args.WeightColumn = weightColumn != null ? Optional<string>.Explicit(weightColumn) : Optional<string>.Implicit(DefaultColumnNames.Weight);
110-
return args;
111-
}
112-
11385
protected override RegressionPredictionTransformer<OlsLinearRegressionModelParameters> MakeTransformer(OlsLinearRegressionModelParameters model, Schema trainSchema)
11486
=> new RegressionPredictionTransformer<OlsLinearRegressionModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
11587

@@ -518,14 +490,14 @@ public static void Pptri(Layout layout, UpLo uplo, int n, Double[] ap)
518490
UserName = UserNameValue,
519491
ShortName = ShortName,
520492
XmlInclude = new[] { @"<include file='../Microsoft.ML.HalLearners/doc.xml' path='doc/members/member[@name=""OLS""]/*' />" })]
521-
public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Arguments input)
493+
public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Options input)
522494
{
523495
Contracts.CheckValue(env, nameof(env));
524496
var host = env.Register("TrainOLS");
525497
host.CheckValue(input, nameof(input));
526498
EntryPointUtils.CheckInputArgs(host, input);
527499

528-
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.RegressionOutput>(host, input,
500+
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.RegressionOutput>(host, input,
529501
() => new OlsLinearRegressionTrainer(host, input),
530502
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
531503
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));
@@ -579,7 +551,7 @@ private static VersionInfo GetVersionInfo()
579551
/// are all null. A model may not have per parameter statistics because either
580552
/// there were not more examples than parameters in the model, or because they
581553
/// were explicitly suppressed in training by setting
582-
/// <see cref="OlsLinearRegressionTrainer.Arguments.PerParameterSignificance"/>
554+
/// <see cref="OlsLinearRegressionTrainer.Options.PerParameterSignificance"/>
583555
/// to false.
584556
/// </summary>
585557
public bool HasStatistics => _standardErrors != null;

src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
using System.Runtime.CompilerServices;
66
using Microsoft.ML;
77

8+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
9+
810
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners.StaticPipe" + PublicKey.Value)]
911

1012
[assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)]

src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs

+36-31
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
using Microsoft.ML.Training;
2121
using Microsoft.ML.Transforms;
2222

23-
[assembly: LoadableClass(typeof(SymSgdClassificationTrainer), typeof(SymSgdClassificationTrainer.Arguments),
23+
[assembly: LoadableClass(typeof(SymSgdClassificationTrainer), typeof(SymSgdClassificationTrainer.Options),
2424
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
2525
SymSgdClassificationTrainer.UserNameValue,
2626
SymSgdClassificationTrainer.LoadNameValue,
@@ -33,48 +33,78 @@ namespace Microsoft.ML.Trainers.SymSgd
3333
using TPredictor = IPredictorWithFeatureWeights<float>;
3434

3535
/// <include file='doc.xml' path='doc/members/member[@name="SymSGD"]/*' />
36+
[BestFriend]
3637
public sealed class SymSgdClassificationTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<TPredictor>, TPredictor>
3738
{
3839
internal const string LoadNameValue = "SymbolicSGD";
3940
internal const string UserNameValue = "Symbolic SGD (binary)";
4041
internal const string ShortName = "SymSGD";
4142

42-
public sealed class Arguments : LearnerInputBaseWithLabel
43+
public sealed class Options : LearnerInputBaseWithLabel
4344
{
45+
/// <summary>
46+
/// Degree of lock-free parallelism. Determinism not guaranteed.
47+
/// Multi-threading is not supported currently.
48+
/// </summary>
4449
[Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Determinism not guaranteed. " +
4550
"Multi-threading is not supported currently.", ShortName = "nt")]
4651
public int? NumberOfThreads;
4752

53+
/// <summary>
54+
/// Number of passes over the data.
55+
/// </summary>
4856
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of passes over the data.", ShortName = "iter", SortOrder = 50)]
4957
[TGUI(SuggestedSweeps = "1,5,10,20,30,40,50")]
5058
[TlcModule.SweepableDiscreteParam("NumberOfIterations", new object[] { 1, 5, 10, 20, 30, 40, 50 })]
5159
public int NumberOfIterations = 50;
5260

61+
/// <summary>
62+
/// Tolerance for difference in average loss in consecutive passes.
63+
/// </summary>
5364
[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance for difference in average loss in consecutive passes.", ShortName = "tol")]
5465
public float Tolerance = 1e-4f;
5566

67+
/// <summary>
68+
/// Learning rate.
69+
/// </summary>
5670
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", NullName = "<Auto>", SortOrder = 51)]
5771
[TGUI(SuggestedSweeps = "<Auto>,1e1,1e0,1e-1,1e-2,1e-3")]
5872
[TlcModule.SweepableDiscreteParam("LearningRate", new object[] { "<Auto>", 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f })]
5973
public float? LearningRate;
6074

75+
/// <summary>
76+
/// L2 regularization.
77+
/// </summary>
6178
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization", ShortName = "l2", SortOrder = 52)]
6279
[TGUI(SuggestedSweeps = "0.0,1e-5,1e-5,1e-6,1e-7")]
6380
[TlcModule.SweepableDiscreteParam("L2Regularization", new object[] { 0.0f, 1e-5f, 1e-5f, 1e-6f, 1e-7f })]
6481
public float L2Regularization;
6582

83+
/// <summary>
84+
/// The number of iterations each thread learns a local model until combining it with the
85+
/// global model. Low value means more updated global model and high value means less cache traffic.
86+
/// </summary>
6687
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of iterations each thread learns a local model until combining it with the " +
6788
"global model. Low value means more updated global model and high value means less cache traffic.", ShortName = "freq", NullName = "<Auto>")]
6889
[TGUI(SuggestedSweeps = "<Auto>,5,20")]
6990
[TlcModule.SweepableDiscreteParam("UpdateFrequency", new object[] { "<Auto>", 5, 20 })]
7091
public int? UpdateFrequency;
7192

93+
/// <summary>
94+
/// The acceleration memory budget in MB.
95+
/// </summary>
7296
[Argument(ArgumentType.AtMostOnce, HelpText = "The acceleration memory budget in MB", ShortName = "accelMemBudget")]
7397
public long MemorySize = 1024;
7498

99+
/// <summary>
100+
/// Set to <see langword="true" /> causes the data to shuffle.
101+
/// </summary>
75102
[Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data?", ShortName = "shuf")]
76103
public bool Shuffle = true;
77104

105+
/// <summary>
106+
/// Apply weight to the positive class, for imbalanced data.
107+
/// </summary>
78108
[Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")]
79109
public float PositiveInstanceWeight = 1;
80110

@@ -88,7 +118,7 @@ public void Check(IExceptionContext ectx)
88118
}
89119

90120
public override TrainerInfo Info { get; }
91-
private readonly Arguments _args;
121+
private readonly Options _args;
92122

93123
/// <summary>
94124
/// This method ensures that the data meets the requirements of this trainer and its
@@ -152,32 +182,7 @@ private protected override TPredictor TrainModelCore(TrainContext context)
152182
/// <summary>
153183
/// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
154184
/// </summary>
155-
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
156-
/// <param name="labelColumn">The name of the label column.</param>
157-
/// <param name="featureColumn">The name of the feature column.</param>
158-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
159-
public SymSgdClassificationTrainer(IHostEnvironment env,
160-
string labelColumn = DefaultColumnNames.Label,
161-
string featureColumn = DefaultColumnNames.Features,
162-
Action<Arguments> advancedSettings = null)
163-
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn),
164-
TrainerUtils.MakeBoolScalarLabel(labelColumn))
165-
{
166-
_args = new Arguments();
167-
168-
// Apply the advanced args, if the user supplied any.
169-
_args.Check(Host);
170-
advancedSettings?.Invoke(_args);
171-
_args.FeatureColumn = featureColumn;
172-
_args.LabelColumn = labelColumn;
173-
174-
Info = new TrainerInfo(supportIncrementalTrain: true);
175-
}
176-
177-
/// <summary>
178-
/// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
179-
/// </summary>
180-
internal SymSgdClassificationTrainer(IHostEnvironment env, Arguments args)
185+
internal SymSgdClassificationTrainer(IHostEnvironment env, Options args)
181186
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
182187
TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
183188
{
@@ -218,14 +223,14 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
218223
UserName = SymSgdClassificationTrainer.UserNameValue,
219224
ShortName = SymSgdClassificationTrainer.ShortName,
220225
XmlInclude = new[] { @"<include file='../Microsoft.ML.HalLearners/doc.xml' path='doc/members/member[@name=""SymSGD""]/*' />" })]
221-
public static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Arguments input)
226+
public static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Options input)
222227
{
223228
Contracts.CheckValue(env, nameof(env));
224229
var host = env.Register("TrainSymSGD");
225230
host.CheckValue(input, nameof(input));
226231
EntryPointUtils.CheckInputArgs(host, input);
227232

228-
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
233+
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
229234
() => new SymSgdClassificationTrainer(host, input),
230235
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
231236
}

0 commit comments

Comments
 (0)