diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
index 3fe853aff5..344adbe39f 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
@@ -51,7 +51,7 @@ public Arguments()
BasePredictors = new[]
{
ComponentFactoryUtils.CreateFromFunction(
- env => new OnlineGradientDescentTrainer(env, new OnlineGradientDescentTrainer.Arguments()))
+ env => new OnlineGradientDescentTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features))
};
}
}
diff --git a/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs b/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs
index 2a64ef8d31..d75ddddd74 100644
--- a/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs
+++ b/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs
@@ -2,13 +2,13 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.KMeans;
-using Microsoft.ML.StaticPipe;
using Microsoft.ML.StaticPipe.Runtime;
using System;
-namespace Microsoft.ML.Trainers
+namespace Microsoft.ML.StaticPipe
{
///
/// The trainer context extensions for the .
@@ -35,16 +35,22 @@ public static (Vector score, Key predictedLabel) KMeans(this Cluste
Action advancedSettings = null,
Action onFit = null)
{
- var rec = new TrainerEstimatorReconciler.Clustering(
- (env, featuresName, weightsName) =>
- {
- var trainer = new KMeansPlusPlusTrainer(env, featuresName, clustersCount, weightsName, advancedSettings);
+ Contracts.CheckValue(features, nameof(features));
+ Contracts.CheckValueOrNull(weights);
+ Contracts.CheckParam(clustersCount > 1, nameof(clustersCount), "If provided, must be greater than 1.");
+ Contracts.CheckValueOrNull(onFit);
+ Contracts.CheckValueOrNull(advancedSettings);
- if (onFit != null)
- return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
- else
- return trainer;
- }, features, weights);
+ var rec = new TrainerEstimatorReconciler.Clustering(
+ (env, featuresName, weightsName) =>
+ {
+ var trainer = new KMeansPlusPlusTrainer(env, featuresName, clustersCount, weightsName, advancedSettings);
+
+ if (onFit != null)
+ return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
+ else
+ return trainer;
+ }, features, weights);
return rec.Output;
}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
index 8722b79ea6..29e2e1e89f 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
@@ -1359,7 +1359,7 @@ public void Add(Double summand)
public sealed class LinearClassificationTrainer : SdcaTrainerBase, TScalarPredictor>
{
public const string LoadNameValue = "SDCA";
- public const string UserNameValue = "Fast Linear (SA-SDCA)";
+ internal const string UserNameValue = "Fast Linear (SA-SDCA)";
public sealed class Arguments : ArgumentsBase
{
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
index 4286f83379..864e0ba3ab 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
@@ -26,29 +26,29 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization weight", ShortName = "l2", SortOrder = 50)]
[TGUI(Label = "L2 Weight", Description = "Weight of L2 regularizer term", SuggestedSweeps = "0,0.1,1")]
[TlcModule.SweepableFloatParamAttribute(0.0f, 1.0f, numSteps: 4)]
- public float L2Weight = 1;
+ public float L2Weight = Defaults.L2Weight;
[Argument(ArgumentType.AtMostOnce, HelpText = "L1 regularization weight", ShortName = "l1", SortOrder = 50)]
[TGUI(Label = "L1 Weight", Description = "Weight of L1 regularizer term", SuggestedSweeps = "0,0.1,1")]
[TlcModule.SweepableFloatParamAttribute(0.0f, 1.0f, numSteps: 4)]
- public float L1Weight = 1;
+ public float L1Weight = Defaults.L1Weight;
[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance parameter for optimization convergence. Lower = slower, more accurate",
ShortName = "ot", SortOrder = 50)]
[TGUI(Label = "Optimization Tolerance", Description = "Threshold for optimizer convergence", SuggestedSweeps = "1e-4,1e-7")]
[TlcModule.SweepableDiscreteParamAttribute(new object[] { 1e-4f, 1e-7f })]
- public float OptTol = 1e-7f;
+ public float OptTol = Defaults.OptTol;
[Argument(ArgumentType.AtMostOnce, HelpText = "Memory size for L-BFGS. Lower=faster, less accurate",
ShortName = "m", SortOrder = 50)]
[TGUI(Description = "Memory size for L-BFGS", SuggestedSweeps = "5,20,50")]
[TlcModule.SweepableDiscreteParamAttribute("MemorySize", new object[] { 5, 20, 50 })]
- public int MemorySize = 20;
+ public int MemorySize = Defaults.MemorySize;
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum iterations.", ShortName = "maxiter")]
[TGUI(Label = "Max Number of Iterations")]
[TlcModule.SweepableLongParamAttribute("MaxIterations", 1, int.MaxValue)]
- public int MaxIterations = int.MaxValue;
+ public int MaxIterations = Defaults.MaxIterations;
[Argument(ArgumentType.AtMostOnce, HelpText = "Run SGD to initialize LR weights, converging to this tolerance",
ShortName = "sgd")]
@@ -90,7 +90,17 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
public bool DenseOptimizer = false;
[Argument(ArgumentType.AtMostOnce, HelpText = "Enforce non-negative weights", ShortName = "nn", SortOrder = 90)]
- public bool EnforceNonNegativity = false;
+ public bool EnforceNonNegativity = Defaults.EnforceNonNegativity;
+
+ internal static class Defaults
+ {
+ internal const float L2Weight = 1;
+ internal const float L1Weight = 1;
+ internal const float OptTol = 1e-7f;
+ internal const int MemorySize = 20;
+ internal const int MaxIterations = int.MaxValue;
+ internal const bool EnforceNonNegativity = false;
+ }
}
private const string RegisterName = nameof(LbfgsTrainerBase);
@@ -142,32 +152,48 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
public override TrainerInfo Info => _info;
internal LbfgsTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn,
- string weightColumn = null, Action advancedSettings = null)
- : this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings), labelColumn)
+ string weightColumn, Action advancedSettings, float l1Weight,
+ float l2Weight,
+ float optimizationTolerance,
+ int memorySize,
+ bool enforceNoNegativity)
+ : this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings), labelColumn,
+ l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity)
{
}
- internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn)
+ internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn,
+ float? l1Weight = null,
+ float? l2Weight = null,
+ float? optimizationTolerance = null,
+ int? memorySize = null,
+ bool? enforceNoNegativity = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
{
Host.CheckValue(args, nameof(args));
Args = args;
- Contracts.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null,
+ Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null,
nameof(Args.NumThreads), "numThreads must be positive (or empty for default)");
- Contracts.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative");
- Contracts.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative");
- Contracts.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive");
- Contracts.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive");
- Contracts.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive");
- Contracts.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative");
- Contracts.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative");
-
- L2Weight = Args.L2Weight;
- L1Weight = Args.L1Weight;
- OptTol = Args.OptTol;
- MemorySize = Args.MemorySize;
+ Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative");
+ Host.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative");
+ Host.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive");
+ Host.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive");
+ Host.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive");
+ Host.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative");
+ Host.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative");
+
+ Host.CheckParam(!(l2Weight < 0), nameof(l2Weight), "Must be non-negative, if provided.");
+ Host.CheckParam(!(l1Weight < 0), nameof(l1Weight), "Must be non-negative, if provided");
+ Host.CheckParam(!(optimizationTolerance <= 0), nameof(optimizationTolerance), "Must be positive, if provided.");
+ Host.CheckParam(!(memorySize <= 0), nameof(memorySize), "Must be positive, if provided.");
+
+ // Review: Warn about the overriding behavior
+ L2Weight = l2Weight ?? Args.L2Weight;
+ L1Weight = l1Weight ?? Args.L1Weight;
+ OptTol = optimizationTolerance ?? Args.OptTol;
+ MemorySize = memorySize ?? Args.MemorySize;
MaxIterations = Args.MaxIterations;
SgdInitializationTolerance = Args.SgdInitializationTolerance;
Quiet = Args.Quiet;
@@ -175,7 +201,7 @@ internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column l
UseThreads = Args.UseThreads;
NumThreads = Args.NumThreads;
DenseOptimizer = Args.DenseOptimizer;
- EnforceNonNegativity = Args.EnforceNonNegativity;
+ EnforceNonNegativity = enforceNoNegativity ?? Args.EnforceNonNegativity;
if (EnforceNonNegativity && ShowTrainingStats)
{
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatics.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatics.cs
new file mode 100644
index 0000000000..71227d2155
--- /dev/null
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatics.cs
@@ -0,0 +1,196 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Internal.Calibration;
+using Microsoft.ML.Runtime.Learners;
+using Microsoft.ML.StaticPipe.Runtime;
+
+namespace Microsoft.ML.StaticPipe
+{
+ using Arguments = LogisticRegression.Arguments;
+
+ ///
+ /// Binary Classification trainer estimators.
+ ///
+ public static partial class BinaryClassificationTrainers
+ {
+ ///
+ /// Predict a target using a linear binary classification model trained with the trainer.
+ ///
+ /// The binary classificaiton context trainer object.
+ /// The label, or dependent variable.
+ /// The features, or independent variables.
+ /// The optional example weights.
+ /// Enforce non-negative weights.
+ /// Weight of L1 regularization term.
+ /// Weight of L2 regularization term.
+ /// Memory size for . Lower=faster, less accurate.
+ /// Threshold for optimizer convergence.
+ /// A delegate that is called every time the
+ /// method is called on the
+ /// instance created out of this. This delegate will receive
+ /// the linear model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to
+ /// be informed about what was learnt.
+ /// The predicted output.
+ public static (Scalar score, Scalar probability, Scalar predictedLabel) LogisticRegressionBinaryClassifier(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
+ Scalar label,
+ Vector features,
+ Scalar weights = null,
+ float l1Weight = Arguments.Defaults.L1Weight,
+ float l2Weight = Arguments.Defaults.L2Weight,
+ float optimizationTolerance = Arguments.Defaults.OptTol,
+ int memorySize = Arguments.Defaults.MemorySize,
+ bool enoforceNoNegativity = Arguments.Defaults.EnforceNonNegativity,
+ Action onFit = null)
+ {
+ LbfgsStaticUtils.ValidateParams(label, features, weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity, onFit);
+
+ var rec = new TrainerEstimatorReconciler.BinaryClassifier(
+ (env, labelName, featuresName, weightsName) =>
+ {
+ var trainer = new LogisticRegression(env, featuresName, labelName, weightsName,
+ l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity);
+
+ if (onFit != null)
+ return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
+ return trainer;
+
+ }, label, features, weights);
+
+ return rec.Output;
+ }
+ }
+
+ ///
+ /// Regression trainer estimators.
+ ///
+ public static partial class RegressionTrainers
+ {
+
+ ///
+ /// Predict a target using a linear regression model trained with the trainer.
+ ///
+ /// The regression context trainer object.
+ /// The label, or dependent variable.
+ /// The features, or independent variables.
+ /// The optional example weights.
+ /// Enforce non-negative weights.
+ /// Weight of L1 regularization term.
+ /// Weight of L2 regularization term.
+ /// Memory size for . Lower=faster, less accurate.
+ /// Threshold for optimizer convergence.
+ /// A delegate that is called every time the
+ /// method is called on the
+ /// instance created out of this. This delegate will receive
+ /// the linear model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to
+ /// be informed about what was learnt.
+ /// The predicted output.
+ public static Scalar PoissonRegression(this RegressionContext.RegressionTrainers ctx,
+ Scalar label,
+ Vector features,
+ Scalar weights = null,
+ float l1Weight = Arguments.Defaults.L1Weight,
+ float l2Weight = Arguments.Defaults.L2Weight,
+ float optimizationTolerance = Arguments.Defaults.OptTol,
+ int memorySize = Arguments.Defaults.MemorySize,
+ bool enoforceNoNegativity = Arguments.Defaults.EnforceNonNegativity,
+ Action onFit = null)
+ {
+ LbfgsStaticUtils.ValidateParams(label, features, weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity, onFit);
+
+ var rec = new TrainerEstimatorReconciler.Regression(
+ (env, labelName, featuresName, weightsName) =>
+ {
+ var trainer = new PoissonRegression(env, featuresName, labelName, weightsName,
+ l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity);
+
+ if (onFit != null)
+ return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
+
+ return trainer;
+ }, label, features, weights);
+
+ return rec.Score;
+ }
+ }
+
+ ///
+ /// MultiClass Classification trainer estimators.
+ ///
+ public static partial class MultiClassClassificationTrainers
+ {
+
+ ///
+ /// Predict a target using a linear multiclass classification model trained with the trainer.
+ ///
+ /// The multiclass classification context trainer object.
+ /// The label, or dependent variable.
+ /// The features, or independent variables.
+ /// The optional example weights.
+ /// Enforce non-negative weights.
+ /// Weight of L1 regularization term.
+ /// Weight of L2 regularization term.
+ /// Memory size for . Lower=faster, less accurate.
+ /// Threshold for optimizer convergence.
+ /// A delegate that is called every time the
+ /// method is called on the
+ /// instance created out of this. This delegate will receive
+ /// the linear model that was trained. Note that this action cannot change the
+ /// result in any way; it is only a way for the caller to be informed about what was learnt.
+ /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label.
+ public static (Vector score, Key predictedLabel)
+ MultiClassLogisticRegression(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx,
+ Key label,
+ Vector features,
+ Scalar weights = null,
+ float l1Weight = Arguments.Defaults.L1Weight,
+ float l2Weight = Arguments.Defaults.L2Weight,
+ float optimizationTolerance = Arguments.Defaults.OptTol,
+ int memorySize = Arguments.Defaults.MemorySize,
+ bool enoforceNoNegativity = Arguments.Defaults.EnforceNonNegativity,
+ Action onFit = null)
+ {
+ LbfgsStaticUtils.ValidateParams(label, features, weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity, onFit);
+
+ var rec = new TrainerEstimatorReconciler.MulticlassClassifier(
+ (env, labelName, featuresName, weightsName) =>
+ {
+ var trainer = new MulticlassLogisticRegression(env, featuresName, labelName, weightsName,
+ l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity);
+
+ if (onFit != null)
+ return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
+ return trainer;
+ }, label, features, weights);
+
+ return rec.Output;
+ }
+
+ }
+
+ internal static class LbfgsStaticUtils{
+
+ internal static void ValidateParams(PipelineColumn label,
+ Vector features,
+ Scalar weights = null,
+ float l1Weight = Arguments.Defaults.L1Weight,
+ float l2Weight = Arguments.Defaults.L2Weight,
+ float optimizationTolerance = Arguments.Defaults.OptTol,
+ int memorySize = Arguments.Defaults.MemorySize,
+ bool enoforceNoNegativity = Arguments.Defaults.EnforceNonNegativity,
+ Delegate onFit = null)
+ {
+ Contracts.CheckValue(label, nameof(label));
+ Contracts.CheckValue(features, nameof(features));
+ Contracts.CheckParam(l2Weight >= 0, nameof(l2Weight), "Must be non-negative");
+ Contracts.CheckParam(l1Weight >= 0, nameof(l1Weight), "Must be non-negative");
+ Contracts.CheckParam(optimizationTolerance > 0, nameof(optimizationTolerance), "Must be positive");
+ Contracts.CheckParam(memorySize > 0, nameof(memorySize), "Must be positive");
+ Contracts.CheckValueOrNull(onFit);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
index 0914c63c33..87c1c1bde9 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
@@ -54,10 +54,24 @@ public sealed class Arguments : ArgumentsBase
/// The name of the label column.
/// The name of the feature column.
/// The name for the example weight column.
+ /// Enforce non-negative weights.
+ /// Weight of L1 regularizer term.
+ /// Weight of L2 regularizer term.
+ /// Memory size for . Lower=faster, less accurate.
+ /// Threshold for optimizer convergence.
/// A delegate to apply all the advanced arguments to the algorithm.
- public LogisticRegression(IHostEnvironment env, string featureColumn, string labelColumn,
- string weightColumn = null, Action advancedSettings = null)
- : base(env, featureColumn, TrainerUtils.MakeR4ScalarLabel(labelColumn), weightColumn, advancedSettings)
+ public LogisticRegression(IHostEnvironment env,
+ string featureColumn,
+ string labelColumn,
+ string weightColumn = null,
+ float l1Weight = Arguments.Defaults.L1Weight,
+ float l2Weight = Arguments.Defaults.L2Weight,
+ float optimizationTolerance = Arguments.Defaults.OptTol,
+ int memorySize = Arguments.Defaults.MemorySize,
+ bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity,
+ Action advancedSettings = null)
+ : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weightColumn, advancedSettings,
+ l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity)
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
@@ -70,7 +84,7 @@ public LogisticRegression(IHostEnvironment env, string featureColumn, string lab
/// Initializes a new instance of
///
internal LogisticRegression(IHostEnvironment env, Arguments args)
- : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn))
+ : base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
{
_posWeight = 0;
ShowTrainingStats = Args.ShowTrainingStats;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
index 52d1688638..40e3090c74 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
@@ -76,10 +76,22 @@ public sealed class Arguments : ArgumentsBase
/// The name of the label column.
/// The name of the feature column.
/// The name for the example weight column.
+ /// Enforce non-negative weights.
+ /// Weight of L1 regularizer term.
+ /// Weight of L2 regularizer term.
+ /// Memory size for . Lower=faster, less accurate.
+ /// Threshold for optimizer convergence.
/// A delegate to apply all the advanced arguments to the algorithm.
public MulticlassLogisticRegression(IHostEnvironment env, string featureColumn, string labelColumn,
- string weightColumn = null, Action advancedSettings = null)
- : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), weightColumn, advancedSettings)
+ string weightColumn = null,
+ float l1Weight = Arguments.Defaults.L1Weight,
+ float l2Weight = Arguments.Defaults.L2Weight,
+ float optimizationTolerance = Arguments.Defaults.OptTol,
+ int memorySize = Arguments.Defaults.MemorySize,
+ bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity,
+ Action advancedSettings = null)
+ : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), weightColumn, advancedSettings,
+ l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity)
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs
index 70ee279a1c..5871c932dd 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs
@@ -22,12 +22,12 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", SortOrder = 50)]
[TGUI(Label = "Learning rate", SuggestedSweeps = "0.01,0.1,0.5,1.0")]
[TlcModule.SweepableDiscreteParam("LearningRate", new object[] { 0.01, 0.1, 0.5, 1.0 })]
- public Float LearningRate = 1;
+ public Float LearningRate = AveragedDefaultArgs.LearningRate;
[Argument(ArgumentType.AtMostOnce, HelpText = "Decrease learning rate", ShortName = "decreaselr", SortOrder = 50)]
[TGUI(Label = "Decrease Learning Rate", Description = "Decrease learning rate as iterations progress")]
[TlcModule.SweepableDiscreteParam("DecreaseLearningRate", new object[] { false, true })]
- public bool DecreaseLearningRate = false;
+ public bool DecreaseLearningRate = AveragedDefaultArgs.DecreaseLearningRate;
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of examples after which weights will be reset to the current average", ShortName = "numreset")]
public long? ResetWeightsAfterXExamples = null;
@@ -38,7 +38,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg", SortOrder = 50)]
[TGUI(Label = "L2 Regularization Weight")]
[TlcModule.SweepableFloatParam("L2RegularizerWeight", 0.0f, 0.4f)]
- public Float L2RegularizerWeight = 0;
+ public Float L2RegularizerWeight = AveragedDefaultArgs.L2RegularizerWeight;
[Argument(ArgumentType.AtMostOnce, HelpText = "Extra weight given to more recent updates", ShortName = "rg")]
public Float RecencyGain = 0;
@@ -51,6 +51,13 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments
[Argument(ArgumentType.AtMostOnce, HelpText = "The inexactness tolerance for averaging", ShortName = "avgtol")]
public Float AveragedTolerance = (Float)1e-2;
+
+ internal class AveragedDefaultArgs : OnlineDefaultArgs
+ {
+ internal const Float LearningRate = 1;
+ internal const bool DecreaseLearningRate = false;
+ internal const Float L2RegularizerWeight = 0;
+ }
}
public abstract class AveragedLinearTrainer : OnlineLinearTrainer
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
index f5a1dc44f5..8cbb1978a5 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
@@ -57,22 +57,21 @@ public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
{
_args = args;
LossFunction = _args.LossFunction.CreateComponent(env);
-
- _outputColumns = new[]
- {
- new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
- new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))),
- new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
- };
}
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
protected override bool NeedCalibration => true;
- private readonly SchemaShape.Column[] _outputColumns;
-
- protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns;
+ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
+ {
+ return new[]
+ {
+ // REVIEW AP is currently not calibrating. Add the probability column after fixing the behavior.
+ new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
+ new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
+ };
+ }
protected override void CheckLabel(RoleMappedData data)
{
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
index ad8442433d..d312c6b52b 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
@@ -34,8 +34,8 @@ namespace Microsoft.ML.Runtime.Learners
///
public sealed class LinearSvm : OnlineLinearTrainer, LinearBinaryPredictor>
{
- public const string LoadNameValue = "LinearSVM";
- public const string ShortName = "svm";
+ internal const string LoadNameValue = "LinearSVM";
+ internal const string ShortName = "svm";
internal const string UserNameValue = "SVM (Pegasos-Linear)";
internal const string Summary = "The idea behind support vector machines, is to map the instances into a high dimensional space "
+ "in which instances of the two classes are linearly separable, i.e., there exists a hyperplane such that all the positive examples are on one side of it, "
@@ -92,8 +92,13 @@ public LinearSvm(IHostEnvironment env, Arguments args)
Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive);
Args = args;
+ }
+
+ public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
- _outputColumns = new[]
+ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
+ {
+ return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
@@ -101,12 +106,6 @@ public LinearSvm(IHostEnvironment env, Arguments args)
};
}
- public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
-
- private readonly SchemaShape.Column[] _outputColumns;
-
- protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns;
-
protected override void CheckLabel(RoleMappedData data)
{
Contracts.AssertValue(data);
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs
index 768bb1da75..00dfb59cbd 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs
@@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Float = System.Single;
-
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
@@ -46,27 +44,69 @@ public sealed class Arguments : AveragedLinearArguments
///
public Arguments()
{
- LearningRate = (Float)0.1;
- DecreaseLearningRate = true;
+ LearningRate = OgdDefaultArgs.LearningRate;
+ DecreaseLearningRate = OgdDefaultArgs.DecreaseLearningRate;
+ }
+
+ internal class OgdDefaultArgs : AveragedDefaultArgs
+ {
+ internal new const float LearningRate = 0.1f;
+ internal new const bool DecreaseLearningRate = true;
}
}
- public OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args)
- : base(args, env, UserNameValue, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn))
+ ///
+ /// Trains a new .
+ ///
+ /// The pricate instance of .
+ /// Name of the label column.
+ /// Name of the feature column.
+ /// The learning Rate.
+ /// Decrease learning rate as iterations progress.
+ /// L2 Regularization Weight.
+ /// Number of training iterations through the data.
+ /// The name of the weights column.
+ /// The custom loss functions. Defaults to if not provided.
+ public OnlineGradientDescentTrainer(IHostEnvironment env,
+ string labelColumn,
+ string featureColumn,
+ float learningRate = Arguments.OgdDefaultArgs.LearningRate,
+ bool decreaseLearningRate = Arguments.OgdDefaultArgs.DecreaseLearningRate,
+ float l2RegularizerWeight = Arguments.OgdDefaultArgs.L2RegularizerWeight,
+ int numIterations = Arguments.OgdDefaultArgs.NumIterations,
+ string weightsColumn = null,
+ IRegressionLoss lossFunction = null)
+ : base(new Arguments
+ {
+ LearningRate = learningRate,
+ DecreaseLearningRate = decreaseLearningRate,
+ L2RegularizerWeight = l2RegularizerWeight,
+ NumIterations = numIterations,
+ LabelColumn = labelColumn,
+ FeatureColumn = featureColumn,
+ InitialWeights = weightsColumn
+
+ }, env, UserNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn))
+ {
+ LossFunction = lossFunction ?? new SquaredLoss();
+ }
+
+ internal OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args)
+ : base(args, env, UserNameValue, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn))
{
LossFunction = args.LossFunction.CreateComponent(env);
+ }
+
+ public override PredictionKind PredictionKind => PredictionKind.Regression;
- _outputColumns = new[]
+ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
+ {
+ return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
};
}
- public override PredictionKind PredictionKind => PredictionKind.Regression;
-
- private readonly SchemaShape.Column[] _outputColumns;
- protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns;
-
protected override void CheckLabel(RoleMappedData data)
{
data.CheckRegressionLabel();
@@ -75,8 +115,8 @@ protected override void CheckLabel(RoleMappedData data)
protected override LinearRegressionPredictor CreatePredictor()
{
Contracts.Assert(WeightsScale == 1);
- VBuffer weights = default(VBuffer);
- Float bias;
+ VBuffer weights = default(VBuffer);
+ float bias;
if (!Args.Averaged)
{
@@ -86,8 +126,8 @@ protected override LinearRegressionPredictor CreatePredictor()
else
{
TotalWeights.CopyTo(ref weights);
- VectorUtils.ScaleBy(ref weights, 1 / (Float)NumWeightUpdates);
- bias = TotalBias / (Float)NumWeightUpdates;
+ VectorUtils.ScaleBy(ref weights, 1 / (float)NumWeightUpdates);
+ bias = TotalBias / (float)NumWeightUpdates;
}
return new LinearRegressionPredictor(Host, ref weights, bias);
}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerStatic.cs
new file mode 100644
index 0000000000..de286a59c8
--- /dev/null
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerStatic.cs
@@ -0,0 +1,174 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Learners;
+using Microsoft.ML.StaticPipe.Runtime;
+using System;
+
+namespace Microsoft.ML.StaticPipe
+{
+ ///
+ /// Binary Classification trainer estimators.
+ ///
+ public static partial class BinaryClassificationTrainers
+ {
+ ///
+ /// Predict a target using a linear binary classification model trained with the AveragedPerceptron trainer, and a custom loss.
+ ///
+ /// The binary classification context trainer object.
+ /// The label, or dependent variable.
+ /// The features, or independent variables.
+ /// The custom loss.
+ /// The optional example weights.
+ /// The learning Rate.
+ /// Decrease learning rate as iterations progress.
+ /// L2 regularization weight.
+ /// Number of training iterations through the data.
+ /// A delegate that is called every time the
+ /// method is called on the
+ /// instance created out of this. This delegate will receive
+ /// the linear model that was trained, as well as the calibrator on top of that model. Note that this action cannot change the
+ /// result in any way; it is only a way for the caller to be informed about what was learnt.
+ /// The set of output columns including in order the predicted binary classification score (which will range
+ /// from negative to positive infinity), and the predicted label.
+ /// .
+ public static (Scalar score, Scalar predictedLabel) AveragedPerceptron(
+ this BinaryClassificationContext.BinaryClassificationTrainers ctx,
+ IClassificationLoss lossFunction,
+ Scalar label, Vector features, Scalar weights = null,
+ float learningRate = AveragedLinearArguments.AveragedDefaultArgs.LearningRate,
+ bool decreaseLearningRate = AveragedLinearArguments.AveragedDefaultArgs.DecreaseLearningRate,
+ float l2RegularizerWeight = AveragedLinearArguments.AveragedDefaultArgs.L2RegularizerWeight,
+ int numIterations = AveragedLinearArguments.AveragedDefaultArgs.NumIterations,
+ Action onFit = null
+ )
+ {
+ OnlineLinearStaticUtils.CheckUserParams(label, features, weights, learningRate, l2RegularizerWeight, numIterations, onFit);
+
+ bool hasProbs = lossFunction is HingeLoss;
+
+ var args = new AveragedPerceptronTrainer.Arguments()
+ {
+ LearningRate = learningRate,
+ DecreaseLearningRate = decreaseLearningRate,
+ L2RegularizerWeight = l2RegularizerWeight,
+ NumIterations = numIterations
+ };
+
+ if (lossFunction != null)
+ args.LossFunction = new TrivialClassificationLossFactory(lossFunction);
+
+ var rec = new TrainerEstimatorReconciler.BinaryClassifierNoCalibration(
+ (env, labelName, featuresName, weightsName) =>
+ {
+ args.FeatureColumn = featuresName;
+ args.LabelColumn = labelName;
+ args.InitialWeights = weightsName;
+
+ var trainer = new AveragedPerceptronTrainer(env, args);
+
+ if (onFit != null)
+ return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
+ else
+ return trainer;
+
+ }, label, features, weights, hasProbs);
+
+ return rec.Output;
+ }
+
+ private sealed class TrivialClassificationLossFactory : ISupportClassificationLossFactory
+ {
+ private readonly IClassificationLoss _loss;
+
+ public TrivialClassificationLossFactory(IClassificationLoss loss)
+ {
+ _loss = loss;
+ }
+
+ public IClassificationLoss CreateComponent(IHostEnvironment env)
+ {
+ return _loss;
+ }
+ }
+ }
+
+ ///
+ /// Regression trainer estimators.
+ ///
+ public static partial class RegressionTrainers
+ {
+ ///
+ /// Predict a target using a linear regression model trained with the trainer.
+ ///
+ /// The regression context trainer object.
+ /// The label, or dependent variable.
+ /// The features, or independent variables.
+ /// The optional example weights.
+ /// The custom loss. Defaults to if not provided.
+ /// The learning Rate.
+ /// Decrease learning rate as iterations progress.
+ /// L2 regularization weight.
+ /// Number of training iterations through the data.
+ /// A delegate that is called every time the
+ /// method is called on the
+ /// instance created out of this. This delegate will receive
+ /// the linear model that was trained, as well as the calibrator on top of that model. Note that this action cannot change the
+ /// result in any way; it is only a way for the caller to be informed about what was learnt.
+ /// The set of output columns including in order the predicted binary classification score (which will range
+ /// from negative to positive infinity), and the predicted label.
+ /// .
+ /// The predicted output.
+ public static Scalar OnlineGradientDescent(this RegressionContext.RegressionTrainers ctx,
+ Scalar label,
+ Vector features,
+ Scalar weights = null,
+ IRegressionLoss lossFunction = null,
+ float learningRate = OnlineGradientDescentTrainer.Arguments.OgdDefaultArgs.LearningRate,
+ bool decreaseLearningRate = OnlineGradientDescentTrainer.Arguments.OgdDefaultArgs.DecreaseLearningRate,
+ float l2RegularizerWeight = OnlineGradientDescentTrainer.Arguments.OgdDefaultArgs.L2RegularizerWeight,
+ int numIterations = OnlineGradientDescentTrainer.Arguments.OgdDefaultArgs.NumIterations,
+ Action onFit = null)
+ {
+ OnlineLinearStaticUtils.CheckUserParams(label, features, weights, learningRate, l2RegularizerWeight, numIterations, onFit);
+ Contracts.CheckValueOrNull(lossFunction);
+
+ var rec = new TrainerEstimatorReconciler.Regression(
+ (env, labelName, featuresName, weightsName) =>
+ {
+ var trainer = new OnlineGradientDescentTrainer(env, labelName, featuresName, learningRate,
+ decreaseLearningRate, l2RegularizerWeight, numIterations, weightsName, lossFunction);
+
+ if (onFit != null)
+ return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
+
+ return trainer;
+ }, label, features, weights);
+
+ return rec.Score;
+ }
+ }
+
+ internal static class OnlineLinearStaticUtils{
+
+ internal static void CheckUserParams(PipelineColumn label,
+ PipelineColumn features,
+ PipelineColumn weights,
+ float learningRate,
+ float l2RegularizerWeight,
+ int numIterations,
+ Delegate onFit)
+ {
+ Contracts.CheckValue(label, nameof(label));
+ Contracts.CheckValue(features, nameof(features));
+ Contracts.CheckValueOrNull(weights);
+ Contracts.CheckParam(learningRate > 0, nameof(learningRate), "Must be positive.");
+ Contracts.CheckParam(0 <= l2RegularizerWeight && l2RegularizerWeight < 0.5, nameof(l2RegularizerWeight), "must be in range [0, 0.5)");
+ Contracts.CheckParam(numIterations > 0, nameof(numIterations), "Must be positive, if specified.");
+ Contracts.CheckValueOrNull(onFit);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
index 15bd5da290..41bf4a421a 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
@@ -24,7 +24,7 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter", SortOrder = 50)]
[TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")]
[TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)]
- public int NumIterations = 1;
+ public int NumIterations = OnlineDefaultArgs.NumIterations;
[Argument(ArgumentType.AtMostOnce, HelpText = "Initial Weights and bias, comma-separated", ShortName = "initweights")]
[TGUI(NoSweep = true)]
@@ -41,6 +41,11 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
[Argument(ArgumentType.AtMostOnce, HelpText = "Size of cache when trained in Scope", ShortName = "cache")]
public int StreamingCacheSize = 1000000;
+
+ internal class OnlineDefaultArgs
+ {
+ internal const int NumIterations = 1;
+ }
}
public abstract class OnlineLinearTrainer : TrainerEstimatorBase
@@ -78,7 +83,7 @@ public abstract class OnlineLinearTrainer : TrainerEstimat
protected virtual bool NeedCalibration => false;
protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
- : base(Contracts.CheckRef(env, nameof(env)).Register(name), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.InitialWeights))
+ : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights))
{
Contracts.CheckValue(args, nameof(args));
Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive);
@@ -156,18 +161,6 @@ protected override TModel TrainModelCore(TrainContext context)
protected abstract void CheckLabel(RoleMappedData data);
- private static SchemaShape.Column MakeWeightColumn(string weightColumn)
- {
- if (weightColumn == null)
- return null;
- return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
- }
-
- private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
- {
- return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
- }
-
protected virtual void TrainCore(IChannel ch, RoleMappedData data)
{
bool shuffle = Args.Shuffle;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs
index ed400aaeac..2e703bf4a0 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs
@@ -46,10 +46,22 @@ public sealed class Arguments : ArgumentsBase
/// The name of the label column.
/// The name of the feature column.
/// The name for the example weight column.
+ /// Enforce non-negative weights.
+ /// Weight of L1 regularizer term.
+ /// Weight of L2 regularizer term.
+ /// Memory size for . Lower=faster, less accurate.
+ /// Threshold for optimizer convergence.
/// A delegate to apply all the advanced arguments to the algorithm.
public PoissonRegression(IHostEnvironment env, string featureColumn, string labelColumn,
- string weightColumn = null, Action advancedSettings = null)
- : base(env, featureColumn, TrainerUtils.MakeR4ScalarLabel(labelColumn), weightColumn, advancedSettings)
+ string weightColumn = null,
+ float l1Weight = Arguments.Defaults.L1Weight,
+ float l2Weight = Arguments.Defaults.L2Weight,
+ float optimizationTolerance = Arguments.Defaults.OptTol,
+ int memorySize = Arguments.Defaults.MemorySize,
+ bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity,
+ Action advancedSettings = null)
+ : base(env, featureColumn, TrainerUtils.MakeR4ScalarLabel(labelColumn), weightColumn, advancedSettings,
+ l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity)
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs
index 113092ed26..6c4995bfa7 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs
@@ -142,7 +142,7 @@ public static (Scalar score, Scalar probability, Scalar pred
/// The binary classification context trainer object.
/// The label, or dependent variable.
/// The features, or independent variables.
- /// /// The custom loss.
+ /// The custom loss.
/// The optional example weights.
/// The L2 regularization hyperparameter.
/// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.
@@ -212,7 +212,7 @@ public static (Scalar score, Scalar predictedLabel) Sdca(
/// The multiclass classification context trainer object.
/// The label, or dependent variable.
/// The features, or independent variables.
- /// /// The custom loss.
+ /// The custom loss.
/// The optional example weights.
/// The L2 regularization hyperparameter.
/// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.
diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs
index 8ccea87539..867466d21a 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs
+++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs
@@ -12,7 +12,7 @@
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.LightGBM;
using Microsoft.ML.Runtime.RunTests;
-using Microsoft.ML.Runtime.Training;
+using Microsoft.ML.StaticPipe;
using Microsoft.ML.Trainers;
using System;
using System.Linq;
@@ -145,7 +145,7 @@ public void SdcaBinaryClassification()
}
[Fact]
- public void SdcaBinaryClassificationNoClaibration()
+ public void SdcaBinaryClassificationNoCalibration()
{
var env = new ConsoleEnvironment(seed: 0);
var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
@@ -187,6 +187,42 @@ public void SdcaBinaryClassificationNoClaibration()
Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}");
}
+ [Fact]
+ public void AveragePerceptronNoCalibration()
+ {
+ var env = new ConsoleEnvironment(seed: 0);
+ var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
+ var dataSource = new MultiFileSource(dataPath);
+ var ctx = new BinaryClassificationContext(env);
+
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadBool(0), features: c.LoadFloat(1, 9)));
+
+ LinearBinaryPredictor pred = null;
+
+ var loss = new HingeLoss(new HingeLoss.Arguments() { Margin = 1 });
+
+ var est = reader.MakeNewEstimator()
+ .Append(r => (r.label, preds: ctx.Trainers.AveragedPerceptron(loss, r.label, r.features,
+ numIterations: 2, onFit: p => pred = p)));
+
+ var pipe = reader.Append(est);
+
+ Assert.Null(pred);
+ var model = pipe.Fit(dataSource);
+ Assert.NotNull(pred);
+ // 9 input features, so we ought to have 9 weights.
+ Assert.Equal(9, pred.Weights2.Count);
+
+ var data = model.Read(dataSource);
+
+ var metrics = ctx.Evaluate(data, r => r.label, r => r.preds);
+ // Run a sanity check against a few of the metrics.
+ Assert.InRange(metrics.Accuracy, 0, 1);
+ Assert.InRange(metrics.Auc, 0, 1);
+ Assert.InRange(metrics.Auprc, 0, 1);
+ }
+
[Fact]
public void FfmBinaryClassification()
{
@@ -453,6 +489,172 @@ public void LightGbmRegression()
Assert.InRange(metrics.LossFn, 0, double.PositiveInfinity);
}
+ [Fact]
+ public void PoissonRegression()
+ {
+ var env = new ConsoleEnvironment(seed: 0);
+ var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
+ var dataSource = new MultiFileSource(dataPath);
+
+ var ctx = new RegressionContext(env);
+
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)),
+ separator: ';', hasHeader: true);
+
+ PoissonRegressionPredictor pred = null;
+
+ var est = reader.MakeNewEstimator()
+ .Append(r => (r.label, score: ctx.Trainers.PoissonRegression(r.label, r.features,
+ l1Weight: 2,
+ enoforceNoNegativity: true,
+ onFit: (p) => { pred = p; })));
+
+ var pipe = reader.Append(est);
+
+ Assert.Null(pred);
+ var model = pipe.Fit(dataSource);
+ Assert.NotNull(pred);
+ // 11 input features, so we ought to have 11 weights.
+ VBuffer weights = new VBuffer();
+ pred.GetFeatureWeights(ref weights);
+ Assert.Equal(11, weights.Length);
+
+ var data = model.Read(dataSource);
+
+ var metrics = ctx.Evaluate(data, r => r.label, r => r.score, new PoissonLoss());
+ // Run a sanity check against a few of the metrics.
+ Assert.InRange(metrics.L1, 0, double.PositiveInfinity);
+ Assert.InRange(metrics.L2, 0, double.PositiveInfinity);
+ Assert.InRange(metrics.Rms, 0, double.PositiveInfinity);
+ Assert.Equal(metrics.Rms * metrics.Rms, metrics.L2, 5);
+ Assert.InRange(metrics.LossFn, 0, double.PositiveInfinity);
+ }
+
+ [Fact]
+ public void LogisticRegressionBinaryClassification()
+ {
+ var env = new ConsoleEnvironment(seed: 0);
+ var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
+ var dataSource = new MultiFileSource(dataPath);
+ var ctx = new BinaryClassificationContext(env);
+
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadBool(0), features: c.LoadFloat(1, 9)));
+
+ IPredictorWithFeatureWeights pred = null;
+
+ var est = reader.MakeNewEstimator()
+ .Append(r => (r.label, preds: ctx.Trainers.LogisticRegressionBinaryClassifier(r.label, r.features,
+ l1Weight: 10,
+ onFit: (p) => { pred = p; })));
+
+ var pipe = reader.Append(est);
+
+ Assert.Null(pred);
+ var model = pipe.Fit(dataSource);
+ Assert.NotNull(pred);
+
+ // 9 input features, so we ought to have 9 weights.
+ VBuffer weights = new VBuffer();
+ pred.GetFeatureWeights(ref weights);
+ Assert.Equal(9, weights.Length);
+
+ var data = model.Read(dataSource);
+
+ var metrics = ctx.Evaluate(data, r => r.label, r => r.preds);
+ // Run a sanity check against a few of the metrics.
+ Assert.InRange(metrics.Accuracy, 0, 1);
+ Assert.InRange(metrics.Auc, 0, 1);
+ Assert.InRange(metrics.Auprc, 0, 1);
+ }
+
+ [Fact]
+ public void MulticlassLogisticRegression()
+ {
+ var env = new ConsoleEnvironment(seed: 0);
+ var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
+ var dataSource = new MultiFileSource(dataPath);
+
+ var ctx = new MulticlassClassificationContext(env);
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadText(0), features: c.LoadFloat(1, 4)));
+
+ MulticlassLogisticRegressionPredictor pred = null;
+
+ // With a custom loss function we no longer get calibrated predictions.
+ var est = reader.MakeNewEstimator()
+ .Append(r => (label: r.label.ToKey(), r.features))
+ .Append(r => (r.label, preds: ctx.Trainers.MultiClassLogisticRegression(
+ r.label,
+ r.features, onFit: p => pred = p)));
+
+ var pipe = reader.Append(est);
+
+ Assert.Null(pred);
+ var model = pipe.Fit(dataSource);
+ Assert.NotNull(pred);
+ VBuffer[] weights = default;
+ pred.GetWeights(ref weights, out int n);
+ Assert.True(n == 3 && n == weights.Length);
+ foreach (var w in weights)
+ Assert.True(w.Length == 4);
+
+ var data = model.Read(dataSource);
+
+ // Just output some data on the schema for fun.
+ var schema = data.AsDynamic.Schema;
+ for (int c = 0; c < schema.ColumnCount; ++c)
+ Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}");
+
+ var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2);
+ Assert.True(metrics.LogLoss > 0);
+ Assert.True(metrics.TopKAccuracy > 0);
+ }
+
+ [Fact]
+ public void OnlineGradientDescent()
+ {
+ var env = new ConsoleEnvironment(seed: 0);
+ var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
+ var dataSource = new MultiFileSource(dataPath);
+
+ var ctx = new RegressionContext(env);
+
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)),
+ separator: ';', hasHeader: true);
+
+ LinearRegressionPredictor pred = null;
+
+ var loss = new SquaredLoss();
+
+ var est = reader.MakeNewEstimator()
+ .Append(r => (r.label, score: ctx.Trainers.OnlineGradientDescent(r.label, r.features,
+ // lossFunction:loss,
+ onFit: (p) => { pred = p; })));
+
+ var pipe = reader.Append(est);
+
+ Assert.Null(pred);
+ var model = pipe.Fit(dataSource);
+ Assert.NotNull(pred);
+ // 11 input features, so we ought to have 11 weights.
+ VBuffer weights = new VBuffer();
+ pred.GetFeatureWeights(ref weights);
+ Assert.Equal(11, weights.Length);
+
+ var data = model.Read(dataSource);
+
+ var metrics = ctx.Evaluate(data, r => r.label, r => r.score, new PoissonLoss());
+ // Run a sanity check against a few of the metrics.
+ Assert.InRange(metrics.L1, 0, double.PositiveInfinity);
+ Assert.InRange(metrics.L2, 0, double.PositiveInfinity);
+ Assert.InRange(metrics.Rms, 0, double.PositiveInfinity);
+ Assert.Equal(metrics.Rms * metrics.Rms, metrics.L2, 5);
+ Assert.InRange(metrics.LossFn, 0, double.PositiveInfinity);
+ }
+
[Fact]
public void KMeans()
{
diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs
index 8d4f29591f..fb54ad52ce 100644
--- a/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs
+++ b/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs
@@ -3,11 +3,8 @@
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Core.Data;
-using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Learners;
-using Microsoft.ML.Runtime.RunTests;
using Xunit;
namespace Microsoft.ML.Tests.TrainerEstimators
diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs
index 2dea99e131..693e3ef2c7 100644
--- a/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs
+++ b/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs
@@ -25,7 +25,7 @@ public void OnlineLinearWorkout()
var trainData = pipe.Fit(data).Transform(data).AsDynamic;
- IEstimator est = new OnlineGradientDescentTrainer(Env, new OnlineGradientDescentTrainer.Arguments());
+ IEstimator est = new OnlineGradientDescentTrainer(Env, "Label", "Features");
TestEstimatorCore(est, trainData);
est = new AveragedPerceptronTrainer(Env, new AveragedPerceptronTrainer.Arguments());