diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptron.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptron.cs new file mode 100644 index 0000000000..767d398dc6 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptron.cs @@ -0,0 +1,44 @@ +using Microsoft.ML; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification +{ + public static class AveragedPerceptron + { + // In this examples we will use the adult income dataset. The goal is to predict + // if a person's income is above $50K or not, based on different pieces of information about that person. + // For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Download and featurize the dataset. + var data = SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext); + + // Leave out 10% of data for testing. + var trainTestData = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); + + // Create data training pipeline. + var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(numIterations: 10); + + // Fit this pipeline to the training data. + var model = pipeline.Fit(trainTestData.TrainSet); + + // Evaluate how the model is doing on the test data. + var dataWithPredictions = model.Transform(trainTestData.TestSet); + var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(dataWithPredictions); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Accuracy: 0.86 + // AUC: 0.91 + // F1 Score: 0.68 + // Negative Precision: 0.90 + // Negative Recall: 0.91 + // Positive Precision: 0.70 + // Positive Recall: 0.66 + } + } +} \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs new file mode 100644 index 0000000000..ee568bff92 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs @@ -0,0 +1,55 @@ +using Microsoft.ML; +using Microsoft.ML.Trainers.Online; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification +{ + public static class AveragedPerceptronWithOptions + { + // In this examples we will use the adult income dataset. The goal is to predict + // if a person's income is above $50K or not, based on different pieces of information about that person. + // For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Download and featurize the dataset. + var data = SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext); + + // Leave out 10% of data for testing. + var trainTestData = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); + + // Define the trainer options. + var options = new AveragedPerceptronTrainer.Options() + { + LossFunction = new SmoothedHingeLoss.Arguments(), + LearningRate = 0.1f, + DoLazyUpdates = false, + RecencyGain = 0.1f, + NumberOfIterations = 10 + }; + + // Create data training pipeline. + var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(options); + + // Fit this pipeline to the training data. + var model = pipeline.Fit(trainTestData.TrainSet); + + // Evaluate how the model is doing on the test data. + var dataWithPredictions = model.Transform(trainTestData.TestSet); + var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(dataWithPredictions); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Accuracy: 0.86 + // AUC: 0.90 + // F1 Score: 0.66 + // Negative Precision: 0.89 + // Negative Recall: 0.93 + // Positive Precision: 0.72 + // Positive Recall: 0.61 + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs b/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs new file mode 100644 index 0000000000..83fafd8658 --- /dev/null +++ b/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.ML.Data; + +namespace Microsoft.ML.SamplesUtils +{ + /// + /// Utilities for creating console outputs in samples' code. + /// + public static class ConsoleUtils + { + /// + /// Pretty-print BinaryClassificationMetrics objects. + /// + /// Binary classification metrics. + public static void PrintMetrics(BinaryClassificationMetrics metrics) + { + Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}"); + Console.WriteLine($"AUC: {metrics.Auc:F2}"); + Console.WriteLine($"F1 Score: {metrics.F1Score:F2}"); + Console.WriteLine($"Negative Precision: {metrics.NegativePrecision:F2}"); + Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}"); + Console.WriteLine($"Positive Precision: {metrics.PositivePrecision:F2}"); + Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}"); + } + } +} diff --git a/src/Microsoft.ML.SamplesUtils/Microsoft.ML.SamplesUtils.csproj b/src/Microsoft.ML.SamplesUtils/Microsoft.ML.SamplesUtils.csproj index e4d6c5d504..0bdb047d42 100644 --- a/src/Microsoft.ML.SamplesUtils/Microsoft.ML.SamplesUtils.csproj +++ b/src/Microsoft.ML.SamplesUtils/Microsoft.ML.SamplesUtils.csproj @@ -6,7 +6,9 @@ + + diff --git a/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs b/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs index 14f0beb094..203bd6e6bd 100644 --- a/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs +++ b/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs @@ -7,6 +7,7 @@ using System.IO; using System.Net; using Microsoft.Data.DataView; +using Microsoft.ML; using Microsoft.ML.Data; namespace Microsoft.ML.SamplesUtils @@ -86,6 +87,65 @@ public static string DownloadSentimentDataset() public static string DownloadAdultDataset() => Download("https://raw.githubusercontent.com/dotnet/machinelearning/244a8c2ac832657af282aa312d568211698790aa/test/data/adult.train", "adult.txt"); + /// + /// Downloads the Adult UCI dataset and featurizes it to be suitable for classification tasks. + /// + /// used for data loading and processing. + /// Featurized dataset. + /// + /// For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult. + /// + public static IDataView LoadFeaturizedAdultDataset(MLContext mlContext) + { + // Download the file + string dataFile = DownloadAdultDataset(); + + // Define the columns to read + var reader = mlContext.Data.CreateTextLoader( + columns: new[] + { + new TextLoader.Column("age", DataKind.R4, 0), + new TextLoader.Column("workclass", DataKind.TX, 1), + new TextLoader.Column("fnlwgt", DataKind.R4, 2), + new TextLoader.Column("education", DataKind.TX, 3), + new TextLoader.Column("education-num", DataKind.R4, 4), + new TextLoader.Column("marital-status", DataKind.TX, 5), + new TextLoader.Column("occupation", DataKind.TX, 6), + new TextLoader.Column("relationship", DataKind.TX, 7), + new TextLoader.Column("ethnicity", DataKind.TX, 8), + new TextLoader.Column("sex", DataKind.TX, 9), + new TextLoader.Column("capital-gain", DataKind.R4, 10), + new TextLoader.Column("capital-loss", DataKind.R4, 11), + new TextLoader.Column("hours-per-week", DataKind.R4, 12), + new TextLoader.Column("native-country", DataKind.R4, 13), + new TextLoader.Column("IsOver50K", DataKind.BL, 14), + }, + separatorChar: ',', + hasHeader: true + ); + + // Create data featurizing pipeline + var pipeline = mlContext.Transforms.CopyColumns("Label", "IsOver50K") + // Convert categorical features to one-hot vectors + .Append(mlContext.Transforms.Categorical.OneHotEncoding("workclass")) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("education")) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("marital-status")) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("occupation")) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("relationship")) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("ethnicity")) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("native-country")) + // Combine all features into one feature vector + .Append(mlContext.Transforms.Concatenate("Features", "workclass", "education", "marital-status", + "occupation", "relationship", "ethnicity", "native-country", "age", "education-num", + "capital-gain", "capital-loss", "hours-per-week")) + // Min-max normalized all the features + .Append(mlContext.Transforms.Normalize("Features")); + + var data = reader.Read(dataFile); + var featurizedData = pipeline.Fit(data).Transform(data); + return featurizedData; + } + /// /// Downloads the breast cancer dataset from the ML.NET repo. /// diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs index 457fa7e9ae..519ec4005f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs @@ -14,40 +14,94 @@ namespace Microsoft.ML.Trainers.Online { + /// + /// Arguments class for averaged linear trainers. + /// public abstract class AveragedLinearArguments : OnlineLinearArguments { + /// + /// Learning rate. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", SortOrder = 50)] [TGUI(Label = "Learning rate", SuggestedSweeps = "0.01,0.1,0.5,1.0")] [TlcModule.SweepableDiscreteParam("LearningRate", new object[] { 0.01, 0.1, 0.5, 1.0 })] public float LearningRate = AveragedDefaultArgs.LearningRate; + /// + /// Determine whether to decrease the or not. + /// + /// + /// to decrease the as iterations progress; otherwise, . + /// Default is . The learning rate will be reduced with every weight update proportional to the square root of the number of updates. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Decrease learning rate", ShortName = "decreaselr", SortOrder = 50)] [TGUI(Label = "Decrease Learning Rate", Description = "Decrease learning rate as iterations progress")] [TlcModule.SweepableDiscreteParam("DecreaseLearningRate", new object[] { false, true })] public bool DecreaseLearningRate = AveragedDefaultArgs.DecreaseLearningRate; + /// + /// Number of examples after which weights will be reset to the current average. + /// + /// + /// Default is , which disables this feature. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Number of examples after which weights will be reset to the current average", ShortName = "numreset")] public long? ResetWeightsAfterXExamples = null; + /// + /// Determines when to update averaged weights. + /// + /// + /// to update averaged weights only when loss is nonzero. + /// to update averaged weights on every example. + /// Default is . + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Instead of updating averaged weights on every example, only update when loss is nonzero", ShortName = "lazy")] public bool DoLazyUpdates = true; + /// + /// L2 weight for regularization. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg", SortOrder = 50)] [TGUI(Label = "L2 Regularization Weight")] [TlcModule.SweepableFloatParam("L2RegularizerWeight", 0.0f, 0.4f)] public float L2RegularizerWeight = AveragedDefaultArgs.L2RegularizerWeight; + /// + /// Extra weight given to more recent updates. + /// + /// + /// Default is 0, i.e. no extra gain. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Extra weight given to more recent updates", ShortName = "rg")] public float RecencyGain = 0; + /// + /// Determines whether is multiplicative or additive. + /// + /// + /// means is multiplicative. + /// means is additive. + /// Default is . + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Whether Recency Gain is multiplicative (vs. additive)", ShortName = "rgm")] public bool RecencyGainMulti = false; + /// + /// Determines whether to do averaging or not. + /// + /// + /// to do averaging; otherwise, . + /// Default is . + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Do averaging?", ShortName = "avg")] public bool Averaged = true; + /// + /// The inexactness tolerance for averaging. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The inexactness tolerance for averaging", ShortName = "avgtol")] - public float AveragedTolerance = (float)1e-2; + internal float AveragedTolerance = (float)1e-2; [BestFriend] internal class AveragedDefaultArgs : OnlineDefaultArgs diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index a165008327..5436427123 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -23,12 +23,27 @@ namespace Microsoft.ML.Trainers.Online { - // This is an averaged perceptron classifier. - // Configurable subcomponents: - // - Loss function. By default, hinge loss (aka max-margin avgd perceptron) - // - Feature normalization. By default, rescaling between min and max values for every feature - // - Prediction calibration to produce probabilities. Off by default, if on, uses exponential (aka Platt) calibration. - /// + /// + /// The for the averaged perceptron trainer. + /// + /// + /// The perceptron is a classification algorithm that makes its predictions by finding a separating hyperplane. + /// For instance, with feature values f0, f1,..., f_D-1, the prediction is given by determining what side of the hyperplane the point falls into. + /// That is the same as the sign of sigma[0, D-1] (w_i * f_i), where w_0, w_1,..., w_D-1 are the weights computed by the algorithm. + /// + /// The perceptron is an online algorithm, which means it processes the instances in the training set one at a time. + /// It starts with a set of initial weights (zero, random, or initialized from a previous learner). Then, for each example in the training set, the weighted sum of the features (sigma[0, D-1] (w_i * f_i)) is computed. + /// If this value has the same sign as the label of the current example, the weights remain the same. If they have opposite signs, + /// the weights vector is updated by either adding or subtracting (if the label is positive or negative, respectively) the feature vector of the current example, + /// multiplied by a factor 0 < a <= 1, called the learning rate. In a generalization of this algorithm, the weights are updated by adding the feature vector multiplied by the learning rate, + /// and by the gradient of some loss function (in the specific case described above, the loss is hinge-loss, whose gradient is 1 when it is non-zero). + /// + /// In Averaged Perceptron (aka voted-perceptron), for each iteration, i.e. pass through the training data, a weight vector is calculated as explained above. + /// The final prediction is then calculate by averaging the weighted sum from each weight vector and looking at the sign of the result. + /// + /// For more information see Wikipedia entry for Perceptron + /// or Large Margin Classification Using the Perceptron Algorithm + /// public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer, LinearBinaryModelParameters> { internal const string LoadNameValue = "AveragedPerceptron"; @@ -38,14 +53,26 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer + /// Options for the averaged perceptron trainer. + /// public sealed class Options : AveragedLinearArguments { + /// + /// A custom loss. + /// [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] public ISupportClassificationLossFactory LossFunction = new HingeLoss.Arguments(); + /// + /// The calibrator for producing probabilities. Default is exponential (aka Platt) calibration. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The calibrator kind to apply to the predictor. Specify null for no calibration", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] internal ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory(); + /// + /// The maximum number of examples to use when training the calibrator. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] internal int MaxCalibrationExamples = 1000000; @@ -98,9 +125,9 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, Options options) /// The name of the label column. /// The name of the feature column. /// The learning rate. - /// Wheather to decrease learning rate as iterations progress. + /// Whether to decrease learning rate as iterations progress. /// L2 Regularization Weight. - /// The number of training iteraitons. + /// The number of training iterations. internal AveragedPerceptronTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -116,7 +143,7 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, LearningRate = learningRate, DecreaseLearningRate = decreaseLearningRate, L2RegularizerWeight = l2RegularizerWeight, - NumIterations = numIterations, + NumberOfIterations = numIterations, LossFunction = new TrivialFactory(lossFunction ?? new HingeLoss()) }) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs index f2066fe69f..74dc162515 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs @@ -244,7 +244,7 @@ internal LinearSvmTrainer(IHostEnvironment env, LabelColumn = labelColumn, FeatureColumn = featureColumn, WeightColumn = weightColumn, - NumIterations = numIterations, + NumberOfIterations = numIterations, }) { } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index e1381864fc..4b1e32bb19 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -111,7 +111,7 @@ internal OnlineGradientDescentTrainer(IHostEnvironment env, LearningRate = learningRate, DecreaseLearningRate = decreaseLearningRate, L2RegularizerWeight = l2RegularizerWeight, - NumIterations = numIterations, + NumberOfIterations = numIterations, LabelColumn = labelColumn, FeatureColumn = featureColumn, LossFunction = new TrivialFactory(lossFunction ?? new SquaredLoss()) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index 09a9e15b1f..d8ff4b58be 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -16,29 +16,49 @@ namespace Microsoft.ML.Trainers.Online { + /// + /// Arguments class for online linear trainers. + /// public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel { - [Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter", SortOrder = 50)] + /// + /// Number of passes through the training dataset. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter, numIterations", SortOrder = 50)] [TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")] [TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)] - public int NumIterations = OnlineDefaultArgs.NumIterations; + public int NumberOfIterations = OnlineDefaultArgs.NumIterations; + /// + /// Initial weights and bias, comma-separated. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Initial Weights and bias, comma-separated", ShortName = "initweights")] [TGUI(NoSweep = true)] internal string InitialWeights; - [Argument(ArgumentType.AtMostOnce, HelpText = "Init weights diameter", ShortName = "initwts", SortOrder = 140)] + /// + /// Initial weights and bias scale. + /// + /// + /// This property is only used if the provided value is positive and is not specified. + /// The weights and bias will be randomly selected from InitialWeights * [-0.5,0.5] interval with uniform distribution. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Init weights diameter", ShortName = "initwts, initWtsDiameter", SortOrder = 140)] [TGUI(Label = "Initial Weights Scale", SuggestedSweeps = "0,0.1,0.5,1")] [TlcModule.SweepableFloatParamAttribute("InitWtsDiameter", 0.0f, 1.0f, numSteps: 5)] - public float InitWtsDiameter = 0; + public float InitialWeightsDiameter = 0; + /// + /// Determines whether to shuffle data for each training iteration. + /// + /// + /// to shuffle data for each training iteration; otherwise, . + /// Default is . + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to shuffle for each training iteration", ShortName = "shuf")] [TlcModule.SweepableDiscreteParamAttribute("Shuffle", new object[] { false, true })] public bool Shuffle = true; - [Argument(ArgumentType.AtMostOnce, HelpText = "Size of cache when trained in Scope", ShortName = "cache")] - public int StreamingCacheSize = 1000000; - [BestFriend] internal class OnlineDefaultArgs { @@ -133,13 +153,13 @@ protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters pre Weights = new VBuffer(numFeatures, weightValues); Bias = float.Parse(weightStr[numFeatures], CultureInfo.InvariantCulture); } - else if (parent.Args.InitWtsDiameter > 0) + else if (parent.Args.InitialWeightsDiameter > 0) { var weightValues = new float[numFeatures]; for (int i = 0; i < numFeatures; i++) - weightValues[i] = parent.Args.InitWtsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5); + weightValues[i] = parent.Args.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5); Weights = new VBuffer(numFeatures, weightValues); - Bias = parent.Args.InitWtsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5); + Bias = parent.Args.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5); } else if (numFeatures <= 1000) Weights = VBufferUtils.CreateDense(numFeatures); @@ -237,9 +257,8 @@ private protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironme : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights)) { Contracts.CheckValue(args, nameof(args)); - Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive); - Contracts.CheckUserArg(args.InitWtsDiameter >= 0, nameof(args.InitWtsDiameter), UserErrorNonNegative); - Contracts.CheckUserArg(args.StreamingCacheSize > 0, nameof(args.StreamingCacheSize), UserErrorPositive); + Contracts.CheckUserArg(args.NumberOfIterations > 0, nameof(args.NumberOfIterations), UserErrorPositive); + Contracts.CheckUserArg(args.InitialWeightsDiameter >= 0, nameof(args.InitialWeightsDiameter), UserErrorNonNegative); Args = args; Name = name; @@ -304,7 +323,7 @@ private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state) var cursorFactory = new FloatLabelCursor.Factory(data, cursorOpt); long numBad = 0; - while (state.Iteration < Args.NumIterations) + while (state.Iteration < Args.NumberOfIterations) { state.BeginIteration(ch); @@ -322,7 +341,7 @@ private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state) { ch.Warning( "Skipped {0} instances with missing features during training (over {1} iterations; {2} inst/iter)", - numBad, Args.NumIterations, numBad / Args.NumIterations); + numBad, Args.NumberOfIterations, numBad / Args.NumberOfIterations); } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml b/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml index 8e8f5dc2ba..292aeface5 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml @@ -25,44 +25,5 @@ - - - - Averaged Perceptron Binary Classifier. - - - Perceptron is a classification algorithm that makes its predictions based on a linear function. - I.e., for an instance with feature values f0, f1,..., f_D-1, , the prediction is given by the sign of sigma[0,D-1] ( w_i * f_i), where w_0, w_1,...,w_D-1 are the weights computed by the algorithm. - - Perceptron is an online algorithm, i.e., it processes the instances in the training set one at a time. - The weights are initialized to be 0, or some random values. Then, for each example in the training set, the value of sigma[0, D-1] (w_i * f_i) is computed. - If this value has the same sign as the label of the current example, the weights remain the same. If they have opposite signs, - the weights vector is updated by either subtracting or adding (if the label is negative or positive, respectively) the feature vector of the current example, - multiplied by a factor 0 < a <= 1, called the learning rate. In a generalization of this algorithm, the weights are updated by adding the feature vector multiplied by the learning rate, - and by the gradient of some loss function (in the specific case described above, the loss is hinge-loss, whose gradient is 1 when it is non-zero). - - - In Averaged Perceptron (AKA voted-perceptron), the weight vectors are stored, - together with a weight that counts the number of iterations it survived (this is equivalent to storing the weight vector after every iteration, regardless of whether it was updated or not). - The prediction is then calculated by taking the weighted average of all the sums sigma[0, D-1] (w_i * f_i) or the different weight vectors. - - For more information see: - Wikipedia entry for Perceptron - Large Margin Classification Using the Perceptron Algorithm - - - - - - new AveragedPerceptronBinaryClassifier - { - NumIterations = 10, - L2RegularizerWeight = 0.01f, - LossFunction = new ExpLossClassificationLossFunction() - } - - - - diff --git a/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs b/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs index 1ab03f627a..53eeb33a02 100644 --- a/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs +++ b/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs @@ -241,16 +241,26 @@ public static SdcaMultiClassTrainer StochasticDualCoordinateAscent(this Multicla } /// - /// Predict a target using a linear binary classification model trained with the AveragedPerceptron trainer. + /// Predict a target using a linear binary classification model trained with . /// /// The binary classification catalog trainer object. /// The name of the label column, or dependent variable. /// The features, or independent variables. - /// The custom loss. - /// The learning Rate. - /// Decrease learning rate as iterations progress. - /// L2 regularization weight. - /// Number of training iterations through the data. + /// A custom loss. If , hinge loss will be used resulting in max-margin averaged perceptron. + /// Learning rate. + /// + /// to decrease the as iterations progress; otherwise, . + /// Default is . + /// + /// L2 weight for regularization. + /// Number of passes through the training dataset. + /// + /// + /// + /// + /// public static AveragedPerceptronTrainer AveragedPerceptron( this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, string labelColumn = DefaultColumnNames.Label, @@ -268,10 +278,17 @@ public static AveragedPerceptronTrainer AveragedPerceptron( } /// - /// Predict a target using a linear binary classification model trained with the AveragedPerceptron trainer. + /// Predict a target using a linear binary classification model trained with and advanced options. /// /// The binary classification catalog trainer object. - /// Advanced arguments to the algorithm. + /// Trainer options. + /// + /// + /// + /// + /// public static AveragedPerceptronTrainer AveragedPerceptron( this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, AveragedPerceptronTrainer.Options options) { diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 9ba33dbe57..35dd5c085b 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -4306,11 +4306,12 @@ } }, { - "Name": "NumIterations", + "Name": "NumberOfIterations", "Type": "Int", "Desc": "Number of iterations", "Aliases": [ - "iter" + "iter", + "numIterations" ], "Required": false, "SortOrder": 50.0, @@ -4325,11 +4326,12 @@ } }, { - "Name": "InitWtsDiameter", + "Name": "InitialWeightsDiameter", "Type": "Float", "Desc": "Init weights diameter", "Aliases": [ - "initwts" + "initwts", + "initWtsDiameter" ], "Required": false, "SortOrder": 140.0, @@ -4467,18 +4469,6 @@ true ] } - }, - { - "Name": "StreamingCacheSize", - "Type": "Int", - "Desc": "Size of cache when trained in Scope", - "Aliases": [ - "cache" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 1000000 } ], "Outputs": [ @@ -13295,11 +13285,12 @@ } }, { - "Name": "NumIterations", + "Name": "NumberOfIterations", "Type": "Int", "Desc": "Number of iterations", "Aliases": [ - "iter" + "iter", + "numIterations" ], "Required": false, "SortOrder": 50.0, @@ -13314,11 +13305,12 @@ } }, { - "Name": "InitWtsDiameter", + "Name": "InitialWeightsDiameter", "Type": "Float", "Desc": "Init weights diameter", "Aliases": [ - "initwts" + "initwts", + "initWtsDiameter" ], "Required": false, "SortOrder": 140.0, @@ -13401,18 +13393,6 @@ ] } }, - { - "Name": "StreamingCacheSize", - "Type": "Int", - "Desc": "Size of cache when trained in Scope", - "Aliases": [ - "cache" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 1000000 - }, { "Name": "BatchSize", "Type": "Int", @@ -14320,11 +14300,12 @@ } }, { - "Name": "NumIterations", + "Name": "NumberOfIterations", "Type": "Int", "Desc": "Number of iterations", "Aliases": [ - "iter" + "iter", + "numIterations" ], "Required": false, "SortOrder": 50.0, @@ -14339,11 +14320,12 @@ } }, { - "Name": "InitWtsDiameter", + "Name": "InitialWeightsDiameter", "Type": "Float", "Desc": "Init weights diameter", "Aliases": [ - "initwts" + "initwts", + "initWtsDiameter" ], "Required": false, "SortOrder": 140.0, @@ -14458,18 +14440,6 @@ true ] } - }, - { - "Name": "StreamingCacheSize", - "Type": "Int", - "Desc": "Size of cache when trained in Scope", - "Aliases": [ - "cache" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 1000000 } ], "Outputs": [ diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 661b4fbe80..bfe7d8c4f7 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -5446,11 +5446,10 @@ public void TestOvaMacroWithUncalibratedLearner() 'RecencyGainMulti': false, 'Averaged': true, 'AveragedTolerance': 0.01, - 'NumIterations': 1, + 'NumberOfIterations': 1, 'InitialWeights': null, - 'InitWtsDiameter': 0.0, + 'InitialWeightsDiameter': 0.0, 'Shuffle': false, - 'StreamingCacheSize': 1000000, 'LabelColumn': 'Label', 'TrainingData': '$Var_9ccc8bce4f6540eb8a244ab40585602a', 'FeatureColumn': 'Features', diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index 464fd0dc59..f509c71908 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -744,7 +744,7 @@ public void TestEnsembleCombiner() { FeatureColumn = "Features", LabelColumn = DefaultColumnNames.Label, - NumIterations = 2, + NumberOfIterations = 2, TrainingData = dataView, NormalizeFeatures = NormalizeOption.No }).PredictorModel, diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index 4b7fdbfaff..ac54587b65 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -182,7 +182,7 @@ public void TrainAveragedPerceptronWithCache() var cached = mlContext.Data.Cache(xf); var estimator = mlContext.BinaryClassification.Trainers.AveragedPerceptron( - new AveragedPerceptronTrainer.Options { NumIterations = 2 }); + new AveragedPerceptronTrainer.Options { NumberOfIterations = 2 }); estimator.Fit(cached).Transform(cached); diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index 4754191d09..51a37eccd7 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -131,7 +131,7 @@ public void OvaLinearSvm() // Pipeline var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll( - mlContext.BinaryClassification.Trainers.LinearSupportVectorMachines(new LinearSvmTrainer.Options { NumIterations = 100 }), + mlContext.BinaryClassification.Trainers.LinearSupportVectorMachines(new LinearSvmTrainer.Options { NumberOfIterations = 100 }), useProbabilities: false); var model = pipeline.Fit(data);