Skip to content

Commit f0b9565

Browse files
committed
SdcaBinaryTrainer, SdcaMultiClassTrainer, SdcaRegressionTrainer
1 parent 2a77e5e commit f0b9565

31 files changed

+422
-181
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs

+8-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Linq;
33
using Microsoft.ML.Data;
4+
using Microsoft.ML.Trainers;
45

56
namespace Microsoft.ML.Samples.Dynamic
67
{
@@ -59,15 +60,13 @@ public static void SDCA_BinaryClassification()
5960
// If we wanted to specify more advanced parameters for the algorithm,
6061
// we could do so by tweaking the 'advancedSetting'.
6162
var advancedPipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
62-
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent
63-
(labelColumn: "Sentiment",
64-
featureColumn: "Features",
65-
advancedSettings: s=>
66-
{
67-
s.ConvergenceTolerance = 0.01f; // The learning rate for adjusting bias from being regularized
68-
s.NumThreads = 2; // Degree of lock-free parallelism
69-
})
70-
);
63+
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent(
64+
new SdcaBinaryTrainer.Options {
65+
LabelColumn = "Sentiment",
66+
FeatureColumn = "Features",
67+
ConvergenceTolerance = 0.01f, // The learning rate for adjusting bias from being regularized
68+
NumThreads = 2, // Degree of lock-free parallelism
69+
}));
7170

7271
// Run Cross-Validation on this second pipeline.
7372
var cvResults_advancedPipeline = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment", numFolds: 3);

src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs

+24-31
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
using Microsoft.ML.Training;
2424
using Microsoft.ML.Transforms;
2525

26-
[assembly: LoadableClass(typeof(SdcaBinaryTrainer), typeof(SdcaBinaryTrainer.Arguments),
26+
[assembly: LoadableClass(typeof(SdcaBinaryTrainer), typeof(SdcaBinaryTrainer.Options),
2727
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
2828
SdcaBinaryTrainer.UserNameValue,
2929
SdcaBinaryTrainer.LoadNameValue,
@@ -253,21 +253,19 @@ protected enum MetricKind
253253

254254
private const string RegisterName = nameof(SdcaTrainerBase<TArgs, TTransformer, TModel>);
255255

256-
private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn, Action<TArgs> advancedSettings = null)
256+
private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn)
257257
{
258258
var args = new TArgs();
259259

260-
// Apply the advanced args, if the user supplied any.
261-
advancedSettings?.Invoke(args);
262260
args.FeatureColumn = featureColumn;
263261
args.LabelColumn = labelColumn.Name;
264262
return args;
265263
}
266264

267265
internal SdcaTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn,
268-
SchemaShape.Column weight = default, Action<TArgs> advancedSettings = null, float? l2Const = null,
266+
SchemaShape.Column weight = default, float? l2Const = null,
269267
float? l1Threshold = null, int? maxIterations = null)
270-
: this(env, ArgsInit(featureColumn, labelColumn, advancedSettings), labelColumn, weight, l2Const, l1Threshold, maxIterations)
268+
: this(env, ArgsInit(featureColumn, labelColumn), labelColumn, weight, l2Const, l1Threshold, maxIterations)
271269
{
272270
}
273271

@@ -1398,13 +1396,13 @@ public void Add(Double summand)
13981396
}
13991397
}
14001398

1401-
public sealed class SdcaBinaryTrainer : SdcaTrainerBase<SdcaBinaryTrainer.Arguments, BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor>
1399+
public sealed class SdcaBinaryTrainer : SdcaTrainerBase<SdcaBinaryTrainer.Options, BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor>
14021400
{
14031401
public const string LoadNameValue = "SDCA";
14041402

14051403
internal const string UserNameValue = "Fast Linear (SA-SDCA)";
14061404

1407-
public sealed class Arguments : ArgumentsBase
1405+
public sealed class Options : ArgumentsBase
14081406
{
14091407
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
14101408
public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory();
@@ -1449,21 +1447,16 @@ internal override void Check(IHostEnvironment env)
14491447
/// <param name="l2Const">The L2 regularization hyperparameter.</param>
14501448
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
14511449
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
1452-
/// <param name="advancedSettings">A delegate to set more settings.
1453-
/// The settings here will override the ones provided in the direct method signature,
1454-
/// if both are present and have different values.
1455-
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
1456-
public SdcaBinaryTrainer(IHostEnvironment env,
1450+
internal SdcaBinaryTrainer(IHostEnvironment env,
14571451
string labelColumn = DefaultColumnNames.Label,
14581452
string featureColumn = DefaultColumnNames.Features,
14591453
string weightColumn = null,
14601454
ISupportSdcaClassificationLoss loss = null,
14611455
float? l2Const = null,
14621456
float? l1Threshold = null,
1463-
int? maxIterations = null,
1464-
Action<Arguments> advancedSettings = null)
1465-
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings,
1466-
l2Const, l1Threshold, maxIterations)
1457+
int? maxIterations = null)
1458+
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn),
1459+
l2Const, l1Threshold, maxIterations)
14671460
{
14681461
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
14691462
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
@@ -1503,11 +1496,11 @@ public SdcaBinaryTrainer(IHostEnvironment env,
15031496
_outputColumns = outCols.ToArray();
15041497
}
15051498

1506-
internal SdcaBinaryTrainer(IHostEnvironment env, Arguments args,
1499+
internal SdcaBinaryTrainer(IHostEnvironment env, Options options,
15071500
string featureColumn, string labelColumn, string weightColumn = null)
1508-
: base(env, args, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
1501+
: base(env, options, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
15091502
{
1510-
_loss = args.LossFunction.CreateComponent(env);
1503+
_loss = options.LossFunction.CreateComponent(env);
15111504
Loss = _loss;
15121505
Info = new TrainerInfo(calibration: !(_loss is LogLoss));
15131506
_positiveInstanceWeight = Args.PositiveInstanceWeight;
@@ -1544,8 +1537,8 @@ internal SdcaBinaryTrainer(IHostEnvironment env, Arguments args,
15441537

15451538
}
15461539

1547-
public SdcaBinaryTrainer(IHostEnvironment env, Arguments args)
1548-
: this(env, args, args.FeatureColumn, args.LabelColumn)
1540+
internal SdcaBinaryTrainer(IHostEnvironment env, Options options)
1541+
: this(env, options, options.FeatureColumn, options.LabelColumn)
15491542
{
15501543
}
15511544

@@ -1731,15 +1724,15 @@ internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env,
17311724
/// Initializes a new instance of <see cref="StochasticGradientDescentClassificationTrainer"/>
17321725
/// </summary>
17331726
/// <param name="env">The environment to use.</param>
1734-
/// <param name="args">Advanced arguments to the algorithm.</param>
1735-
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Options args)
1736-
: base(env, args.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
1727+
/// <param name="options">Advanced arguments to the algorithm.</param>
1728+
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Options options)
1729+
: base(env, options.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn, options.WeightColumn.IsExplicit))
17371730
{
1738-
args.Check(env);
1739-
_loss = args.LossFunction.CreateComponent(env);
1731+
options.Check(env);
1732+
_loss = options.LossFunction.CreateComponent(env);
17401733
Info = new TrainerInfo(calibration: !(_loss is LogLoss), supportIncrementalTrain: true);
1741-
NeedShuffle = args.Shuffle;
1742-
_args = args;
1734+
NeedShuffle = options.Shuffle;
1735+
_args = options;
17431736
}
17441737

17451738
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
@@ -1979,14 +1972,14 @@ public static partial class Sdca
19791972
ShortName = SdcaBinaryTrainer.LoadNameValue,
19801973
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/member[@name=""SDCA""]/*' />",
19811974
@"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/example[@name=""StochasticDualCoordinateAscentBinaryClassifier""]/*'/>" })]
1982-
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaBinaryTrainer.Arguments input)
1975+
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaBinaryTrainer.Options input)
19831976
{
19841977
Contracts.CheckValue(env, nameof(env));
19851978
var host = env.Register("TrainSDCA");
19861979
host.CheckValue(input, nameof(input));
19871980
EntryPointUtils.CheckInputArgs(host, input);
19881981

1989-
return LearnerEntryPointsUtils.Train<SdcaBinaryTrainer.Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
1982+
return LearnerEntryPointsUtils.Train<SdcaBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
19901983
() => new SdcaBinaryTrainer(host, input),
19911984
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
19921985
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);

src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs

+13-18
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
using Microsoft.ML.Training;
2020
using Float = System.Single;
2121

22-
[assembly: LoadableClass(SdcaMultiClassTrainer.Summary, typeof(SdcaMultiClassTrainer), typeof(SdcaMultiClassTrainer.Arguments),
22+
[assembly: LoadableClass(SdcaMultiClassTrainer.Summary, typeof(SdcaMultiClassTrainer), typeof(SdcaMultiClassTrainer.Options),
2323
new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
2424
SdcaMultiClassTrainer.UserNameValue,
2525
SdcaMultiClassTrainer.LoadNameValue,
@@ -29,14 +29,14 @@ namespace Microsoft.ML.Trainers
2929
{
3030
// SDCA linear multiclass trainer.
3131
/// <include file='doc.xml' path='doc/members/member[@name="SDCA"]/*' />
32-
public class SdcaMultiClassTrainer : SdcaTrainerBase<SdcaMultiClassTrainer.Arguments, MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters>
32+
public class SdcaMultiClassTrainer : SdcaTrainerBase<SdcaMultiClassTrainer.Options, MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters>
3333
{
3434
public const string LoadNameValue = "SDCAMC";
3535
public const string UserNameValue = "Fast Linear Multi-class Classification (SA-SDCA)";
3636
public const string ShortName = "sasdcamc";
3737
internal const string Summary = "The SDCA linear multi-class classification trainer.";
3838

39-
public sealed class Arguments : ArgumentsBase
39+
public sealed class Options : ArgumentsBase
4040
{
4141
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
4242
public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory();
@@ -57,41 +57,36 @@ public sealed class Arguments : ArgumentsBase
5757
/// <param name="l2Const">The L2 regularization hyperparameter.</param>
5858
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
5959
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
60-
/// <param name="advancedSettings">A delegate to set more settings.
61-
/// The settings here will override the ones provided in the direct method signature,
62-
/// if both are present and have different values.
63-
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
6460
public SdcaMultiClassTrainer(IHostEnvironment env,
6561
string labelColumn = DefaultColumnNames.Label,
6662
string featureColumn = DefaultColumnNames.Features,
6763
string weights = null,
6864
ISupportSdcaClassificationLoss loss = null,
6965
float? l2Const = null,
7066
float? l1Threshold = null,
71-
int? maxIterations = null,
72-
Action<Arguments> advancedSettings = null)
73-
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights), advancedSettings,
74-
l2Const, l1Threshold, maxIterations)
67+
int? maxIterations = null)
68+
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights),
69+
l2Const, l1Threshold, maxIterations)
7570
{
7671
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
7772
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
7873
_loss = loss ?? Args.LossFunction.CreateComponent(env);
7974
Loss = _loss;
8075
}
8176

82-
internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args,
77+
internal SdcaMultiClassTrainer(IHostEnvironment env, Options options,
8378
string featureColumn, string labelColumn, string weightColumn = null)
84-
: base(env, args, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
79+
: base(env, options, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
8580
{
8681
Host.CheckValue(labelColumn, nameof(labelColumn));
8782
Host.CheckValue(featureColumn, nameof(featureColumn));
8883

89-
_loss = args.LossFunction.CreateComponent(env);
84+
_loss = options.LossFunction.CreateComponent(env);
9085
Loss = _loss;
9186
}
9287

93-
internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args)
94-
: this(env, args, args.FeatureColumn, args.LabelColumn)
88+
internal SdcaMultiClassTrainer(IHostEnvironment env, Options options)
89+
: this(env, options, options.FeatureColumn, options.LabelColumn)
9590
{
9691
}
9792

@@ -455,14 +450,14 @@ public static partial class Sdca
455450
ShortName = SdcaMultiClassTrainer.ShortName,
456451
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/member[@name=""SDCA""]/*' />",
457452
@"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/example[@name=""StochasticDualCoordinateAscentClassifier""]/*' />" })]
458-
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, SdcaMultiClassTrainer.Arguments input)
453+
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, SdcaMultiClassTrainer.Options input)
459454
{
460455
Contracts.CheckValue(env, nameof(env));
461456
var host = env.Register("TrainSDCA");
462457
host.CheckValue(input, nameof(input));
463458
EntryPointUtils.CheckInputArgs(host, input);
464459

465-
return LearnerEntryPointsUtils.Train<SdcaMultiClassTrainer.Arguments, CommonOutputs.MulticlassClassificationOutput>(host, input,
460+
return LearnerEntryPointsUtils.Train<SdcaMultiClassTrainer.Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
466461
() => new SdcaMultiClassTrainer(host, input),
467462
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
468463
}

0 commit comments

Comments
 (0)