-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Modify API for advanced settings (several learners) #2163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
961fe9d
91939d9
fed254d
c447d96
7d15dfb
5cf6c19
c89741f
5b316f2
ea9a5de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,7 +60,10 @@ public Arguments() | |
BasePredictors = new[] | ||
{ | ||
ComponentFactoryUtils.CreateFromFunction( | ||
env => new MulticlassLogisticRegression(env, LabelColumn, FeatureColumn)) | ||
env => { | ||
var mlContext = new MLContext(); | ||
return mlContext.MulticlassClassification.Trainers.LogisticRegression(LabelColumn, FeatureColumn); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here too. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
}) | ||
}; | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,7 +54,10 @@ public Arguments() | |
BasePredictors = new[] | ||
{ | ||
ComponentFactoryUtils.CreateFromFunction( | ||
env => new OnlineGradientDescentTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features)) | ||
env => { | ||
var mlContext = new MLContext(); | ||
return mlContext.Regression.Trainers.OnlineGradientDescent(); | ||
}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one more #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
}; | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
using Microsoft.ML.Numeric; | ||
using Microsoft.ML.Training; | ||
|
||
[assembly: LoadableClass(LogisticRegression.Summary, typeof(LogisticRegression), typeof(LogisticRegression.Arguments), | ||
[assembly: LoadableClass(LogisticRegression.Summary, typeof(LogisticRegression), typeof(LogisticRegression.Options), | ||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, | ||
LogisticRegression.UserNameValue, | ||
LogisticRegression.LoadNameValue, | ||
|
@@ -30,15 +30,15 @@ namespace Microsoft.ML.Learners | |
|
||
/// <include file='doc.xml' path='doc/members/member[@name="LBFGS"]/*' /> | ||
/// <include file='doc.xml' path='docs/members/example[@name="LogisticRegressionBinaryClassifier"]/*' /> | ||
public sealed partial class LogisticRegression : LbfgsTrainerBase<LogisticRegression.Arguments, BinaryPredictionTransformer<ParameterMixingCalibratedPredictor>, ParameterMixingCalibratedPredictor> | ||
public sealed partial class LogisticRegression : LbfgsTrainerBase<LogisticRegression.Options, BinaryPredictionTransformer<ParameterMixingCalibratedPredictor>, ParameterMixingCalibratedPredictor> | ||
{ | ||
public const string LoadNameValue = "LogisticRegression"; | ||
internal const string UserNameValue = "Logistic Regression"; | ||
internal const string ShortName = "lr"; | ||
internal const string Summary = "Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can " | ||
+ "be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function."; | ||
|
||
public sealed class Arguments : ArgumentsBase | ||
public sealed class Options : ArgumentsBase | ||
{ | ||
/// <summary> | ||
/// If set to <value>true</value>training statistics will be generated at the end of training. | ||
|
@@ -53,7 +53,7 @@ public sealed class Arguments : ArgumentsBase | |
/// <summary> | ||
/// The instance of <see cref="ComputeLRTrainingStd"/> that computes the std of the training statistics, at the end of training. | ||
/// The calculations are not part of Microsoft.ML package, due to the size of MKL. | ||
/// If you need these calculations, add the Microsoft.ML.HalLearners package, and initialize <see cref="LogisticRegression.Arguments.StdComputer"/>. | ||
/// If you need these calculations, add the Microsoft.ML.HalLearners package, and initialize <see cref="LogisticRegression.Options.StdComputer"/>. | ||
/// to the <see cref="ComputeLRTrainingStd"/> implementation in the Microsoft.ML.HalLearners package. | ||
/// </summary> | ||
public ComputeLRTrainingStd StdComputer; | ||
|
@@ -74,18 +74,16 @@ public sealed class Arguments : ArgumentsBase | |
/// <param name="l2Weight">Weight of L2 regularizer term.</param> | ||
/// <param name="memorySize">Memory size for <see cref="LogisticRegression"/>. Low=faster, less accurate.</param> | ||
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param> | ||
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param> | ||
public LogisticRegression(IHostEnvironment env, | ||
internal LogisticRegression(IHostEnvironment env, | ||
string labelColumn = DefaultColumnNames.Label, | ||
string featureColumn = DefaultColumnNames.Features, | ||
string weights = null, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'd just delete it.. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do see it being passed to base constructor In reply to: 248543741 [](ancestors = 248543741) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
float l1Weight = Arguments.Defaults.L1Weight, | ||
float l2Weight = Arguments.Defaults.L2Weight, | ||
float optimizationTolerance = Arguments.Defaults.OptTol, | ||
int memorySize = Arguments.Defaults.MemorySize, | ||
bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, | ||
Action<Arguments> advancedSettings = null) | ||
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weights, advancedSettings, | ||
float l1Weight = Options.Defaults.L1Weight, | ||
float l2Weight = Options.Defaults.L2Weight, | ||
float optimizationTolerance = Options.Defaults.OptTol, | ||
int memorySize = Options.Defaults.MemorySize, | ||
bool enforceNoNegativity = Options.Defaults.EnforceNonNegativity) | ||
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weights, | ||
l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) | ||
{ | ||
Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); | ||
|
@@ -98,8 +96,8 @@ public LogisticRegression(IHostEnvironment env, | |
/// <summary> | ||
/// Initializes a new instance of <see cref="LogisticRegression"/> | ||
/// </summary> | ||
internal LogisticRegression(IHostEnvironment env, Arguments args) | ||
: base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) | ||
internal LogisticRegression(IHostEnvironment env, Options options) | ||
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn)) | ||
{ | ||
_posWeight = 0; | ||
ShowTrainingStats = Args.ShowTrainingStats; | ||
|
@@ -410,14 +408,14 @@ protected override ParameterMixingCalibratedPredictor CreatePredictor() | |
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/LogisticRegression/doc.xml' path='doc/members/member[@name=""LBFGS""]/*' />", | ||
@"<include file='../Microsoft.ML.StandardLearners/Standard/LogisticRegression/doc.xml' path='doc/members/example[@name=""LogisticRegressionBinaryClassifier""]/*' />"})] | ||
|
||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input) | ||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
var host = env.Register("TrainLRBinary"); | ||
host.CheckValue(input, nameof(input)); | ||
EntryPointUtils.CheckInputArgs(host, input); | ||
|
||
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input, | ||
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input, | ||
() => new LogisticRegression(host, input), | ||
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), | ||
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)); | ||
|
@@ -437,7 +435,7 @@ public abstract class ComputeLRTrainingStd | |
/// Computes the standard deviation matrix of each of the non-zero training weights, needed to calculate further the standard deviation, | ||
/// p-value and z-Score. | ||
/// The calculations are not part of Microsoft.ML package, due to the size of MKL. | ||
/// If you need these calculations, add the Microsoft.ML.HalLearners package, and initialize <see cref="LogisticRegression.Arguments.StdComputer"/> | ||
/// If you need these calculations, add the Microsoft.ML.HalLearners package, and initialize <see cref="LogisticRegression.Options.StdComputer"/> | ||
/// to the <see cref="ComputeLRTrainingStd"/> implementation in the Microsoft.ML.HalLearners package. | ||
/// Due to the existence of regularization, an approximation is used to compute the variances of the trained linear coefficients. | ||
/// </summary> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
using Microsoft.ML.Training; | ||
using Newtonsoft.Json.Linq; | ||
|
||
[assembly: LoadableClass(typeof(MulticlassLogisticRegression), typeof(MulticlassLogisticRegression.Arguments), | ||
[assembly: LoadableClass(typeof(MulticlassLogisticRegression), typeof(MulticlassLogisticRegression.Options), | ||
new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer) }, | ||
MulticlassLogisticRegression.UserNameValue, | ||
MulticlassLogisticRegression.LoadNameValue, | ||
|
@@ -38,14 +38,14 @@ namespace Microsoft.ML.Learners | |
{ | ||
/// <include file = 'doc.xml' path='doc/members/member[@name="LBFGS"]/*' /> | ||
/// <include file = 'doc.xml' path='docs/members/example[@name="LogisticRegressionClassifier"]/*' /> | ||
public sealed class MulticlassLogisticRegression : LbfgsTrainerBase<MulticlassLogisticRegression.Arguments, | ||
public sealed class MulticlassLogisticRegression : LbfgsTrainerBase<MulticlassLogisticRegression.Options, | ||
MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters> | ||
{ | ||
public const string LoadNameValue = "MultiClassLogisticRegression"; | ||
internal const string UserNameValue = "Multi-class Logistic Regression"; | ||
internal const string ShortName = "mlr"; | ||
|
||
public sealed class Arguments : ArgumentsBase | ||
public sealed class Options : ArgumentsBase | ||
{ | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Show statistics of training examples.", ShortName = "stat", SortOrder = 50)] | ||
public bool ShowTrainingStats = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
xml #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
@@ -82,19 +82,16 @@ public sealed class Arguments : ArgumentsBase | |
/// <param name="l2Weight">Weight of L2 regularizer term.</param> | ||
/// <param name="memorySize">Memory size for <see cref="LogisticRegression"/>. Low=faster, less accurate.</param> | ||
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param> | ||
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param> | ||
public MulticlassLogisticRegression(IHostEnvironment env, | ||
internal MulticlassLogisticRegression(IHostEnvironment env, | ||
string labelColumn = DefaultColumnNames.Label, | ||
string featureColumn = DefaultColumnNames.Features, | ||
string weights = null, | ||
float l1Weight = Arguments.Defaults.L1Weight, | ||
float l2Weight = Arguments.Defaults.L2Weight, | ||
float optimizationTolerance = Arguments.Defaults.OptTol, | ||
int memorySize = Arguments.Defaults.MemorySize, | ||
bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, | ||
Action<Arguments> advancedSettings = null) | ||
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), weights, advancedSettings, | ||
l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) | ||
float l1Weight = Options.Defaults.L1Weight, | ||
float l2Weight = Options.Defaults.L2Weight, | ||
float optimizationTolerance = Options.Defaults.OptTol, | ||
int memorySize = Options.Defaults.MemorySize, | ||
bool enforceNoNegativity = Options.Defaults.EnforceNonNegativity) | ||
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) | ||
{ | ||
Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
|
@@ -105,8 +102,8 @@ public MulticlassLogisticRegression(IHostEnvironment env, | |
/// <summary> | ||
/// Initializes a new instance of <see cref="MulticlassLogisticRegression"/> | ||
/// </summary> | ||
internal MulticlassLogisticRegression(IHostEnvironment env, Arguments args) | ||
: base(env, args, TrainerUtils.MakeU4ScalarColumn(args.LabelColumn)) | ||
internal MulticlassLogisticRegression(IHostEnvironment env, Options options) | ||
: base(env, options, TrainerUtils.MakeU4ScalarColumn(options.LabelColumn)) | ||
{ | ||
ShowTrainingStats = Args.ShowTrainingStats; | ||
} | ||
|
@@ -1007,14 +1004,14 @@ public partial class LogisticRegression | |
ShortName = MulticlassLogisticRegression.ShortName, | ||
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/LogisticRegression/doc.xml' path='doc/members/member[@name=""LBFGS""]/*' />", | ||
@"<include file='../Microsoft.ML.StandardLearners/Standard/LogisticRegression/doc.xml' path='doc/members/example[@name=""LogisticRegressionClassifier""]/*' />" })] | ||
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, MulticlassLogisticRegression.Arguments input) | ||
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, MulticlassLogisticRegression.Options input) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
var host = env.Register("TrainLRMultiClass"); | ||
host.CheckValue(input, nameof(input)); | ||
EntryPointUtils.CheckInputArgs(host, input); | ||
|
||
return LearnerEntryPointsUtils.Train<MulticlassLogisticRegression.Arguments, CommonOutputs.MulticlassClassificationOutput>(host, input, | ||
return LearnerEntryPointsUtils.Train<MulticlassLogisticRegression.Options, CommonOutputs.MulticlassClassificationOutput>(host, input, | ||
() => new MulticlassLogisticRegression(host, input), | ||
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), | ||
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think it is a good idea to create another context/environment.
Just use the constructor, for internal code.
Add the [BestFriend] annotation to LinearSVM, if it canot be accessed from here. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks. for the tip about using the [BestFriend] attribute
In reply to: 248542758 [](ancestors = 248542758)