Skip to content

Commit 2a77e5e

Browse files
committed
StochasticGradientDescentClassificationTrainer
1 parent 60eb9d5 commit 2a77e5e

File tree

7 files changed

+89
-38
lines changed

7 files changed

+89
-38
lines changed

src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs

+22-18
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"lc",
3232
"sasdca")]
3333

34-
[assembly: LoadableClass(typeof(StochasticGradientDescentClassificationTrainer), typeof(StochasticGradientDescentClassificationTrainer.Arguments),
34+
[assembly: LoadableClass(typeof(StochasticGradientDescentClassificationTrainer), typeof(StochasticGradientDescentClassificationTrainer.Options),
3535
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
3636
StochasticGradientDescentClassificationTrainer.UserNameValue,
3737
StochasticGradientDescentClassificationTrainer.LoadNameValue,
@@ -69,6 +69,13 @@ private protected LinearTrainerBase(IHostEnvironment env, string featureColumn,
6969
{
7070
}
7171

72+
private protected LinearTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn,
73+
SchemaShape.Column weightColumn)
74+
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn),
75+
labelColumn, weightColumn)
76+
{
77+
}
78+
7279
private protected override TModel TrainModelCore(TrainContext context)
7380
{
7481
Host.CheckValue(context, nameof(context));
@@ -1595,7 +1602,7 @@ public sealed class StochasticGradientDescentClassificationTrainer :
15951602
internal const string UserNameValue = "Hogwild SGD (binary)";
15961603
internal const string ShortName = "HogwildSGD";
15971604

1598-
public sealed class Arguments : LearnerInputBaseWithWeight
1605+
public sealed class Options : LearnerInputBaseWithWeight
15991606
{
16001607
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
16011608
public ISupportClassificationLossFactory LossFunction = new LogLossFactory();
@@ -1670,7 +1677,7 @@ internal static class Defaults
16701677
}
16711678

16721679
private readonly IClassificationLoss _loss;
1673-
private readonly Arguments _args;
1680+
private readonly Options _args;
16741681

16751682
protected override bool ShuffleData => _args.Shuffle;
16761683

@@ -1689,29 +1696,24 @@ internal static class Defaults
16891696
/// <param name="initLearningRate">The initial learning rate used by SGD.</param>
16901697
/// <param name="l2Weight">The L2 regularizer constant.</param>
16911698
/// <param name="loss">The loss function to use.</param>
1692-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
1693-
public StochasticGradientDescentClassificationTrainer(IHostEnvironment env,
1699+
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env,
16941700
string labelColumn = DefaultColumnNames.Label,
16951701
string featureColumn = DefaultColumnNames.Features,
16961702
string weightColumn = null,
1697-
int maxIterations = Arguments.Defaults.MaxIterations,
1698-
double initLearningRate = Arguments.Defaults.InitLearningRate,
1699-
float l2Weight = Arguments.Defaults.L2Weight,
1700-
ISupportClassificationLossFactory loss = null,
1701-
Action<Arguments> advancedSettings = null)
1703+
int maxIterations = Options.Defaults.MaxIterations,
1704+
double initLearningRate = Options.Defaults.InitLearningRate,
1705+
float l2Weight = Options.Defaults.L2Weight,
1706+
ISupportClassificationLossFactory loss = null)
17021707
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weightColumn)
17031708
{
17041709
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
17051710
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
17061711

1707-
_args = new Arguments();
1712+
_args = new Options();
17081713
_args.MaxIterations = maxIterations;
17091714
_args.InitLearningRate = initLearningRate;
17101715
_args.L2Weight = l2Weight;
17111716

1712-
// Apply the advanced args, if the user supplied any.
1713-
advancedSettings?.Invoke(_args);
1714-
17151717
_args.FeatureColumn = featureColumn;
17161718
_args.LabelColumn = labelColumn;
17171719
_args.WeightColumn = weightColumn;
@@ -1728,8 +1730,10 @@ public StochasticGradientDescentClassificationTrainer(IHostEnvironment env,
17281730
/// <summary>
17291731
/// Initializes a new instance of <see cref="StochasticGradientDescentClassificationTrainer"/>
17301732
/// </summary>
1731-
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Arguments args)
1732-
: base(env, args.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn), args.WeightColumn)
1733+
/// <param name="env">The environment to use.</param>
1734+
/// <param name="args">Advanced arguments to the algorithm.</param>
1735+
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Options args)
1736+
: base(env, args.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
17331737
{
17341738
args.Check(env);
17351739
_loss = args.LossFunction.CreateComponent(env);
@@ -1948,14 +1952,14 @@ private protected override void CheckLabel(RoleMappedData examples, out int weig
19481952
}
19491953

19501954
[TlcModule.EntryPoint(Name = "Trainers.StochasticGradientDescentBinaryClassifier", Desc = "Train an Hogwild SGD binary model.", UserName = UserNameValue, ShortName = ShortName)]
1951-
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input)
1955+
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
19521956
{
19531957
Contracts.CheckValue(env, nameof(env));
19541958
var host = env.Register("TrainHogwildSGD");
19551959
host.CheckValue(input, nameof(input));
19561960
EntryPointUtils.CheckInputArgs(host, input);
19571961

1958-
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
1962+
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
19591963
() => new StochasticGradientDescentClassificationTrainer(host, input),
19601964
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
19611965
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn),

src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs

+20-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace Microsoft.ML
1414
{
1515
using LRArguments = LogisticRegression.Arguments;
16-
using SgdArguments = StochasticGradientDescentClassificationTrainer.Arguments;
16+
using SgdOptions = StochasticGradientDescentClassificationTrainer.Options;
1717

1818
/// <summary>
1919
/// TrainerEstimator extension methods.
@@ -31,20 +31,32 @@ public static class StandardLearnersCatalog
3131
/// <param name="initLearningRate">The initial learning rate used by SGD.</param>
3232
/// <param name="l2Weight">The L2 regularization constant.</param>
3333
/// <param name="loss">The loss function to use.</param>
34-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
3534
public static StochasticGradientDescentClassificationTrainer StochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
3635
string labelColumn = DefaultColumnNames.Label,
3736
string featureColumn = DefaultColumnNames.Features,
3837
string weights = null,
39-
int maxIterations = SgdArguments.Defaults.MaxIterations,
40-
double initLearningRate = SgdArguments.Defaults.InitLearningRate,
41-
float l2Weight = SgdArguments.Defaults.L2Weight,
42-
ISupportClassificationLossFactory loss = null,
43-
Action<SgdArguments> advancedSettings = null)
38+
int maxIterations = SgdOptions.Defaults.MaxIterations,
39+
double initLearningRate = SgdOptions.Defaults.InitLearningRate,
40+
float l2Weight = SgdOptions.Defaults.L2Weight,
41+
ISupportClassificationLossFactory loss = null)
42+
{
43+
Contracts.CheckValue(ctx, nameof(ctx));
44+
var env = CatalogUtils.GetEnvironment(ctx);
45+
return new StochasticGradientDescentClassificationTrainer(env, labelColumn, featureColumn, weights, maxIterations, initLearningRate, l2Weight, loss);
46+
}
47+
48+
/// <summary>
49+
/// Predict a target using a linear binary classification model trained with the <see cref="StochasticGradientDescentClassificationTrainer"/> trainer.
50+
/// </summary>
51+
/// <param name="ctx">The binary classificaiton context trainer object.</param>
52+
/// <param name="advancedSettings">Advanced arguments to the algorithm.</param>
53+
public static StochasticGradientDescentClassificationTrainer StochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
54+
SgdOptions advancedSettings)
4455
{
4556
Contracts.CheckValue(ctx, nameof(ctx));
4657
var env = CatalogUtils.GetEnvironment(ctx);
47-
return new StochasticGradientDescentClassificationTrainer(env, labelColumn, featureColumn, weights, maxIterations, initLearningRate, l2Weight, loss, advancedSettings);
58+
59+
return new StochasticGradientDescentClassificationTrainer(env, advancedSettings);
4860
}
4961

5062
/// <summary>

src/Microsoft.ML.StaticPipe/SgdStatic.cs

+40-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace Microsoft.ML.StaticPipe
1111
{
12-
using Arguments = StochasticGradientDescentClassificationTrainer.Arguments;
12+
using Options = StochasticGradientDescentClassificationTrainer.Options;
1313

1414
/// <summary>
1515
/// Binary Classification trainer estimators.
@@ -27,7 +27,6 @@ public static class SgdStaticExtensions
2727
/// <param name="initLearningRate">The initial learning rate used by SGD.</param>
2828
/// <param name="l2Weight">The L2 regularization constant.</param>
2929
/// <param name="loss">The loss function to use.</param>
30-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
3130
/// <param name="onFit">A delegate that is called every time the
3231
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the
3332
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive
@@ -38,17 +37,51 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred
3837
Scalar<bool> label,
3938
Vector<float> features,
4039
Scalar<float> weights = null,
41-
int maxIterations = Arguments.Defaults.MaxIterations,
42-
double initLearningRate = Arguments.Defaults.InitLearningRate,
43-
float l2Weight = Arguments.Defaults.L2Weight,
40+
int maxIterations = Options.Defaults.MaxIterations,
41+
double initLearningRate = Options.Defaults.InitLearningRate,
42+
float l2Weight = Options.Defaults.L2Weight,
4443
ISupportClassificationLossFactory loss = null,
45-
Action<Arguments> advancedSettings = null,
4644
Action<IPredictorWithFeatureWeights<float>> onFit = null)
4745
{
4846
var rec = new TrainerEstimatorReconciler.BinaryClassifier(
4947
(env, labelName, featuresName, weightsName) =>
5048
{
51-
var trainer = new StochasticGradientDescentClassificationTrainer(env, labelName, featuresName, weightsName, maxIterations, initLearningRate, l2Weight, loss, advancedSettings);
49+
var trainer = new StochasticGradientDescentClassificationTrainer(env, labelName, featuresName, weightsName, maxIterations, initLearningRate, l2Weight, loss);
50+
51+
if (onFit != null)
52+
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
53+
return trainer;
54+
55+
}, label, features, weights);
56+
57+
return rec.Output;
58+
}
59+
60+
/// <summary>
61+
/// Predict a target using a linear binary classification model trained with the <see cref="Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer"/> trainer.
62+
/// </summary>
63+
/// <param name="ctx">The binary classificaiton context trainer object.</param>
64+
/// <param name="label">The name of the label column.</param>
65+
/// <param name="features">The name of the feature column.</param>
66+
/// <param name="weights">The name for the example weight column.</param>
67+
/// <param name="advancedSettings">Advanced arguments to the algorithm.</param>
68+
/// <param name="onFit">A delegate that is called every time the
69+
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the
70+
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive
71+
/// the linear model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to
72+
/// be informed about what was learnt.</param>
73+
/// <returns>The predicted output.</returns>
74+
public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) StochasticGradientDescentClassificationTrainer(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
75+
Scalar<bool> label,
76+
Vector<float> features,
77+
Scalar<float> weights,
78+
Options advancedSettings,
79+
Action<IPredictorWithFeatureWeights<float>> onFit = null)
80+
{
81+
var rec = new TrainerEstimatorReconciler.BinaryClassifier(
82+
(env, labelName, featuresName, weightsName) =>
83+
{
84+
var trainer = new StochasticGradientDescentClassificationTrainer(env, advancedSettings);
5285

5386
if (onFit != null)
5487
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));

test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Trainers.PoissonRegressor Train an Poisson regression model. Microsoft.ML.Traine
6868
Trainers.StochasticDualCoordinateAscentBinaryClassifier Train an SDCA binary model. Microsoft.ML.Trainers.Sdca TrainBinary Microsoft.ML.Trainers.SdcaBinaryTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
6969
Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Trainers.Sdca TrainMultiClass Microsoft.ML.Trainers.SdcaMultiClassTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
7070
Trainers.StochasticDualCoordinateAscentRegressor The SDCA linear regression trainer. Microsoft.ML.Trainers.Sdca TrainRegression Microsoft.ML.Trainers.SdcaRegressionTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
71-
Trainers.StochasticGradientDescentBinaryClassifier Train an Hogwild SGD binary model. Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer TrainBinary Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
71+
Trainers.StochasticGradientDescentBinaryClassifier Train an Hogwild SGD binary model. Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer TrainBinary Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
7272
Trainers.SymSgdBinaryClassifier Train a symbolic SGD. Microsoft.ML.Trainers.SymSgd.SymSgdClassificationTrainer TrainSymSgd Microsoft.ML.Trainers.SymSgd.SymSgdClassificationTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
7373
Transforms.ApproximateBootstrapSampler Approximate bootstrap sampling. Microsoft.ML.Transforms.BootstrapSample GetSample Microsoft.ML.Transforms.BootstrapSamplingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
7474
Transforms.BinaryPredictionScoreColumnsRenamer For binary prediction, it renames the PredictedLabel and Score columns to include the name of the positive class. Microsoft.ML.EntryPoints.ScoreModel RenameBinaryPredictionScoreColumns Microsoft.ML.EntryPoints.ScoreModel+RenameBinaryPredictionScoreColumnsInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput

test/Microsoft.ML.StaticPipelineTesting/Training.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -940,8 +940,7 @@ public void HogwildSGDBinaryClassification()
940940
var est = reader.MakeNewEstimator()
941941
.Append(r => (r.label, preds: ctx.Trainers.StochasticGradientDescentClassificationTrainer(r.label, r.features,
942942
l2Weight: 0,
943-
onFit: (p) => { pred = p; },
944-
advancedSettings: s => s.NumThreads = 1)));
943+
onFit: (p) => { pred = p; })));
945944

946945
var pipe = reader.Append(est);
947946

test/Microsoft.ML.Tests/FeatureContributionTests.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using Microsoft.ML.Internal.Internallearn;
1010
using Microsoft.ML.Internal.Utilities;
1111
using Microsoft.ML.RunTests;
12+
using Microsoft.ML.Trainers;
1213
using Microsoft.ML.Training;
1314
using Microsoft.ML.Transforms;
1415
using Xunit;
@@ -152,7 +153,9 @@ public void TestSDCABinary()
152153
[Fact]
153154
public void TestSGDBinary()
154155
{
155-
TestFeatureContribution(ML.BinaryClassification.Trainers.StochasticGradientDescent(advancedSettings: args => { args.NumThreads = 1; }), GetSparseDataset(TaskType.BinaryClassification, 100), "SGDBinary");
156+
TestFeatureContribution(ML.BinaryClassification.Trainers.StochasticGradientDescent(
157+
new StochasticGradientDescentClassificationTrainer.Options { NumThreads = 1}),
158+
GetSparseDataset(TaskType.BinaryClassification, 100), "SGDBinary");
156159
}
157160

158161
[Fact]

test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public void KMeansEstimator()
8686
public void TestEstimatorHogwildSGD()
8787
{
8888
(IEstimator<ITransformer> pipe, IDataView dataView) = GetBinaryClassificationPipeline();
89-
var trainer = new StochasticGradientDescentClassificationTrainer(Env, "Label", "Features");
89+
var trainer = ML.BinaryClassification.Trainers.StochasticGradientDescent();
9090
var pipeWithTrainer = pipe.Append(trainer);
9191
TestEstimatorCore(pipeWithTrainer, dataView);
9292

0 commit comments

Comments
 (0)