-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Modify API for advanced settings. (SDCA) #2093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
1a58429
60eb9d5
2a77e5e
f0b9565
192cfeb
0f9854e
fd1cdf6
72c7d81
86b7a35
0e10e0b
5e4377b
feaaa65
5fca29e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ | |
using Microsoft.ML.Training; | ||
using Float = System.Single; | ||
|
||
[assembly: LoadableClass(SdcaMultiClassTrainer.Summary, typeof(SdcaMultiClassTrainer), typeof(SdcaMultiClassTrainer.Arguments), | ||
[assembly: LoadableClass(SdcaMultiClassTrainer.Summary, typeof(SdcaMultiClassTrainer), typeof(SdcaMultiClassTrainer.Options), | ||
new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, | ||
SdcaMultiClassTrainer.UserNameValue, | ||
SdcaMultiClassTrainer.LoadNameValue, | ||
|
@@ -29,14 +29,14 @@ namespace Microsoft.ML.Trainers | |
{ | ||
// SDCA linear multiclass trainer. | ||
/// <include file='doc.xml' path='doc/members/member[@name="SDCA"]/*' /> | ||
public class SdcaMultiClassTrainer : SdcaTrainerBase<SdcaMultiClassTrainer.Arguments, MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters> | ||
public class SdcaMultiClassTrainer : SdcaTrainerBase<SdcaMultiClassTrainer.Options, MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters> | ||
{ | ||
public const string LoadNameValue = "SDCAMC"; | ||
public const string UserNameValue = "Fast Linear Multi-class Classification (SA-SDCA)"; | ||
public const string ShortName = "sasdcamc"; | ||
internal const string Summary = "The SDCA linear multi-class classification trainer."; | ||
|
||
public sealed class Arguments : ArgumentsBase | ||
public sealed class Options : ArgumentsBase | ||
{ | ||
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] | ||
public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory(); | ||
|
@@ -57,41 +57,36 @@ public sealed class Arguments : ArgumentsBase | |
/// <param name="l2Const">The L2 regularization hyperparameter.</param> | ||
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param> | ||
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param> | ||
/// <param name="advancedSettings">A delegate to set more settings. | ||
/// The settings here will override the ones provided in the direct method signature, | ||
/// if both are present and have different values. | ||
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param> | ||
public SdcaMultiClassTrainer(IHostEnvironment env, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
same #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
string labelColumn = DefaultColumnNames.Label, | ||
string featureColumn = DefaultColumnNames.Features, | ||
string weights = null, | ||
ISupportSdcaClassificationLoss loss = null, | ||
float? l2Const = null, | ||
float? l1Threshold = null, | ||
int? maxIterations = null, | ||
Action<Arguments> advancedSettings = null) | ||
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights), advancedSettings, | ||
l2Const, l1Threshold, maxIterations) | ||
int? maxIterations = null) | ||
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights), | ||
l2Const, l1Threshold, maxIterations) | ||
{ | ||
Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
_loss = loss ?? Args.LossFunction.CreateComponent(env); | ||
Loss = _loss; | ||
} | ||
|
||
internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args, | ||
internal SdcaMultiClassTrainer(IHostEnvironment env, Options options, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
delete, maybe #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
string featureColumn, string labelColumn, string weightColumn = null) | ||
: base(env, args, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) | ||
: base(env, options, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) | ||
{ | ||
Host.CheckValue(labelColumn, nameof(labelColumn)); | ||
Host.CheckValue(featureColumn, nameof(featureColumn)); | ||
|
||
_loss = args.LossFunction.CreateComponent(env); | ||
_loss = options.LossFunction.CreateComponent(env); | ||
Loss = _loss; | ||
} | ||
|
||
internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args) | ||
: this(env, args, args.FeatureColumn, args.LabelColumn) | ||
internal SdcaMultiClassTrainer(IHostEnvironment env, Options options) | ||
: this(env, options, options.FeatureColumn, options.LabelColumn) | ||
{ | ||
} | ||
|
||
|
@@ -455,14 +450,14 @@ public static partial class Sdca | |
ShortName = SdcaMultiClassTrainer.ShortName, | ||
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/member[@name=""SDCA""]/*' />", | ||
@"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/example[@name=""StochasticDualCoordinateAscentClassifier""]/*' />" })] | ||
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, SdcaMultiClassTrainer.Arguments input) | ||
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, SdcaMultiClassTrainer.Options input) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
var host = env.Register("TrainSDCA"); | ||
host.CheckValue(input, nameof(input)); | ||
EntryPointUtils.CheckInputArgs(host, input); | ||
|
||
return LearnerEntryPointsUtils.Train<SdcaMultiClassTrainer.Arguments, CommonOutputs.MulticlassClassificationOutput>(host, input, | ||
return LearnerEntryPointsUtils.Train<SdcaMultiClassTrainer.Options, CommonOutputs.MulticlassClassificationOutput>(host, input, | ||
() => new SdcaMultiClassTrainer(host, input), | ||
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
using Microsoft.ML.Trainers; | ||
using Microsoft.ML.Training; | ||
|
||
[assembly: LoadableClass(SdcaRegressionTrainer.Summary, typeof(SdcaRegressionTrainer), typeof(SdcaRegressionTrainer.Arguments), | ||
[assembly: LoadableClass(SdcaRegressionTrainer.Summary, typeof(SdcaRegressionTrainer), typeof(SdcaRegressionTrainer.Options), | ||
new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, | ||
SdcaRegressionTrainer.UserNameValue, | ||
SdcaRegressionTrainer.LoadNameValue, | ||
|
@@ -24,19 +24,19 @@ | |
namespace Microsoft.ML.Trainers | ||
{ | ||
/// <include file='doc.xml' path='doc/members/member[@name="SDCA"]/*' /> | ||
public sealed class SdcaRegressionTrainer : SdcaTrainerBase<SdcaRegressionTrainer.Arguments, RegressionPredictionTransformer<LinearRegressionModelParameters>, LinearRegressionModelParameters> | ||
public sealed class SdcaRegressionTrainer : SdcaTrainerBase<SdcaRegressionTrainer.Options, RegressionPredictionTransformer<LinearRegressionModelParameters>, LinearRegressionModelParameters> | ||
{ | ||
internal const string LoadNameValue = "SDCAR"; | ||
internal const string UserNameValue = "Fast Linear Regression (SA-SDCA)"; | ||
internal const string ShortName = "sasdcar"; | ||
internal const string Summary = "The SDCA linear regression trainer."; | ||
|
||
public sealed class Arguments : ArgumentsBase | ||
public sealed class Options : ArgumentsBase | ||
{ | ||
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] | ||
public ISupportSdcaRegressionLossFactory LossFunction = new SquaredLossFactory(); | ||
|
||
public Arguments() | ||
public Options() | ||
{ | ||
// Using a higher default tolerance for better RMS. | ||
ConvergenceTolerance = 0.01f; | ||
|
@@ -61,40 +61,35 @@ public Arguments() | |
/// <param name="l2Const">The L2 regularization hyperparameter.</param> | ||
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param> | ||
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param> | ||
/// <param name="advancedSettings">A delegate to set more settings. | ||
/// The settings here will override the ones provided in the direct method signature, | ||
/// if both are present and have different values. | ||
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param> | ||
public SdcaRegressionTrainer(IHostEnvironment env, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
shall this ctor go away? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
string labelColumn = DefaultColumnNames.Label, | ||
string featureColumn = DefaultColumnNames.Features, | ||
string weights = null, | ||
ISupportSdcaRegressionLoss loss = null, | ||
float? l2Const = null, | ||
float? l1Threshold = null, | ||
int? maxIterations = null, | ||
Action<Arguments> advancedSettings = null) | ||
: base(env, featureColumn, TrainerUtils.MakeR4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights), advancedSettings, | ||
l2Const, l1Threshold, maxIterations) | ||
int? maxIterations = null) | ||
: base(env, featureColumn, TrainerUtils.MakeR4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights), | ||
l2Const, l1Threshold, maxIterations) | ||
{ | ||
Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
_loss = loss ?? Args.LossFunction.CreateComponent(env); | ||
Loss = _loss; | ||
} | ||
|
||
internal SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null) | ||
: base(env, args, TrainerUtils.MakeR4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) | ||
internal SdcaRegressionTrainer(IHostEnvironment env, Options options, string featureColumn, string labelColumn, string weightColumn = null) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'd remove this one too. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
: base(env, options, TrainerUtils.MakeR4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) | ||
{ | ||
Host.CheckValue(labelColumn, nameof(labelColumn)); | ||
Host.CheckValue(featureColumn, nameof(featureColumn)); | ||
|
||
_loss = args.LossFunction.CreateComponent(env); | ||
_loss = options.LossFunction.CreateComponent(env); | ||
Loss = _loss; | ||
} | ||
|
||
internal SdcaRegressionTrainer(IHostEnvironment env, Arguments args) | ||
: this(env, args, args.FeatureColumn, args.LabelColumn) | ||
internal SdcaRegressionTrainer(IHostEnvironment env, Options options) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
isnt there suposed to be a constructor with common arguments and one with options? If we make the other two constructors internal (and eventually remove), then we have just this constructor with the options. Is that OK? Does SDCARegressionTrainer have any common args? #Pending There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought that main point with the API changes is that we end up with two constructors that are publicly available to the user:
Now the first constructor can use the options constructor under the covers, i.e. Foo(arg1, arg2): this(new Options(){ arg1 = arg1, arg2 = arg2 }... but we would still expose two different constructors publicly that can be used. Is that not the case for SDCA? In reply to: 248870888 [](ancestors = 248870888,248865494) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In reply to: 248880304 [](ancestors = 248880304,248870888,248865494) |
||
: this(env, options, options.FeatureColumn, options.LabelColumn) | ||
{ | ||
} | ||
|
||
|
@@ -178,14 +173,14 @@ public static partial class Sdca | |
ShortName = SdcaRegressionTrainer.ShortName, | ||
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/member[@name=""SDCA""]/*' />", | ||
@"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/example[@name=""StochasticDualCoordinateAscentRegressor""]/*' />" })] | ||
public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, SdcaRegressionTrainer.Arguments input) | ||
public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, SdcaRegressionTrainer.Options input) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
var host = env.Register("TrainSDCA"); | ||
host.CheckValue(input, nameof(input)); | ||
EntryPointUtils.CheckInputArgs(host, input); | ||
|
||
return LearnerEntryPointsUtils.Train<SdcaRegressionTrainer.Arguments, CommonOutputs.RegressionOutput>(host, input, | ||
return LearnerEntryPointsUtils.Train<SdcaRegressionTrainer.Options, CommonOutputs.RegressionOutput>(host, input, | ||
() => new SdcaRegressionTrainer(host, input), | ||
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we planning to rename this too? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point the focus is to get the public APIs fixed.
Rename of internal classes / variables is low priority at this point.
In reply to: 248102673 [](ancestors = 248102673)