-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Documentation for BinaryClassification.AveragedPerceptron (V2) #2517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bf9d946
4222e85
f4f03ba
ceb3aa2
2b96f6d
87be8dc
ba0abff
a5538ed
6b56065
c487733
5219d0c
59673b8
3daf7bd
b10d7e6
5900a41
056a887
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
using Microsoft.ML; | ||
|
||
namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification | ||
{ | ||
public static class AveragedPerceptron | ||
{ | ||
// In this examples we will use the adult income dataset. The goal is to predict | ||
// if a person's income is above $50K or not, based on different pieces of information about that person. | ||
// For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult. | ||
public static void Example() | ||
{ | ||
// Create a new context for ML.NET operations. It can be used for exception tracking and logging, | ||
// as a catalog of available operations and as the source of randomness. | ||
// Setting the seed to a fixed number in this example to make outputs deterministic. | ||
var mlContext = new MLContext(seed: 0); | ||
|
||
// Download and featurize the dataset. | ||
var data = SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext); | ||
|
||
// Leave out 10% of data for testing. | ||
var trainTestData = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); | ||
|
||
// Create data training pipeline. | ||
var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(numIterations: 10); | ||
|
||
// Fit this pipeline to the training data. | ||
var model = pipeline.Fit(trainTestData.TrainSet); | ||
|
||
// Evaluate how the model is doing on the test data. | ||
var dataWithPredictions = model.Transform(trainTestData.TestSet); | ||
var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(dataWithPredictions); | ||
SamplesUtils.ConsoleUtils.PrintMetrics(metrics); | ||
|
||
// Expected output: | ||
// Accuracy: 0.86 | ||
// AUC: 0.91 | ||
// F1 Score: 0.68 | ||
// Negative Precision: 0.90 | ||
// Negative Recall: 0.91 | ||
// Positive Precision: 0.70 | ||
// Positive Recall: 0.66 | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
using Microsoft.ML; | ||
using Microsoft.ML.Trainers.Online; | ||
|
||
namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification | ||
{ | ||
public static class AveragedPerceptronWithOptions | ||
{ | ||
// In this examples we will use the adult income dataset. The goal is to predict | ||
// if a person's income is above $50K or not, based on different pieces of information about that person. | ||
// For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult. | ||
public static void Example() | ||
{ | ||
// Create a new context for ML.NET operations. It can be used for exception tracking and logging, | ||
// as a catalog of available operations and as the source of randomness. | ||
// Setting the seed to a fixed number in this example to make outputs deterministic. | ||
var mlContext = new MLContext(seed: 0); | ||
|
||
// Download and featurize the dataset. | ||
var data = SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext); | ||
|
||
// Leave out 10% of data for testing. | ||
var trainTestData = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); | ||
|
||
// Define the trainer options. | ||
var options = new AveragedPerceptronTrainer.Options() | ||
{ | ||
LossFunction = new SmoothedHingeLoss.Arguments(), | ||
LearningRate = 0.1f, | ||
DoLazyUpdates = false, | ||
RecencyGain = 0.1f, | ||
NumberOfIterations = 10 | ||
}; | ||
|
||
// Create data training pipeline. | ||
var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(options); | ||
|
||
// Fit this pipeline to the training data. | ||
var model = pipeline.Fit(trainTestData.TrainSet); | ||
|
||
// Evaluate how the model is doing on the test data. | ||
var dataWithPredictions = model.Transform(trainTestData.TestSet); | ||
var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(dataWithPredictions); | ||
SamplesUtils.ConsoleUtils.PrintMetrics(metrics); | ||
|
||
// Expected output: | ||
// Accuracy: 0.86 | ||
// AUC: 0.90 | ||
// F1 Score: 0.66 | ||
// Negative Precision: 0.89 | ||
// Negative Recall: 0.93 | ||
// Positive Precision: 0.72 | ||
// Positive Recall: 0.61 | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.SamplesUtils | ||
{ | ||
/// <summary> | ||
/// Utilities for creating console outputs in samples' code. | ||
/// </summary> | ||
public static class ConsoleUtils | ||
{ | ||
/// <summary> | ||
/// Pretty-print BinaryClassificationMetrics objects. | ||
/// </summary> | ||
/// <param name="metrics">Binary classification metrics.</param> | ||
public static void PrintMetrics(BinaryClassificationMetrics metrics) | ||
{ | ||
Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}"); | ||
Console.WriteLine($"AUC: {metrics.Auc:F2}"); | ||
Console.WriteLine($"F1 Score: {metrics.F1Score:F2}"); | ||
Console.WriteLine($"Negative Precision: {metrics.NegativePrecision:F2}"); | ||
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}"); | ||
Console.WriteLine($"Positive Precision: {metrics.PositivePrecision:F2}"); | ||
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}"); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
using System.IO; | ||
using System.Net; | ||
using Microsoft.Data.DataView; | ||
using Microsoft.ML; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.SamplesUtils | ||
|
@@ -86,6 +87,65 @@ public static string DownloadSentimentDataset() | |
public static string DownloadAdultDataset() | ||
=> Download("https://raw.githubusercontent.com/dotnet/machinelearning/244a8c2ac832657af282aa312d568211698790aa/test/data/adult.train", "adult.txt"); | ||
|
||
/// <summary> | ||
/// Downloads the Adult UCI dataset and featurizes it to be suitable for classification tasks. | ||
/// </summary> | ||
/// <param name="mlContext"><see cref="MLContext"/> used for data loading and processing.</param> | ||
/// <returns>Featurized dataset.</returns> | ||
/// <remarks> | ||
/// For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult. | ||
/// </remarks> | ||
public static IDataView LoadFeaturizedAdultDataset(MLContext mlContext) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
xml doc since this is public #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
{ | ||
// Download the file | ||
string dataFile = DownloadAdultDataset(); | ||
|
||
// Define the columns to read | ||
var reader = mlContext.Data.CreateTextLoader( | ||
columns: new[] | ||
{ | ||
new TextLoader.Column("age", DataKind.R4, 0), | ||
new TextLoader.Column("workclass", DataKind.TX, 1), | ||
new TextLoader.Column("fnlwgt", DataKind.R4, 2), | ||
new TextLoader.Column("education", DataKind.TX, 3), | ||
new TextLoader.Column("education-num", DataKind.R4, 4), | ||
new TextLoader.Column("marital-status", DataKind.TX, 5), | ||
new TextLoader.Column("occupation", DataKind.TX, 6), | ||
new TextLoader.Column("relationship", DataKind.TX, 7), | ||
new TextLoader.Column("ethnicity", DataKind.TX, 8), | ||
new TextLoader.Column("sex", DataKind.TX, 9), | ||
new TextLoader.Column("capital-gain", DataKind.R4, 10), | ||
new TextLoader.Column("capital-loss", DataKind.R4, 11), | ||
new TextLoader.Column("hours-per-week", DataKind.R4, 12), | ||
new TextLoader.Column("native-country", DataKind.R4, 13), | ||
new TextLoader.Column("IsOver50K", DataKind.BL, 14), | ||
}, | ||
separatorChar: ',', | ||
hasHeader: true | ||
); | ||
|
||
// Create data featurizing pipeline | ||
var pipeline = mlContext.Transforms.CopyColumns("Label", "IsOver50K") | ||
// Convert categorical features to one-hot vectors | ||
.Append(mlContext.Transforms.Categorical.OneHotEncoding("workclass")) | ||
.Append(mlContext.Transforms.Categorical.OneHotEncoding("education")) | ||
.Append(mlContext.Transforms.Categorical.OneHotEncoding("marital-status")) | ||
.Append(mlContext.Transforms.Categorical.OneHotEncoding("occupation")) | ||
.Append(mlContext.Transforms.Categorical.OneHotEncoding("relationship")) | ||
.Append(mlContext.Transforms.Categorical.OneHotEncoding("ethnicity")) | ||
.Append(mlContext.Transforms.Categorical.OneHotEncoding("native-country")) | ||
// Combine all features into one feature vector | ||
.Append(mlContext.Transforms.Concatenate("Features", "workclass", "education", "marital-status", | ||
"occupation", "relationship", "ethnicity", "native-country", "age", "education-num", | ||
"capital-gain", "capital-loss", "hours-per-week")) | ||
// Min-max normalized all the features | ||
.Append(mlContext.Transforms.Normalize("Features")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can you add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
var data = reader.Read(dataFile); | ||
var featurizedData = pipeline.Fit(data).Transform(data); | ||
return featurizedData; | ||
} | ||
|
||
/// <summary> | ||
/// Downloads the breast cancer dataset from the ML.NET repo. | ||
/// </summary> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,40 +14,94 @@ | |
|
||
namespace Microsoft.ML.Trainers.Online | ||
{ | ||
/// <summary> | ||
/// Arguments class for averaged linear trainers. | ||
/// </summary> | ||
public abstract class AveragedLinearArguments : OnlineLinearArguments | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
xml #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
{ | ||
/// <summary> | ||
/// <a href="tmpurl_lr">Learning rate</a>. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Remarks? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO, if it is short, it can stay in summary. This does need more text, if that's what you mean by In reply to: 256223629 [](ancestors = 256223629) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i was thinking the same as Senja. The remark will be simply the link, so we might as well leave it in the summary. In reply to: 256251344 [](ancestors = 256251344,256223629) |
||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", SortOrder = 50)] | ||
[TGUI(Label = "Learning rate", SuggestedSweeps = "0.01,0.1,0.5,1.0")] | ||
[TlcModule.SweepableDiscreteParam("LearningRate", new object[] { 0.01, 0.1, 0.5, 1.0 })] | ||
public float LearningRate = AveragedDefaultArgs.LearningRate; | ||
|
||
/// <summary> | ||
/// Determine whether to decrease the <see cref="LearningRate"/> or not. | ||
/// </summary> | ||
/// <value> | ||
/// <see langword="true" /> to decrease the <see cref="LearningRate"/> as iterations progress; otherwise, <see langword="false" />. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Remarks on exactly how it decreases? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
/// Default is <see langword="false" />. The learning rate will be reduced with every weight update proportional to the square root of the number of updates. | ||
/// </value> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Decrease learning rate", ShortName = "decreaselr", SortOrder = 50)] | ||
[TGUI(Label = "Decrease Learning Rate", Description = "Decrease learning rate as iterations progress")] | ||
[TlcModule.SweepableDiscreteParam("DecreaseLearningRate", new object[] { false, true })] | ||
public bool DecreaseLearningRate = AveragedDefaultArgs.DecreaseLearningRate; | ||
|
||
/// <summary> | ||
/// Number of examples after which weights will be reset to the current average. | ||
/// </summary> | ||
/// <value> | ||
/// Default is <see langword="null" />, which disables this feature. | ||
/// </value> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of examples after which weights will be reset to the current average", ShortName = "numreset")] | ||
public long? ResetWeightsAfterXExamples = null; | ||
|
||
/// <summary> | ||
/// Determines when to update averaged weights. | ||
/// </summary> | ||
/// <value> | ||
/// <see langword="true" /> to update averaged weights only when loss is nonzero. | ||
/// <see langword="false" /> to update averaged weights on every example. | ||
/// Default is <see langword="true" />. | ||
/// </value> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Instead of updating averaged weights on every example, only update when loss is nonzero", ShortName = "lazy")] | ||
public bool DoLazyUpdates = true; | ||
|
||
/// <summary> | ||
/// L2 weight for <a href='tmpurl_regularization'>regularization</a>. | ||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg", SortOrder = 50)] | ||
[TGUI(Label = "L2 Regularization Weight")] | ||
[TlcModule.SweepableFloatParam("L2RegularizerWeight", 0.0f, 0.4f)] | ||
public float L2RegularizerWeight = AveragedDefaultArgs.L2RegularizerWeight; | ||
|
||
/// <summary> | ||
/// Extra weight given to more recent updates. | ||
/// </summary> | ||
/// <value> | ||
/// Default is 0, i.e. no extra gain. | ||
/// </value> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Extra weight given to more recent updates", ShortName = "rg")] | ||
public float RecencyGain = 0; | ||
|
||
/// <summary> | ||
/// Determines whether <see cref="RecencyGain"/> is multiplicative or additive. | ||
/// </summary> | ||
/// <value> | ||
/// <see langword="true" /> means <see cref="RecencyGain"/> is multiplicative. | ||
/// <see langword="false" /> means <see cref="RecencyGain"/> is additive. | ||
/// Default is <see langword="false" />. | ||
/// </value> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether Recency Gain is multiplicative (vs. additive)", ShortName = "rgm")] | ||
public bool RecencyGainMulti = false; | ||
|
||
/// <summary> | ||
/// Determines whether to do averaging or not. | ||
/// </summary> | ||
/// <value> | ||
/// <see langword="true" /> to do averaging; otherwise, <see langword="false" />. | ||
/// Default is <see langword="true" />. | ||
/// </value> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Do averaging?", ShortName = "avg")] | ||
public bool Averaged = true; | ||
|
||
/// <summary> | ||
/// The inexactness tolerance for averaging. | ||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "The inexactness tolerance for averaging", ShortName = "avgtol")] | ||
public float AveragedTolerance = (float)1e-2; | ||
internal float AveragedTolerance = (float)1e-2; | ||
|
||
[BestFriend] | ||
internal class AveragedDefaultArgs : OnlineDefaultArgs | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
space #Pending