-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Modify API for advanced settings. (SDCA) #2093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
1a58429
60eb9d5
2a77e5e
f0b9565
192cfeb
0f9854e
fd1cdf6
72c7d81
86b7a35
0e10e0b
5e4377b
feaaa65
5fca29e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,15 +23,15 @@ | |
using Microsoft.ML.Training; | ||
using Microsoft.ML.Transforms; | ||
|
||
[assembly: LoadableClass(typeof(SdcaBinaryTrainer), typeof(SdcaBinaryTrainer.Arguments), | ||
[assembly: LoadableClass(typeof(SdcaBinaryTrainer), typeof(SdcaBinaryTrainer.Options), | ||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, | ||
SdcaBinaryTrainer.UserNameValue, | ||
SdcaBinaryTrainer.LoadNameValue, | ||
"LinearClassifier", | ||
"lc", | ||
"sasdca")] | ||
|
||
[assembly: LoadableClass(typeof(StochasticGradientDescentClassificationTrainer), typeof(StochasticGradientDescentClassificationTrainer.Arguments), | ||
[assembly: LoadableClass(typeof(StochasticGradientDescentClassificationTrainer), typeof(StochasticGradientDescentClassificationTrainer.Options), | ||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, | ||
StochasticGradientDescentClassificationTrainer.UserNameValue, | ||
StochasticGradientDescentClassificationTrainer.LoadNameValue, | ||
|
@@ -69,6 +69,13 @@ private protected LinearTrainerBase(IHostEnvironment env, string featureColumn, | |
{ | ||
} | ||
|
||
private protected LinearTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, | ||
SchemaShape.Column weightColumn) | ||
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), | ||
labelColumn, weightColumn) | ||
{ | ||
} | ||
|
||
private protected override TModel TrainModelCore(TrainContext context) | ||
{ | ||
Host.CheckValue(context, nameof(context)); | ||
|
@@ -246,21 +253,19 @@ protected enum MetricKind | |
|
||
private const string RegisterName = nameof(SdcaTrainerBase<TArgs, TTransformer, TModel>); | ||
|
||
private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn, Action<TArgs> advancedSettings = null) | ||
private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn) | ||
{ | ||
var args = new TArgs(); | ||
|
||
// Apply the advanced args, if the user supplied any. | ||
advancedSettings?.Invoke(args); | ||
args.FeatureColumn = featureColumn; | ||
args.LabelColumn = labelColumn.Name; | ||
return args; | ||
} | ||
|
||
internal SdcaTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
delete #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
SchemaShape.Column weight = default, Action<TArgs> advancedSettings = null, float? l2Const = null, | ||
SchemaShape.Column weight = default, float? l2Const = null, | ||
float? l1Threshold = null, int? maxIterations = null) | ||
: this(env, ArgsInit(featureColumn, labelColumn, advancedSettings), labelColumn, weight, l2Const, l1Threshold, maxIterations) | ||
: this(env, ArgsInit(featureColumn, labelColumn), labelColumn, weight, l2Const, l1Threshold, maxIterations) | ||
{ | ||
} | ||
|
||
|
@@ -1391,12 +1396,13 @@ public void Add(Double summand) | |
} | ||
} | ||
|
||
public sealed class SdcaBinaryTrainer : SdcaTrainerBase<SdcaBinaryTrainer.Arguments, BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor> | ||
public sealed class SdcaBinaryTrainer : SdcaTrainerBase<SdcaBinaryTrainer.Options, BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor> | ||
{ | ||
public const string LoadNameValue = "SDCA"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
all those strings should be internal.. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there some separate issue for it that I can link to this comment ? I would like to keep the PRs for public API focused on API rather than fixing issues across the codebase :) In reply to: 248771027 [](ancestors = 248771027) |
||
|
||
internal const string UserNameValue = "Fast Linear (SA-SDCA)"; | ||
|
||
public sealed class Arguments : ArgumentsBase | ||
public sealed class Options : ArgumentsBase | ||
{ | ||
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] | ||
public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory(); | ||
|
@@ -1441,21 +1447,16 @@ internal override void Check(IHostEnvironment env) | |
/// <param name="l2Const">The L2 regularization hyperparameter.</param> | ||
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param> | ||
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param> | ||
/// <param name="advancedSettings">A delegate to set more settings. | ||
/// The settings here will override the ones provided in the direct method signature, | ||
/// if both are present and have different values. | ||
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param> | ||
public SdcaBinaryTrainer(IHostEnvironment env, | ||
internal SdcaBinaryTrainer(IHostEnvironment env, | ||
string labelColumn = DefaultColumnNames.Label, | ||
string featureColumn = DefaultColumnNames.Features, | ||
string weightColumn = null, | ||
ISupportSdcaClassificationLoss loss = null, | ||
float? l2Const = null, | ||
float? l1Threshold = null, | ||
int? maxIterations = null, | ||
Action<Arguments> advancedSettings = null) | ||
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, | ||
l2Const, l1Threshold, maxIterations) | ||
int? maxIterations = null) | ||
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), | ||
l2Const, l1Threshold, maxIterations) | ||
{ | ||
Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
|
@@ -1495,11 +1496,11 @@ public SdcaBinaryTrainer(IHostEnvironment env, | |
_outputColumns = outCols.ToArray(); | ||
} | ||
|
||
internal SdcaBinaryTrainer(IHostEnvironment env, Arguments args, | ||
internal SdcaBinaryTrainer(IHostEnvironment env, Options options, | ||
string featureColumn, string labelColumn, string weightColumn = null) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
are those still needed? can they be passed to options? The other ctor below can go away than. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
: base(env, args, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) | ||
: base(env, options, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) | ||
{ | ||
_loss = args.LossFunction.CreateComponent(env); | ||
_loss = options.LossFunction.CreateComponent(env); | ||
Loss = _loss; | ||
Info = new TrainerInfo(calibration: !(_loss is LogLoss)); | ||
_positiveInstanceWeight = Args.PositiveInstanceWeight; | ||
|
@@ -1536,8 +1537,8 @@ internal SdcaBinaryTrainer(IHostEnvironment env, Arguments args, | |
|
||
} | ||
|
||
public SdcaBinaryTrainer(IHostEnvironment env, Arguments args) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
i think this is needed for the SignatureTrainer. Doesn't have to be public. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the constructor above has this signature. so its not really going away. In reply to: 248771919 [](ancestors = 248771919) |
||
: this(env, args, args.FeatureColumn, args.LabelColumn) | ||
internal SdcaBinaryTrainer(IHostEnvironment env, Options options) | ||
: this(env, options, options.FeatureColumn, options.LabelColumn) | ||
{ | ||
} | ||
|
||
|
@@ -1594,7 +1595,7 @@ public sealed class StochasticGradientDescentClassificationTrainer : | |
internal const string UserNameValue = "Hogwild SGD (binary)"; | ||
internal const string ShortName = "HogwildSGD"; | ||
|
||
public sealed class Arguments : LearnerInputBaseWithWeight | ||
public sealed class Options : LearnerInputBaseWithWeight | ||
{ | ||
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] | ||
public ISupportClassificationLossFactory LossFunction = new LogLossFactory(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
XML docs #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
@@ -1669,7 +1670,7 @@ internal static class Defaults | |
} | ||
|
||
private readonly IClassificationLoss _loss; | ||
private readonly Arguments _args; | ||
private readonly Options _args; | ||
|
||
protected override bool ShuffleData => _args.Shuffle; | ||
|
||
|
@@ -1688,29 +1689,24 @@ internal static class Defaults | |
/// <param name="initLearningRate">The initial learning rate used by SGD.</param> | ||
/// <param name="l2Weight">The L2 regularizer constant.</param> | ||
/// <param name="loss">The loss function to use.</param> | ||
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param> | ||
public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, | ||
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'd remove all ctors but the one with the (IHostEnvironment env, Arguments) signature, unless is needed for inheritance. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
string labelColumn = DefaultColumnNames.Label, | ||
string featureColumn = DefaultColumnNames.Features, | ||
string weightColumn = null, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be set to Defaults.WeightColumn? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
int maxIterations = Arguments.Defaults.MaxIterations, | ||
double initLearningRate = Arguments.Defaults.InitLearningRate, | ||
float l2Weight = Arguments.Defaults.L2Weight, | ||
ISupportClassificationLossFactory loss = null, | ||
Action<Arguments> advancedSettings = null) | ||
int maxIterations = Options.Defaults.MaxIterations, | ||
double initLearningRate = Options.Defaults.InitLearningRate, | ||
float l2Weight = Options.Defaults.L2Weight, | ||
ISupportClassificationLossFactory loss = null) | ||
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weightColumn) | ||
{ | ||
Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
|
||
_args = new Arguments(); | ||
_args = new Options(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Could you make this _options? So that it is consistent with the general renaming? #Resolved |
||
_args.MaxIterations = maxIterations; | ||
_args.InitLearningRate = initLearningRate; | ||
_args.L2Weight = l2Weight; | ||
|
||
// Apply the advanced args, if the user supplied any. | ||
advancedSettings?.Invoke(_args); | ||
|
||
_args.FeatureColumn = featureColumn; | ||
_args.LabelColumn = labelColumn; | ||
_args.WeightColumn = weightColumn; | ||
|
@@ -1727,14 +1723,16 @@ public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, | |
/// <summary> | ||
/// Initializes a new instance of <see cref="StochasticGradientDescentClassificationTrainer"/> | ||
/// </summary> | ||
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Arguments args) | ||
: base(env, args.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn), args.WeightColumn) | ||
/// <param name="env">The environment to use.</param> | ||
/// <param name="options">Advanced arguments to the algorithm.</param> | ||
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Options options) | ||
: base(env, options.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn, options.WeightColumn.IsExplicit)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think you can replace this with: options.WeightColumn.IsExplicit ? options.WeightColumn : null That way you won't need the other constructor for the base. This is a strange problem from wanting to keep the optional for the weightColumn which is used in Maml. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be a pretty simple fix, which reduces the number of constructors in the base class, I would check if it works :) In reply to: 248104112 [](ancestors = 248104112,248102195) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modified. Thanks for debugging ! In reply to: 248104779 [](ancestors = 248104779,248104112,248102195) |
||
{ | ||
args.Check(env); | ||
_loss = args.LossFunction.CreateComponent(env); | ||
options.Check(env); | ||
_loss = options.LossFunction.CreateComponent(env); | ||
Info = new TrainerInfo(calibration: !(_loss is LogLoss), supportIncrementalTrain: true); | ||
NeedShuffle = args.Shuffle; | ||
_args = args; | ||
NeedShuffle = options.Shuffle; | ||
_args = options; | ||
} | ||
|
||
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) | ||
|
@@ -1947,14 +1945,14 @@ private protected override void CheckLabel(RoleMappedData examples, out int weig | |
} | ||
|
||
[TlcModule.EntryPoint(Name = "Trainers.StochasticGradientDescentBinaryClassifier", Desc = "Train an Hogwild SGD binary model.", UserName = UserNameValue, ShortName = ShortName)] | ||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input) | ||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
var host = env.Register("TrainHogwildSGD"); | ||
host.CheckValue(input, nameof(input)); | ||
EntryPointUtils.CheckInputArgs(host, input); | ||
|
||
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input, | ||
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input, | ||
() => new StochasticGradientDescentClassificationTrainer(host, input), | ||
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), | ||
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn), | ||
|
@@ -1974,14 +1972,14 @@ public static partial class Sdca | |
ShortName = SdcaBinaryTrainer.LoadNameValue, | ||
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/member[@name=""SDCA""]/*' />", | ||
@"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/example[@name=""StochasticDualCoordinateAscentBinaryClassifier""]/*'/>" })] | ||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaBinaryTrainer.Arguments input) | ||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaBinaryTrainer.Options input) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
var host = env.Register("TrainSDCA"); | ||
host.CheckValue(input, nameof(input)); | ||
EntryPointUtils.CheckInputArgs(host, input); | ||
|
||
return LearnerEntryPointsUtils.Train<SdcaBinaryTrainer.Arguments, CommonOutputs.BinaryClassificationOutput>(host, input, | ||
return LearnerEntryPointsUtils.Train<SdcaBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input, | ||
() => new SdcaBinaryTrainer(host, input), | ||
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), | ||
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need this? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this to account for the corresponding changes in line #1728-1730
I saw some tests fail saying "Weights" column not found. The fix was to specify
options.WeightColumn.IsExplicit
in the call below.TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn, options.WeightColumn.IsExplicit)
In reply to: 248006972 [](ancestors = 248006972)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got rid of this
In reply to: 248058147 [](ancestors = 248058147,248006972)