diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index 3fe853aff5..344adbe39f 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -51,7 +51,7 @@ public Arguments() BasePredictors = new[] { ComponentFactoryUtils.CreateFromFunction( - env => new OnlineGradientDescentTrainer(env, new OnlineGradientDescentTrainer.Arguments())) + env => new OnlineGradientDescentTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features)) }; } } diff --git a/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs b/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs index 2a64ef8d31..d75ddddd74 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs @@ -2,13 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.KMeans; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; using System; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { /// /// The trainer context extensions for the . @@ -35,16 +35,22 @@ public static (Vector score, Key predictedLabel) KMeans(this Cluste Action advancedSettings = null, Action onFit = null) { - var rec = new TrainerEstimatorReconciler.Clustering( - (env, featuresName, weightsName) => - { - var trainer = new KMeansPlusPlusTrainer(env, featuresName, clustersCount, weightsName, advancedSettings); + Contracts.CheckValue(features, nameof(features)); + Contracts.CheckValueOrNull(weights); + Contracts.CheckParam(clustersCount > 1, nameof(clustersCount), "If provided, must be greater than 1."); + Contracts.CheckValueOrNull(onFit); + Contracts.CheckValueOrNull(advancedSettings); - if (onFit != null) - return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); - else - return trainer; - }, features, weights); + var rec = new TrainerEstimatorReconciler.Clustering( + (env, featuresName, weightsName) => + { + var trainer = new KMeansPlusPlusTrainer(env, featuresName, clustersCount, weightsName, advancedSettings); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + else + return trainer; + }, features, weights); return rec.Output; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 8722b79ea6..29e2e1e89f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1359,7 +1359,7 @@ public void Add(Double summand) public sealed class LinearClassificationTrainer : SdcaTrainerBase, TScalarPredictor> { public const string LoadNameValue = "SDCA"; - public const string UserNameValue = "Fast Linear (SA-SDCA)"; + internal const string UserNameValue = "Fast Linear (SA-SDCA)"; public sealed class Arguments : ArgumentsBase { diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index 4286f83379..864e0ba3ab 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -26,29 +26,29 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight [Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization weight", ShortName = "l2", SortOrder = 50)] [TGUI(Label = "L2 Weight", Description = "Weight of L2 regularizer term", SuggestedSweeps = "0,0.1,1")] [TlcModule.SweepableFloatParamAttribute(0.0f, 1.0f, numSteps: 4)] - public float L2Weight = 1; + public float L2Weight = Defaults.L2Weight; [Argument(ArgumentType.AtMostOnce, HelpText = "L1 regularization weight", ShortName = "l1", SortOrder = 50)] [TGUI(Label = "L1 Weight", Description = "Weight of L1 regularizer term", SuggestedSweeps = "0,0.1,1")] [TlcModule.SweepableFloatParamAttribute(0.0f, 1.0f, numSteps: 4)] - public float L1Weight = 1; + public float L1Weight = Defaults.L1Weight; [Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance parameter for optimization convergence. Lower = slower, more accurate", ShortName = "ot", SortOrder = 50)] [TGUI(Label = "Optimization Tolerance", Description = "Threshold for optimizer convergence", SuggestedSweeps = "1e-4,1e-7")] [TlcModule.SweepableDiscreteParamAttribute(new object[] { 1e-4f, 1e-7f })] - public float OptTol = 1e-7f; + public float OptTol = Defaults.OptTol; [Argument(ArgumentType.AtMostOnce, HelpText = "Memory size for L-BFGS. Lower=faster, less accurate", ShortName = "m", SortOrder = 50)] [TGUI(Description = "Memory size for L-BFGS", SuggestedSweeps = "5,20,50")] [TlcModule.SweepableDiscreteParamAttribute("MemorySize", new object[] { 5, 20, 50 })] - public int MemorySize = 20; + public int MemorySize = Defaults.MemorySize; [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum iterations.", ShortName = "maxiter")] [TGUI(Label = "Max Number of Iterations")] [TlcModule.SweepableLongParamAttribute("MaxIterations", 1, int.MaxValue)] - public int MaxIterations = int.MaxValue; + public int MaxIterations = Defaults.MaxIterations; [Argument(ArgumentType.AtMostOnce, HelpText = "Run SGD to initialize LR weights, converging to this tolerance", ShortName = "sgd")] @@ -90,7 +90,17 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight public bool DenseOptimizer = false; [Argument(ArgumentType.AtMostOnce, HelpText = "Enforce non-negative weights", ShortName = "nn", SortOrder = 90)] - public bool EnforceNonNegativity = false; + public bool EnforceNonNegativity = Defaults.EnforceNonNegativity; + + internal static class Defaults + { + internal const float L2Weight = 1; + internal const float L1Weight = 1; + internal const float OptTol = 1e-7f; + internal const int MemorySize = 20; + internal const int MaxIterations = int.MaxValue; + internal const bool EnforceNonNegativity = false; + } } private const string RegisterName = nameof(LbfgsTrainerBase); @@ -142,32 +152,48 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight public override TrainerInfo Info => _info; internal LbfgsTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, - string weightColumn = null, Action advancedSettings = null) - : this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings), labelColumn) + string weightColumn, Action advancedSettings, float l1Weight, + float l2Weight, + float optimizationTolerance, + int memorySize, + bool enforceNoNegativity) + : this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings), labelColumn, + l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) { } - internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn) + internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn, + float? l1Weight = null, + float? l2Weight = null, + float? optimizationTolerance = null, + int? memorySize = null, + bool? enforceNoNegativity = null) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) { Host.CheckValue(args, nameof(args)); Args = args; - Contracts.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null, + Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null, nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); - Contracts.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative"); - Contracts.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative"); - Contracts.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive"); - Contracts.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive"); - Contracts.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive"); - Contracts.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative"); - Contracts.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative"); - - L2Weight = Args.L2Weight; - L1Weight = Args.L1Weight; - OptTol = Args.OptTol; - MemorySize = Args.MemorySize; + Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative"); + Host.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative"); + Host.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive"); + Host.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive"); + Host.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive"); + Host.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative"); + Host.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative"); + + Host.CheckParam(!(l2Weight < 0), nameof(l2Weight), "Must be non-negative, if provided."); + Host.CheckParam(!(l1Weight < 0), nameof(l1Weight), "Must be non-negative, if provided"); + Host.CheckParam(!(optimizationTolerance <= 0), nameof(optimizationTolerance), "Must be positive, if provided."); + Host.CheckParam(!(memorySize <= 0), nameof(memorySize), "Must be positive, if provided."); + + // Review: Warn about the overriding behavior + L2Weight = l2Weight ?? Args.L2Weight; + L1Weight = l1Weight ?? Args.L1Weight; + OptTol = optimizationTolerance ?? Args.OptTol; + MemorySize = memorySize ?? Args.MemorySize; MaxIterations = Args.MaxIterations; SgdInitializationTolerance = Args.SgdInitializationTolerance; Quiet = Args.Quiet; @@ -175,7 +201,7 @@ internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column l UseThreads = Args.UseThreads; NumThreads = Args.NumThreads; DenseOptimizer = Args.DenseOptimizer; - EnforceNonNegativity = Args.EnforceNonNegativity; + EnforceNonNegativity = enforceNoNegativity ?? Args.EnforceNonNegativity; if (EnforceNonNegativity && ShowTrainingStats) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatics.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatics.cs new file mode 100644 index 0000000000..71227d2155 --- /dev/null +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatics.cs @@ -0,0 +1,196 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Internal.Calibration; +using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.StaticPipe.Runtime; + +namespace Microsoft.ML.StaticPipe +{ + using Arguments = LogisticRegression.Arguments; + + /// + /// Binary Classification trainer estimators. + /// + public static partial class BinaryClassificationTrainers + { + /// + /// Predict a target using a linear binary classification model trained with the trainer. + /// + /// The binary classificaiton context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// Enforce non-negative weights. + /// Weight of L1 regularization term. + /// Weight of L2 regularization term. + /// Memory size for . Lower=faster, less accurate. + /// Threshold for optimizer convergence. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to + /// be informed about what was learnt. + /// The predicted output. + public static (Scalar score, Scalar probability, Scalar predictedLabel) LogisticRegressionBinaryClassifier(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + Scalar label, + Vector features, + Scalar weights = null, + float l1Weight = Arguments.Defaults.L1Weight, + float l2Weight = Arguments.Defaults.L2Weight, + float optimizationTolerance = Arguments.Defaults.OptTol, + int memorySize = Arguments.Defaults.MemorySize, + bool enoforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, + Action onFit = null) + { + LbfgsStaticUtils.ValidateParams(label, features, weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity, onFit); + + var rec = new TrainerEstimatorReconciler.BinaryClassifier( + (env, labelName, featuresName, weightsName) => + { + var trainer = new LogisticRegression(env, featuresName, labelName, weightsName, + l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + + }, label, features, weights); + + return rec.Output; + } + } + + /// + /// Regression trainer estimators. + /// + public static partial class RegressionTrainers + { + + /// + /// Predict a target using a linear regression model trained with the trainer. + /// + /// The regression context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// Enforce non-negative weights. + /// Weight of L1 regularization term. + /// Weight of L2 regularization term. + /// Memory size for . Lower=faster, less accurate. + /// Threshold for optimizer convergence. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to + /// be informed about what was learnt. + /// The predicted output. + public static Scalar PoissonRegression(this RegressionContext.RegressionTrainers ctx, + Scalar label, + Vector features, + Scalar weights = null, + float l1Weight = Arguments.Defaults.L1Weight, + float l2Weight = Arguments.Defaults.L2Weight, + float optimizationTolerance = Arguments.Defaults.OptTol, + int memorySize = Arguments.Defaults.MemorySize, + bool enoforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, + Action onFit = null) + { + LbfgsStaticUtils.ValidateParams(label, features, weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity, onFit); + + var rec = new TrainerEstimatorReconciler.Regression( + (env, labelName, featuresName, weightsName) => + { + var trainer = new PoissonRegression(env, featuresName, labelName, weightsName, + l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + + return trainer; + }, label, features, weights); + + return rec.Score; + } + } + + /// + /// MultiClass Classification trainer estimators. + /// + public static partial class MultiClassClassificationTrainers + { + + /// + /// Predict a target using a linear multiclass classification model trained with the trainer. + /// + /// The multiclass classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// Enforce non-negative weights. + /// Weight of L1 regularization term. + /// Weight of L2 regularization term. + /// Memory size for . Lower=faster, less accurate. + /// Threshold for optimizer convergence. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. + public static (Vector score, Key predictedLabel) + MultiClassLogisticRegression(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + Key label, + Vector features, + Scalar weights = null, + float l1Weight = Arguments.Defaults.L1Weight, + float l2Weight = Arguments.Defaults.L2Weight, + float optimizationTolerance = Arguments.Defaults.OptTol, + int memorySize = Arguments.Defaults.MemorySize, + bool enoforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, + Action onFit = null) + { + LbfgsStaticUtils.ValidateParams(label, features, weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity, onFit); + + var rec = new TrainerEstimatorReconciler.MulticlassClassifier( + (env, labelName, featuresName, weightsName) => + { + var trainer = new MulticlassLogisticRegression(env, featuresName, labelName, weightsName, + l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, weights); + + return rec.Output; + } + + } + + internal static class LbfgsStaticUtils{ + + internal static void ValidateParams(PipelineColumn label, + Vector features, + Scalar weights = null, + float l1Weight = Arguments.Defaults.L1Weight, + float l2Weight = Arguments.Defaults.L2Weight, + float optimizationTolerance = Arguments.Defaults.OptTol, + int memorySize = Arguments.Defaults.MemorySize, + bool enoforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, + Delegate onFit = null) + { + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckValue(features, nameof(features)); + Contracts.CheckParam(l2Weight >= 0, nameof(l2Weight), "Must be non-negative"); + Contracts.CheckParam(l1Weight >= 0, nameof(l1Weight), "Must be non-negative"); + Contracts.CheckParam(optimizationTolerance > 0, nameof(optimizationTolerance), "Must be positive"); + Contracts.CheckParam(memorySize > 0, nameof(memorySize), "Must be positive"); + Contracts.CheckValueOrNull(onFit); + } + } +} diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs index 0914c63c33..87c1c1bde9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs @@ -54,10 +54,24 @@ public sealed class Arguments : ArgumentsBase /// The name of the label column. /// The name of the feature column. /// The name for the example weight column. + /// Enforce non-negative weights. + /// Weight of L1 regularizer term. + /// Weight of L2 regularizer term. + /// Memory size for . Lower=faster, less accurate. + /// Threshold for optimizer convergence. /// A delegate to apply all the advanced arguments to the algorithm. - public LogisticRegression(IHostEnvironment env, string featureColumn, string labelColumn, - string weightColumn = null, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeR4ScalarLabel(labelColumn), weightColumn, advancedSettings) + public LogisticRegression(IHostEnvironment env, + string featureColumn, + string labelColumn, + string weightColumn = null, + float l1Weight = Arguments.Defaults.L1Weight, + float l2Weight = Arguments.Defaults.L2Weight, + float optimizationTolerance = Arguments.Defaults.OptTol, + int memorySize = Arguments.Defaults.MemorySize, + bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, + Action advancedSettings = null) + : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weightColumn, advancedSettings, + l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); @@ -70,7 +84,7 @@ public LogisticRegression(IHostEnvironment env, string featureColumn, string lab /// Initializes a new instance of /// internal LogisticRegression(IHostEnvironment env, Arguments args) - : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { _posWeight = 0; ShowTrainingStats = Args.ShowTrainingStats; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 52d1688638..40e3090c74 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -76,10 +76,22 @@ public sealed class Arguments : ArgumentsBase /// The name of the label column. /// The name of the feature column. /// The name for the example weight column. + /// Enforce non-negative weights. + /// Weight of L1 regularizer term. + /// Weight of L2 regularizer term. + /// Memory size for . Lower=faster, less accurate. + /// Threshold for optimizer convergence. /// A delegate to apply all the advanced arguments to the algorithm. public MulticlassLogisticRegression(IHostEnvironment env, string featureColumn, string labelColumn, - string weightColumn = null, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), weightColumn, advancedSettings) + string weightColumn = null, + float l1Weight = Arguments.Defaults.L1Weight, + float l2Weight = Arguments.Defaults.L2Weight, + float optimizationTolerance = Arguments.Defaults.OptTol, + int memorySize = Arguments.Defaults.MemorySize, + bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, + Action advancedSettings = null) + : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), weightColumn, advancedSettings, + l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs index 70ee279a1c..5871c932dd 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs @@ -22,12 +22,12 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments [Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", SortOrder = 50)] [TGUI(Label = "Learning rate", SuggestedSweeps = "0.01,0.1,0.5,1.0")] [TlcModule.SweepableDiscreteParam("LearningRate", new object[] { 0.01, 0.1, 0.5, 1.0 })] - public Float LearningRate = 1; + public Float LearningRate = AveragedDefaultArgs.LearningRate; [Argument(ArgumentType.AtMostOnce, HelpText = "Decrease learning rate", ShortName = "decreaselr", SortOrder = 50)] [TGUI(Label = "Decrease Learning Rate", Description = "Decrease learning rate as iterations progress")] [TlcModule.SweepableDiscreteParam("DecreaseLearningRate", new object[] { false, true })] - public bool DecreaseLearningRate = false; + public bool DecreaseLearningRate = AveragedDefaultArgs.DecreaseLearningRate; [Argument(ArgumentType.AtMostOnce, HelpText = "Number of examples after which weights will be reset to the current average", ShortName = "numreset")] public long? ResetWeightsAfterXExamples = null; @@ -38,7 +38,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments [Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg", SortOrder = 50)] [TGUI(Label = "L2 Regularization Weight")] [TlcModule.SweepableFloatParam("L2RegularizerWeight", 0.0f, 0.4f)] - public Float L2RegularizerWeight = 0; + public Float L2RegularizerWeight = AveragedDefaultArgs.L2RegularizerWeight; [Argument(ArgumentType.AtMostOnce, HelpText = "Extra weight given to more recent updates", ShortName = "rg")] public Float RecencyGain = 0; @@ -51,6 +51,13 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments [Argument(ArgumentType.AtMostOnce, HelpText = "The inexactness tolerance for averaging", ShortName = "avgtol")] public Float AveragedTolerance = (Float)1e-2; + + internal class AveragedDefaultArgs : OnlineDefaultArgs + { + internal const Float LearningRate = 1; + internal const bool DecreaseLearningRate = false; + internal const Float L2RegularizerWeight = 0; + } } public abstract class AveragedLinearTrainer : OnlineLinearTrainer diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index f5a1dc44f5..8cbb1978a5 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -57,22 +57,21 @@ public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args) { _args = args; LossFunction = _args.LossFunction.CreateComponent(env); - - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), - new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) - }; } public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; protected override bool NeedCalibration => true; - private readonly SchemaShape.Column[] _outputColumns; - - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + // REVIEW AP is currently not calibrating. Add the probability column after fixing the behavior. + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } protected override void CheckLabel(RoleMappedData data) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs index ad8442433d..d312c6b52b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs @@ -34,8 +34,8 @@ namespace Microsoft.ML.Runtime.Learners /// public sealed class LinearSvm : OnlineLinearTrainer, LinearBinaryPredictor> { - public const string LoadNameValue = "LinearSVM"; - public const string ShortName = "svm"; + internal const string LoadNameValue = "LinearSVM"; + internal const string ShortName = "svm"; internal const string UserNameValue = "SVM (Pegasos-Linear)"; internal const string Summary = "The idea behind support vector machines, is to map the instances into a high dimensional space " + "in which instances of the two classes are linearly separable, i.e., there exists a hyperplane such that all the positive examples are on one side of it, " @@ -92,8 +92,13 @@ public LinearSvm(IHostEnvironment env, Arguments args) Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive); Args = args; + } + + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - _outputColumns = new[] + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), @@ -101,12 +106,6 @@ public LinearSvm(IHostEnvironment env, Arguments args) }; } - public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - - private readonly SchemaShape.Column[] _outputColumns; - - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; - protected override void CheckLabel(RoleMappedData data) { Contracts.AssertValue(data); diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index 768bb1da75..00dfb59cbd 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; @@ -46,27 +44,69 @@ public sealed class Arguments : AveragedLinearArguments /// public Arguments() { - LearningRate = (Float)0.1; - DecreaseLearningRate = true; + LearningRate = OgdDefaultArgs.LearningRate; + DecreaseLearningRate = OgdDefaultArgs.DecreaseLearningRate; + } + + internal class OgdDefaultArgs : AveragedDefaultArgs + { + internal new const float LearningRate = 0.1f; + internal new const bool DecreaseLearningRate = true; } } - public OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args) - : base(args, env, UserNameValue, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) + /// + /// Trains a new . + /// + /// The pricate instance of . + /// Name of the label column. + /// Name of the feature column. + /// The learning Rate. + /// Decrease learning rate as iterations progress. + /// L2 Regularization Weight. + /// Number of training iterations through the data. + /// The name of the weights column. + /// The custom loss functions. Defaults to if not provided. + public OnlineGradientDescentTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + float learningRate = Arguments.OgdDefaultArgs.LearningRate, + bool decreaseLearningRate = Arguments.OgdDefaultArgs.DecreaseLearningRate, + float l2RegularizerWeight = Arguments.OgdDefaultArgs.L2RegularizerWeight, + int numIterations = Arguments.OgdDefaultArgs.NumIterations, + string weightsColumn = null, + IRegressionLoss lossFunction = null) + : base(new Arguments + { + LearningRate = learningRate, + DecreaseLearningRate = decreaseLearningRate, + L2RegularizerWeight = l2RegularizerWeight, + NumIterations = numIterations, + LabelColumn = labelColumn, + FeatureColumn = featureColumn, + InitialWeights = weightsColumn + + }, env, UserNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn)) + { + LossFunction = lossFunction ?? new SquaredLoss(); + } + + internal OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args) + : base(args, env, UserNameValue, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { LossFunction = args.LossFunction.CreateComponent(env); + } + + public override PredictionKind PredictionKind => PredictionKind.Regression; - _outputColumns = new[] + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) }; } - public override PredictionKind PredictionKind => PredictionKind.Regression; - - private readonly SchemaShape.Column[] _outputColumns; - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; - protected override void CheckLabel(RoleMappedData data) { data.CheckRegressionLabel(); @@ -75,8 +115,8 @@ protected override void CheckLabel(RoleMappedData data) protected override LinearRegressionPredictor CreatePredictor() { Contracts.Assert(WeightsScale == 1); - VBuffer weights = default(VBuffer); - Float bias; + VBuffer weights = default(VBuffer); + float bias; if (!Args.Averaged) { @@ -86,8 +126,8 @@ protected override LinearRegressionPredictor CreatePredictor() else { TotalWeights.CopyTo(ref weights); - VectorUtils.ScaleBy(ref weights, 1 / (Float)NumWeightUpdates); - bias = TotalBias / (Float)NumWeightUpdates; + VectorUtils.ScaleBy(ref weights, 1 / (float)NumWeightUpdates); + bias = TotalBias / (float)NumWeightUpdates; } return new LinearRegressionPredictor(Host, ref weights, bias); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerStatic.cs new file mode 100644 index 0000000000..de286a59c8 --- /dev/null +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerStatic.cs @@ -0,0 +1,174 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.StaticPipe.Runtime; +using System; + +namespace Microsoft.ML.StaticPipe +{ + /// + /// Binary Classification trainer estimators. + /// + public static partial class BinaryClassificationTrainers + { + /// + /// Predict a target using a linear binary classification model trained with the AveragedPerceptron trainer, and a custom loss. + /// + /// The binary classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The custom loss. + /// The optional example weights. + /// The learning Rate. + /// Decrease learning rate as iterations progress. + /// L2 regularization weight. + /// Number of training iterations through the data. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained, as well as the calibrator on top of that model. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted binary classification score (which will range + /// from negative to positive infinity), and the predicted label. + /// . + public static (Scalar score, Scalar predictedLabel) AveragedPerceptron( + this BinaryClassificationContext.BinaryClassificationTrainers ctx, + IClassificationLoss lossFunction, + Scalar label, Vector features, Scalar weights = null, + float learningRate = AveragedLinearArguments.AveragedDefaultArgs.LearningRate, + bool decreaseLearningRate = AveragedLinearArguments.AveragedDefaultArgs.DecreaseLearningRate, + float l2RegularizerWeight = AveragedLinearArguments.AveragedDefaultArgs.L2RegularizerWeight, + int numIterations = AveragedLinearArguments.AveragedDefaultArgs.NumIterations, + Action onFit = null + ) + { + OnlineLinearStaticUtils.CheckUserParams(label, features, weights, learningRate, l2RegularizerWeight, numIterations, onFit); + + bool hasProbs = lossFunction is HingeLoss; + + var args = new AveragedPerceptronTrainer.Arguments() + { + LearningRate = learningRate, + DecreaseLearningRate = decreaseLearningRate, + L2RegularizerWeight = l2RegularizerWeight, + NumIterations = numIterations + }; + + if (lossFunction != null) + args.LossFunction = new TrivialClassificationLossFactory(lossFunction); + + var rec = new TrainerEstimatorReconciler.BinaryClassifierNoCalibration( + (env, labelName, featuresName, weightsName) => + { + args.FeatureColumn = featuresName; + args.LabelColumn = labelName; + args.InitialWeights = weightsName; + + var trainer = new AveragedPerceptronTrainer(env, args); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + else + return trainer; + + }, label, features, weights, hasProbs); + + return rec.Output; + } + + private sealed class TrivialClassificationLossFactory : ISupportClassificationLossFactory + { + private readonly IClassificationLoss _loss; + + public TrivialClassificationLossFactory(IClassificationLoss loss) + { + _loss = loss; + } + + public IClassificationLoss CreateComponent(IHostEnvironment env) + { + return _loss; + } + } + } + + /// + /// Regression trainer estimators. + /// + public static partial class RegressionTrainers + { + /// + /// Predict a target using a linear regression model trained with the trainer. + /// + /// The regression context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// The custom loss. Defaults to if not provided. + /// The learning Rate. + /// Decrease learning rate as iterations progress. + /// L2 regularization weight. + /// Number of training iterations through the data. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained, as well as the calibrator on top of that model. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted binary classification score (which will range + /// from negative to positive infinity), and the predicted label. + /// . + /// The predicted output. + public static Scalar OnlineGradientDescent(this RegressionContext.RegressionTrainers ctx, + Scalar label, + Vector features, + Scalar weights = null, + IRegressionLoss lossFunction = null, + float learningRate = OnlineGradientDescentTrainer.Arguments.OgdDefaultArgs.LearningRate, + bool decreaseLearningRate = OnlineGradientDescentTrainer.Arguments.OgdDefaultArgs.DecreaseLearningRate, + float l2RegularizerWeight = OnlineGradientDescentTrainer.Arguments.OgdDefaultArgs.L2RegularizerWeight, + int numIterations = OnlineGradientDescentTrainer.Arguments.OgdDefaultArgs.NumIterations, + Action onFit = null) + { + OnlineLinearStaticUtils.CheckUserParams(label, features, weights, learningRate, l2RegularizerWeight, numIterations, onFit); + Contracts.CheckValueOrNull(lossFunction); + + var rec = new TrainerEstimatorReconciler.Regression( + (env, labelName, featuresName, weightsName) => + { + var trainer = new OnlineGradientDescentTrainer(env, labelName, featuresName, learningRate, + decreaseLearningRate, l2RegularizerWeight, numIterations, weightsName, lossFunction); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + + return trainer; + }, label, features, weights); + + return rec.Score; + } + } + + internal static class OnlineLinearStaticUtils{ + + internal static void CheckUserParams(PipelineColumn label, + PipelineColumn features, + PipelineColumn weights, + float learningRate, + float l2RegularizerWeight, + int numIterations, + Delegate onFit) + { + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckValue(features, nameof(features)); + Contracts.CheckValueOrNull(weights); + Contracts.CheckParam(learningRate > 0, nameof(learningRate), "Must be positive."); + Contracts.CheckParam(0 <= l2RegularizerWeight && l2RegularizerWeight < 0.5, nameof(l2RegularizerWeight), "must be in range [0, 0.5)"); + Contracts.CheckParam(numIterations > 0, nameof(numIterations), "Must be positive, if specified."); + Contracts.CheckValueOrNull(onFit); + } + } +} diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index 15bd5da290..41bf4a421a 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -24,7 +24,7 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel [Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter", SortOrder = 50)] [TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")] [TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)] - public int NumIterations = 1; + public int NumIterations = OnlineDefaultArgs.NumIterations; [Argument(ArgumentType.AtMostOnce, HelpText = "Initial Weights and bias, comma-separated", ShortName = "initweights")] [TGUI(NoSweep = true)] @@ -41,6 +41,11 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel [Argument(ArgumentType.AtMostOnce, HelpText = "Size of cache when trained in Scope", ShortName = "cache")] public int StreamingCacheSize = 1000000; + + internal class OnlineDefaultArgs + { + internal const int NumIterations = 1; + } } public abstract class OnlineLinearTrainer : TrainerEstimatorBase @@ -78,7 +83,7 @@ public abstract class OnlineLinearTrainer : TrainerEstimat protected virtual bool NeedCalibration => false; protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.InitialWeights)) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights)) { Contracts.CheckValue(args, nameof(args)); Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive); @@ -156,18 +161,6 @@ protected override TModel TrainModelCore(TrainContext context) protected abstract void CheckLabel(RoleMappedData data); - private static SchemaShape.Column MakeWeightColumn(string weightColumn) - { - if (weightColumn == null) - return null; - return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); - } - - private static SchemaShape.Column MakeFeatureColumn(string featureColumn) - { - return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); - } - protected virtual void TrainCore(IChannel ch, RoleMappedData data) { bool shuffle = Args.Shuffle; diff --git a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs index ed400aaeac..2e703bf4a0 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs @@ -46,10 +46,22 @@ public sealed class Arguments : ArgumentsBase /// The name of the label column. /// The name of the feature column. /// The name for the example weight column. + /// Enforce non-negative weights. + /// Weight of L1 regularizer term. + /// Weight of L2 regularizer term. + /// Memory size for . Lower=faster, less accurate. + /// Threshold for optimizer convergence. /// A delegate to apply all the advanced arguments to the algorithm. public PoissonRegression(IHostEnvironment env, string featureColumn, string labelColumn, - string weightColumn = null, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeR4ScalarLabel(labelColumn), weightColumn, advancedSettings) + string weightColumn = null, + float l1Weight = Arguments.Defaults.L1Weight, + float l2Weight = Arguments.Defaults.L2Weight, + float optimizationTolerance = Arguments.Defaults.OptTol, + int memorySize = Arguments.Defaults.MemorySize, + bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, + Action advancedSettings = null) + : base(env, featureColumn, TrainerUtils.MakeR4ScalarLabel(labelColumn), weightColumn, advancedSettings, + l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs index 113092ed26..6c4995bfa7 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs @@ -142,7 +142,7 @@ public static (Scalar score, Scalar probability, Scalar pred /// The binary classification context trainer object. /// The label, or dependent variable. /// The features, or independent variables. - /// /// The custom loss. + /// The custom loss. /// The optional example weights. /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. @@ -212,7 +212,7 @@ public static (Scalar score, Scalar predictedLabel) Sdca( /// The multiclass classification context trainer object. /// The label, or dependent variable. /// The features, or independent variables. - /// /// The custom loss. + /// The custom loss. /// The optional example weights. /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 8ccea87539..867466d21a 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -12,7 +12,7 @@ using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Runtime.RunTests; -using Microsoft.ML.Runtime.Training; +using Microsoft.ML.StaticPipe; using Microsoft.ML.Trainers; using System; using System.Linq; @@ -145,7 +145,7 @@ public void SdcaBinaryClassification() } [Fact] - public void SdcaBinaryClassificationNoClaibration() + public void SdcaBinaryClassificationNoCalibration() { var env = new ConsoleEnvironment(seed: 0); var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); @@ -187,6 +187,42 @@ public void SdcaBinaryClassificationNoClaibration() Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}"); } + [Fact] + public void AveragePerceptronNoCalibration() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); + var dataSource = new MultiFileSource(dataPath); + var ctx = new BinaryClassificationContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadBool(0), features: c.LoadFloat(1, 9))); + + LinearBinaryPredictor pred = null; + + var loss = new HingeLoss(new HingeLoss.Arguments() { Margin = 1 }); + + var est = reader.MakeNewEstimator() + .Append(r => (r.label, preds: ctx.Trainers.AveragedPerceptron(loss, r.label, r.features, + numIterations: 2, onFit: p => pred = p))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + // 9 input features, so we ought to have 9 weights. + Assert.Equal(9, pred.Weights2.Count); + + var data = model.Read(dataSource); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.preds); + // Run a sanity check against a few of the metrics. + Assert.InRange(metrics.Accuracy, 0, 1); + Assert.InRange(metrics.Auc, 0, 1); + Assert.InRange(metrics.Auprc, 0, 1); + } + [Fact] public void FfmBinaryClassification() { @@ -453,6 +489,172 @@ public void LightGbmRegression() Assert.InRange(metrics.LossFn, 0, double.PositiveInfinity); } + [Fact] + public void PoissonRegression() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new RegressionContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), + separator: ';', hasHeader: true); + + PoissonRegressionPredictor pred = null; + + var est = reader.MakeNewEstimator() + .Append(r => (r.label, score: ctx.Trainers.PoissonRegression(r.label, r.features, + l1Weight: 2, + enoforceNoNegativity: true, + onFit: (p) => { pred = p; }))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + // 11 input features, so we ought to have 11 weights. + VBuffer weights = new VBuffer(); + pred.GetFeatureWeights(ref weights); + Assert.Equal(11, weights.Length); + + var data = model.Read(dataSource); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.score, new PoissonLoss()); + // Run a sanity check against a few of the metrics. + Assert.InRange(metrics.L1, 0, double.PositiveInfinity); + Assert.InRange(metrics.L2, 0, double.PositiveInfinity); + Assert.InRange(metrics.Rms, 0, double.PositiveInfinity); + Assert.Equal(metrics.Rms * metrics.Rms, metrics.L2, 5); + Assert.InRange(metrics.LossFn, 0, double.PositiveInfinity); + } + + [Fact] + public void LogisticRegressionBinaryClassification() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); + var dataSource = new MultiFileSource(dataPath); + var ctx = new BinaryClassificationContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadBool(0), features: c.LoadFloat(1, 9))); + + IPredictorWithFeatureWeights pred = null; + + var est = reader.MakeNewEstimator() + .Append(r => (r.label, preds: ctx.Trainers.LogisticRegressionBinaryClassifier(r.label, r.features, + l1Weight: 10, + onFit: (p) => { pred = p; }))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + + // 9 input features, so we ought to have 9 weights. + VBuffer weights = new VBuffer(); + pred.GetFeatureWeights(ref weights); + Assert.Equal(9, weights.Length); + + var data = model.Read(dataSource); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.preds); + // Run a sanity check against a few of the metrics. + Assert.InRange(metrics.Accuracy, 0, 1); + Assert.InRange(metrics.Auc, 0, 1); + Assert.InRange(metrics.Auprc, 0, 1); + } + + [Fact] + public void MulticlassLogisticRegression() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.iris.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new MulticlassClassificationContext(env); + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); + + MulticlassLogisticRegressionPredictor pred = null; + + // With a custom loss function we no longer get calibrated predictions. + var est = reader.MakeNewEstimator() + .Append(r => (label: r.label.ToKey(), r.features)) + .Append(r => (r.label, preds: ctx.Trainers.MultiClassLogisticRegression( + r.label, + r.features, onFit: p => pred = p))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + VBuffer[] weights = default; + pred.GetWeights(ref weights, out int n); + Assert.True(n == 3 && n == weights.Length); + foreach (var w in weights) + Assert.True(w.Length == 4); + + var data = model.Read(dataSource); + + // Just output some data on the schema for fun. + var schema = data.AsDynamic.Schema; + for (int c = 0; c < schema.ColumnCount; ++c) + Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}"); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2); + Assert.True(metrics.LogLoss > 0); + Assert.True(metrics.TopKAccuracy > 0); + } + + [Fact] + public void OnlineGradientDescent() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new RegressionContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), + separator: ';', hasHeader: true); + + LinearRegressionPredictor pred = null; + + var loss = new SquaredLoss(); + + var est = reader.MakeNewEstimator() + .Append(r => (r.label, score: ctx.Trainers.OnlineGradientDescent(r.label, r.features, + // lossFunction:loss, + onFit: (p) => { pred = p; }))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + // 11 input features, so we ought to have 11 weights. + VBuffer weights = new VBuffer(); + pred.GetFeatureWeights(ref weights); + Assert.Equal(11, weights.Length); + + var data = model.Read(dataSource); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.score, new PoissonLoss()); + // Run a sanity check against a few of the metrics. + Assert.InRange(metrics.L1, 0, double.PositiveInfinity); + Assert.InRange(metrics.L2, 0, double.PositiveInfinity); + Assert.InRange(metrics.Rms, 0, double.PositiveInfinity); + Assert.Equal(metrics.Rms * metrics.Rms, metrics.L2, 5); + Assert.InRange(metrics.LossFn, 0, double.PositiveInfinity); + } + [Fact] public void KMeans() { diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs index 8d4f29591f..fb54ad52ce 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs @@ -3,11 +3,8 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Learners; -using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.TrainerEstimators diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs index 2dea99e131..693e3ef2c7 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs @@ -25,7 +25,7 @@ public void OnlineLinearWorkout() var trainData = pipe.Fit(data).Transform(data).AsDynamic; - IEstimator est = new OnlineGradientDescentTrainer(Env, new OnlineGradientDescentTrainer.Arguments()); + IEstimator est = new OnlineGradientDescentTrainer(Env, "Label", "Features"); TestEstimatorCore(est, trainData); est = new AveragedPerceptronTrainer(Env, new AveragedPerceptronTrainer.Arguments());