Skip to content

Commit 0b9ca00

Browse files
authored
More pigstensions (#1084)
* AP xtensions * lbfgs derived classes take more arguments in their public ctors * adding pigstensions for lr, multilr, possion * Ogd static xtensions. * namespace change for pigstensions
1 parent f10212c commit 0b9ca00

File tree

18 files changed

+781
-104
lines changed

18 files changed

+781
-104
lines changed

src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public Arguments()
5151
BasePredictors = new[]
5252
{
5353
ComponentFactoryUtils.CreateFromFunction(
54-
env => new OnlineGradientDescentTrainer(env, new OnlineGradientDescentTrainer.Arguments()))
54+
env => new OnlineGradientDescentTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features))
5555
};
5656
}
5757
}

src/Microsoft.ML.KMeansClustering/KMeansStatic.cs

+17-11
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using Microsoft.ML.Runtime;
56
using Microsoft.ML.Runtime.Data;
67
using Microsoft.ML.Runtime.KMeans;
7-
using Microsoft.ML.StaticPipe;
88
using Microsoft.ML.StaticPipe.Runtime;
99
using System;
1010

11-
namespace Microsoft.ML.Trainers
11+
namespace Microsoft.ML.StaticPipe
1212
{
1313
/// <summary>
1414
/// The trainer context extensions for the <see cref="KMeansPlusPlusTrainer"/>.
@@ -35,16 +35,22 @@ public static (Vector<float> score, Key<uint> predictedLabel) KMeans(this Cluste
3535
Action<KMeansPlusPlusTrainer.Arguments> advancedSettings = null,
3636
Action<KMeansPredictor> onFit = null)
3737
{
38-
var rec = new TrainerEstimatorReconciler.Clustering(
39-
(env, featuresName, weightsName) =>
40-
{
41-
var trainer = new KMeansPlusPlusTrainer(env, featuresName, clustersCount, weightsName, advancedSettings);
38+
Contracts.CheckValue(features, nameof(features));
39+
Contracts.CheckValueOrNull(weights);
40+
Contracts.CheckParam(clustersCount > 1, nameof(clustersCount), "If provided, must be greater than 1.");
41+
Contracts.CheckValueOrNull(onFit);
42+
Contracts.CheckValueOrNull(advancedSettings);
4243

43-
if (onFit != null)
44-
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
45-
else
46-
return trainer;
47-
}, features, weights);
44+
var rec = new TrainerEstimatorReconciler.Clustering(
45+
(env, featuresName, weightsName) =>
46+
{
47+
var trainer = new KMeansPlusPlusTrainer(env, featuresName, clustersCount, weightsName, advancedSettings);
48+
49+
if (onFit != null)
50+
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
51+
else
52+
return trainer;
53+
}, features, weights);
4854

4955
return rec.Output;
5056
}

src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1359,7 +1359,7 @@ public void Add(Double summand)
13591359
public sealed class LinearClassificationTrainer : SdcaTrainerBase<BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor>
13601360
{
13611361
public const string LoadNameValue = "SDCA";
1362-
public const string UserNameValue = "Fast Linear (SA-SDCA)";
1362+
internal const string UserNameValue = "Fast Linear (SA-SDCA)";
13631363

13641364
public sealed class Arguments : ArgumentsBase
13651365
{

src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs

+49-23
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,29 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
2626
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization weight", ShortName = "l2", SortOrder = 50)]
2727
[TGUI(Label = "L2 Weight", Description = "Weight of L2 regularizer term", SuggestedSweeps = "0,0.1,1")]
2828
[TlcModule.SweepableFloatParamAttribute(0.0f, 1.0f, numSteps: 4)]
29-
public float L2Weight = 1;
29+
public float L2Weight = Defaults.L2Weight;
3030

3131
[Argument(ArgumentType.AtMostOnce, HelpText = "L1 regularization weight", ShortName = "l1", SortOrder = 50)]
3232
[TGUI(Label = "L1 Weight", Description = "Weight of L1 regularizer term", SuggestedSweeps = "0,0.1,1")]
3333
[TlcModule.SweepableFloatParamAttribute(0.0f, 1.0f, numSteps: 4)]
34-
public float L1Weight = 1;
34+
public float L1Weight = Defaults.L1Weight;
3535

3636
[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance parameter for optimization convergence. Lower = slower, more accurate",
3737
ShortName = "ot", SortOrder = 50)]
3838
[TGUI(Label = "Optimization Tolerance", Description = "Threshold for optimizer convergence", SuggestedSweeps = "1e-4,1e-7")]
3939
[TlcModule.SweepableDiscreteParamAttribute(new object[] { 1e-4f, 1e-7f })]
40-
public float OptTol = 1e-7f;
40+
public float OptTol = Defaults.OptTol;
4141

4242
[Argument(ArgumentType.AtMostOnce, HelpText = "Memory size for L-BFGS. Lower=faster, less accurate",
4343
ShortName = "m", SortOrder = 50)]
4444
[TGUI(Description = "Memory size for L-BFGS", SuggestedSweeps = "5,20,50")]
4545
[TlcModule.SweepableDiscreteParamAttribute("MemorySize", new object[] { 5, 20, 50 })]
46-
public int MemorySize = 20;
46+
public int MemorySize = Defaults.MemorySize;
4747

4848
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum iterations.", ShortName = "maxiter")]
4949
[TGUI(Label = "Max Number of Iterations")]
5050
[TlcModule.SweepableLongParamAttribute("MaxIterations", 1, int.MaxValue)]
51-
public int MaxIterations = int.MaxValue;
51+
public int MaxIterations = Defaults.MaxIterations;
5252

5353
[Argument(ArgumentType.AtMostOnce, HelpText = "Run SGD to initialize LR weights, converging to this tolerance",
5454
ShortName = "sgd")]
@@ -90,7 +90,17 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
9090
public bool DenseOptimizer = false;
9191

9292
[Argument(ArgumentType.AtMostOnce, HelpText = "Enforce non-negative weights", ShortName = "nn", SortOrder = 90)]
93-
public bool EnforceNonNegativity = false;
93+
public bool EnforceNonNegativity = Defaults.EnforceNonNegativity;
94+
95+
internal static class Defaults
96+
{
97+
internal const float L2Weight = 1;
98+
internal const float L1Weight = 1;
99+
internal const float OptTol = 1e-7f;
100+
internal const int MemorySize = 20;
101+
internal const int MaxIterations = int.MaxValue;
102+
internal const bool EnforceNonNegativity = false;
103+
}
94104
}
95105

96106
private const string RegisterName = nameof(LbfgsTrainerBase<TArgs, TTransformer, TModel>);
@@ -142,40 +152,56 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
142152
public override TrainerInfo Info => _info;
143153

144154
internal LbfgsTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn,
145-
string weightColumn = null, Action<TArgs> advancedSettings = null)
146-
: this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings), labelColumn)
155+
string weightColumn, Action<TArgs> advancedSettings, float l1Weight,
156+
float l2Weight,
157+
float optimizationTolerance,
158+
int memorySize,
159+
bool enforceNoNegativity)
160+
: this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings), labelColumn,
161+
l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity)
147162
{
148163
}
149164

150-
internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn)
165+
internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn,
166+
float? l1Weight = null,
167+
float? l2Weight = null,
168+
float? optimizationTolerance = null,
169+
int? memorySize = null,
170+
bool? enforceNoNegativity = null)
151171
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
152172
labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
153173
{
154174
Host.CheckValue(args, nameof(args));
155175
Args = args;
156176

157-
Contracts.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null,
177+
Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null,
158178
nameof(Args.NumThreads), "numThreads must be positive (or empty for default)");
159-
Contracts.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative");
160-
Contracts.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative");
161-
Contracts.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive");
162-
Contracts.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive");
163-
Contracts.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive");
164-
Contracts.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative");
165-
Contracts.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative");
166-
167-
L2Weight = Args.L2Weight;
168-
L1Weight = Args.L1Weight;
169-
OptTol = Args.OptTol;
170-
MemorySize = Args.MemorySize;
179+
Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative");
180+
Host.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative");
181+
Host.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive");
182+
Host.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive");
183+
Host.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive");
184+
Host.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative");
185+
Host.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative");
186+
187+
Host.CheckParam(!(l2Weight < 0), nameof(l2Weight), "Must be non-negative, if provided.");
188+
Host.CheckParam(!(l1Weight < 0), nameof(l1Weight), "Must be non-negative, if provided");
189+
Host.CheckParam(!(optimizationTolerance <= 0), nameof(optimizationTolerance), "Must be positive, if provided.");
190+
Host.CheckParam(!(memorySize <= 0), nameof(memorySize), "Must be positive, if provided.");
191+
192+
// Review: Warn about the overriding behavior
193+
L2Weight = l2Weight ?? Args.L2Weight;
194+
L1Weight = l1Weight ?? Args.L1Weight;
195+
OptTol = optimizationTolerance ?? Args.OptTol;
196+
MemorySize = memorySize ?? Args.MemorySize;
171197
MaxIterations = Args.MaxIterations;
172198
SgdInitializationTolerance = Args.SgdInitializationTolerance;
173199
Quiet = Args.Quiet;
174200
InitWtsDiameter = Args.InitWtsDiameter;
175201
UseThreads = Args.UseThreads;
176202
NumThreads = Args.NumThreads;
177203
DenseOptimizer = Args.DenseOptimizer;
178-
EnforceNonNegativity = Args.EnforceNonNegativity;
204+
EnforceNonNegativity = enforceNoNegativity ?? Args.EnforceNonNegativity;
179205

180206
if (EnforceNonNegativity && ShowTrainingStats)
181207
{

0 commit comments

Comments
 (0)