Skip to content

Commit a16eb30

Browse files
author
Shahab Moradi
authored
Added samples & docs for BinaryClassification.StochasticGradientDescent (#2688)
* Added samples & docs for BinaryClassification.StochasticGradientDescent, plus a bunch of typo fixing. * Addressed PR comments. * Mentioned Hogwild * Updates to exampleWeightColumnName. * Fixed trailing whitespaces.
1 parent 2ef0614 commit a16eb30

File tree

16 files changed

+226
-45
lines changed

16 files changed

+226
-45
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptron.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification
55
public static class AveragedPerceptron
66
{
77
// In this examples we will use the adult income dataset. The goal is to predict
8-
// if a person's income is above $50K or not, based on different pieces of information about that person.
8+
// if a person's income is above $50K or not, based on demographic information about that person.
99
// For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult.
1010
public static void Example()
1111
{

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification
66
public static class AveragedPerceptronWithOptions
77
{
88
// In this examples we will use the adult income dataset. The goal is to predict
9-
// if a person's income is above $50K or not, based on different pieces of information about that person.
9+
// if a person's income is above $50K or not, based on demographic information about that person.
1010
// For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult.
1111
public static void Example()
1212
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using Microsoft.ML;
2+
3+
namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification
4+
{
5+
public static class StochasticGradientDescent
6+
{
7+
// In this examples we will use the adult income dataset. The goal is to predict
8+
// if a person's income is above $50K or not, based on demographic information about that person.
9+
// For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult.
10+
public static void Example()
11+
{
12+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
13+
// as a catalog of available operations and as the source of randomness.
14+
// Setting the seed to a fixed number in this example to make outputs deterministic.
15+
var mlContext = new MLContext(seed: 0);
16+
17+
// Download and featurize the dataset.
18+
var data = SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext);
19+
20+
// Leave out 10% of data for testing.
21+
var trainTestData = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1);
22+
23+
// Create data training pipeline.
24+
var pipeline = mlContext.BinaryClassification.Trainers.StochasticGradientDescent();
25+
26+
// Fit this pipeline to the training data.
27+
var model = pipeline.Fit(trainTestData.TrainSet);
28+
29+
// Evaluate how the model is doing on the test data.
30+
var dataWithPredictions = model.Transform(trainTestData.TestSet);
31+
var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions);
32+
SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
33+
34+
// Expected output:
35+
// Accuracy: 0.85
36+
// AUC: 0.90
37+
// F1 Score: 0.67
38+
// Negative Precision: 0.90
39+
// Negative Recall: 0.91
40+
// Positive Precision: 0.68
41+
// Positive Recall: 0.65
42+
// LogLoss: 0.48
43+
// LogLossReduction: 38.31
44+
// Entropy: 0.78
45+
}
46+
}
47+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
using Microsoft.ML;
2+
using Microsoft.ML.Trainers;
3+
4+
namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification
5+
{
6+
public static class StochasticGradientDescentWithOptions
7+
{
8+
// In this examples we will use the adult income dataset. The goal is to predict
9+
// if a person's income is above $50K or not, based on demographic information about that person.
10+
// For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult.
11+
public static void Example()
12+
{
13+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
14+
// as a catalog of available operations and as the source of randomness.
15+
// Setting the seed to a fixed number in this example to make outputs deterministic.
16+
var mlContext = new MLContext(seed: 0);
17+
18+
// Download and featurize the dataset.
19+
var data = SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext);
20+
21+
// Leave out 10% of data for testing.
22+
var trainTestData = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1);
23+
24+
// Define the trainer options.
25+
var options = new SgdBinaryTrainer.Options()
26+
{
27+
// Make the convergence tolerance tighter.
28+
ConvergenceTolerance = 5e-5,
29+
// Increase the maximum number of passes over training data.
30+
MaxIterations = 30,
31+
// Give the instances of the positive class slightly more weight.
32+
PositiveInstanceWeight = 1.2f,
33+
};
34+
35+
// Create data training pipeline.
36+
var pipeline = mlContext.BinaryClassification.Trainers.StochasticGradientDescent(options);
37+
38+
// Fit this pipeline to the training data.
39+
var model = pipeline.Fit(trainTestData.TrainSet);
40+
41+
// Evaluate how the model is doing on the test data.
42+
var dataWithPredictions = model.Transform(trainTestData.TestSet);
43+
var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions);
44+
SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
45+
46+
// Expected output:
47+
// Accuracy: 0.85
48+
// AUC: 0.90
49+
// F1 Score: 0.67
50+
// Negative Precision: 0.91
51+
// Negative Recall: 0.89
52+
// Positive Precision: 0.65
53+
// Positive Recall: 0.70
54+
// LogLoss: 0.48
55+
// LogLossReduction: 37.52
56+
// Entropy: 0.78
57+
}
58+
}
59+
}

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/SymbolicStochasticGradientDescent.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ public static class SymbolicStochasticGradientDescent
44
{
55
// This example requires installation of additional nuget package <a href="https://www.nuget.org/packages/Microsoft.ML.HalLearners/">Microsoft.ML.HalLearners</a>.
66
// In this example we will use the adult income dataset. The goal is to predict
7-
// if a person's income is above $50K or not, based on different pieces of information about that person.
7+
// if a person's income is above $50K or not, based on demographic information about that person.
88
// For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult
99
public static void Example()
1010
{

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/SymbolicStochasticGradientDescentWithOptions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ public static class SymbolicStochasticGradientDescentWithOptions
44
{
55
// This example requires installation of additional nuget package <a href="https://www.nuget.org/packages/Microsoft.ML.HalLearners/">Microsoft.ML.HalLearners</a>.
66
// In this example we will use the adult income dataset. The goal is to predict
7-
// if a person's income is above $50K or not, based on different pieces of information about that person.
7+
// if a person's income is above $50K or not, based on demographic information about that person.
88
// For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult
99
public static void Example()
1010
{

src/Microsoft.ML.Data/EntryPoints/InputBase.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public abstract class LearnerInputBaseWithLabel : LearnerInputBase
9595
public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel
9696
{
9797
/// <summary>
98-
/// Column to use for example weight.
98+
/// The name of the example weight column.
9999
/// </summary>
100100
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
101101
public string WeightColumn = null;

src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs

+12
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@ public static void PrintMetrics(BinaryClassificationMetrics metrics)
2323
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}");
2424
}
2525

26+
/// <summary>
27+
/// Pretty-print CalibratedBinaryClassificationMetrics objects.
28+
/// </summary>
29+
/// <param name="metrics"><see cref="CalibratedBinaryClassificationMetrics"/> object.</param>
30+
public static void PrintMetrics(CalibratedBinaryClassificationMetrics metrics)
31+
{
32+
PrintMetrics(metrics as BinaryClassificationMetrics);
33+
Console.WriteLine($"LogLoss: {metrics.LogLoss:F2}");
34+
Console.WriteLine($"LogLossReduction: {metrics.LogLossReduction:F2}");
35+
Console.WriteLine($"Entropy: {metrics.Entropy:F2}");
36+
}
37+
2638
/// <summary>
2739
/// Pretty-print RegressionMetrics objects.
2840
/// </summary>

src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public abstract class AveragedLinearOptions : OnlineLinearOptions
6060
public bool DoLazyUpdates = true;
6161

6262
/// <summary>
63-
/// L2 weight for <a href='tmpurl_regularization'>regularization</a>.
63+
/// The L2 weight for <a href='tmpurl_regularization'>regularization</a>.
6464
/// </summary>
6565
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg", SortOrder = 50)]
6666
[TGUI(Label = "L2 Regularization Weight")]

src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPred
5454
private readonly Options _args;
5555

5656
/// <summary>
57-
/// Options for the averaged perceptron trainer.
57+
/// Options for the <see cref="AveragedPerceptronTrainer"/>.
5858
/// </summary>
5959
public sealed class Options : AveragedLinearOptions
6060
{

src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public abstract class OnlineLinearOptions : LearnerInputBaseWithLabel
2424
/// <summary>
2525
/// Number of passes through the training dataset.
2626
/// </summary>
27-
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter, numIterations", SortOrder = 50)]
27+
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter,numIterations", SortOrder = 50)]
2828
[TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")]
2929
[TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)]
3030
public int NumberOfIterations = OnlineDefault.NumIterations;
@@ -43,7 +43,7 @@ public abstract class OnlineLinearOptions : LearnerInputBaseWithLabel
4343
/// This property is only used if the provided value is positive and <see cref="InitialWeights"/> is not specified.
4444
/// The weights and bias will be randomly selected from InitialWeights * [-0.5,0.5] interval with uniform distribution.
4545
/// </value>
46-
[Argument(ArgumentType.AtMostOnce, HelpText = "Init weights diameter", ShortName = "initwts, initWtsDiameter", SortOrder = 140)]
46+
[Argument(ArgumentType.AtMostOnce, HelpText = "Init weights diameter", ShortName = "initwts,initWtsDiameter", SortOrder = 140)]
4747
[TGUI(Label = "Initial Weights Scale", SuggestedSweeps = "0,0.1,0.5,1")]
4848
[TlcModule.SweepableFloatParamAttribute("InitWtsDiameter", 0.0f, 1.0f, numSteps: 5)]
4949
public float InitialWeightsDiameter = 0;

src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs

+52-3
Original file line numberDiff line numberDiff line change
@@ -1723,36 +1723,77 @@ public abstract class SgdBinaryTrainerBase<TModel> :
17231723
{
17241724
public class OptionsBase : LearnerInputBaseWithWeight
17251725
{
1726+
/// <summary>
1727+
/// The L2 weight for <a href='tmpurl_regularization'>regularization</a>.
1728+
/// </summary>
17261729
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization constant", ShortName = "l2", SortOrder = 50)]
17271730
[TGUI(Label = "L2 Regularization Constant", SuggestedSweeps = "1e-7,5e-7,1e-6,5e-6,1e-5")]
17281731
[TlcModule.SweepableDiscreteParam("L2Const", new object[] { 1e-7f, 5e-7f, 1e-6f, 5e-6f, 1e-5f })]
17291732
public float L2Weight = Defaults.L2Weight;
17301733

1734+
/// <summary>
1735+
/// The degree of lock-free parallelism used by SGD.
1736+
/// </summary>
1737+
/// <value>
1738+
/// Defaults to automatic depending on data sparseness. Determinism is not guaranteed.
1739+
/// </value>
17311740
[Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Defaults to automatic depending on data sparseness. Determinism not guaranteed.", ShortName = "nt,t,threads", SortOrder = 50)]
17321741
[TGUI(Label = "Number of threads", SuggestedSweeps = "1,2,4")]
17331742
public int? NumThreads;
17341743

1744+
/// <summary>
1745+
/// The convergence tolerance. If the exponential moving average of loss reductions falls below this tolerance,
1746+
/// the algorithm is deemed to have converged and will stop.
1747+
/// </summary>
17351748
[Argument(ArgumentType.AtMostOnce, HelpText = "Exponential moving averaged improvement tolerance for convergence", ShortName = "tol")]
17361749
[TGUI(SuggestedSweeps = "1e-2,1e-3,1e-4,1e-5")]
17371750
[TlcModule.SweepableDiscreteParam("ConvergenceTolerance", new object[] { 1e-2f, 1e-3f, 1e-4f, 1e-5f })]
17381751
public double ConvergenceTolerance = 1e-4;
17391752

1753+
/// <summary>
1754+
/// The maximum number of passes through the training dataset.
1755+
/// </summary>
1756+
/// <value>
1757+
/// Set to 1 to simulate online learning.
1758+
/// </value>
17401759
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of iterations; set to 1 to simulate online learning.", ShortName = "iter")]
17411760
[TGUI(Label = "Max number of iterations", SuggestedSweeps = "1,5,10,20")]
17421761
[TlcModule.SweepableDiscreteParam("MaxIterations", new object[] { 1, 5, 10, 20 })]
17431762
public int MaxIterations = Defaults.MaxIterations;
17441763

1764+
/// <summary>
1765+
/// The initial <a href="tmpurl_lr">learning rate</a> used by SGD.
1766+
/// </summary>
17451767
[Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate (only used by SGD)", ShortName = "ilr,lr")]
17461768
[TGUI(Label = "Initial Learning Rate (for SGD)")]
17471769
public double InitLearningRate = Defaults.InitLearningRate;
17481770

1771+
/// <summary>
1772+
/// Determines whether to shuffle data for each training iteration.
1773+
/// </summary>
1774+
/// <value>
1775+
/// <see langword="true" /> to shuffle data for each training iteration; otherwise, <see langword="false" />.
1776+
/// Default is <see langword="true" />.
1777+
/// </value>
17491778
[Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data every epoch?", ShortName = "shuf")]
17501779
[TlcModule.SweepableDiscreteParam("Shuffle", null, isBool: true)]
17511780
public bool Shuffle = true;
17521781

1782+
/// <summary>
1783+
/// The weight to be applied to the positive class. This is useful for training with imbalanced data.
1784+
/// </summary>
1785+
/// <value>
1786+
/// Default value is 1, which means no extra weight.
1787+
/// </value>
17531788
[Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")]
17541789
public float PositiveInstanceWeight = 1;
17551790

1791+
/// <summary>
1792+
/// Determines the frequency of checking for convergence in terms of number of iterations.
1793+
/// </summary>
1794+
/// <value>
1795+
/// Default equals <see cref="NumThreads"/>."
1796+
/// </value>
17561797
[Argument(ArgumentType.AtMostOnce, HelpText = "Convergence check frequency (in terms of number of iterations). Default equals number of threads", ShortName = "checkFreq")]
17571798
public int? CheckFrequency;
17581799

@@ -1802,7 +1843,7 @@ internal static class Defaults
18021843
/// <param name="env">The environment to use.</param>
18031844
/// <param name="featureColumn">The name of the feature column.</param>
18041845
/// <param name="labelColumn">The name of the label column.</param>
1805-
/// <param name="weightColumn">The name for the example weight column.</param>
1846+
/// <param name="weightColumn">The name of the example weight column.</param>
18061847
/// <param name="maxIterations">The maximum number of iterations; set to 1 to simulate online learning.</param>
18071848
/// <param name="initLearningRate">The initial learning rate used by SGD.</param>
18081849
/// <param name="l2Weight">The L2 regularizer constant.</param>
@@ -2077,13 +2118,21 @@ private protected override void CheckLabel(RoleMappedData examples, out int weig
20772118
}
20782119

20792120
/// <summary>
2080-
/// Train logistic regression using a parallel stochastic gradient method.
2121+
/// The <see cref="IEstimator{TTransformer}"/> for training logistic regression using a parallel stochastic gradient method.
2122+
/// The trained model is <a href='tmpurl_calib'>calibrated</a> and can produce probability by feeding the output value of the
2123+
/// linear function to a <see cref="PlattCalibrator"/>.
20812124
/// </summary>
2125+
/// <remarks>
2126+
/// The Stochastic Gradient Descent (SGD) is one of the popular stochastic optimization procedures that can be integrated
2127+
/// into several machine learning tasks to achieve state-of-the-art performance. This trainer implements the Hogwild SGD for binary classification
2128+
/// that supports multi-threading without any locking. If the associated optimization problem is sparse, Hogwild SGD achieves a nearly optimal
2129+
/// rate of convergence. For more details about Hogwild SGD, please refer to http://arxiv.org/pdf/1106.5730v2.pdf.
2130+
/// </remarks>
20822131
public sealed class SgdBinaryTrainer :
20832132
SgdBinaryTrainerBase<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>
20842133
{
20852134
/// <summary>
2086-
/// Options available to training logistic regression using the implemented stochastic gradient method.
2135+
/// Options for the <see cref="SgdBinaryTrainer"/>.
20872136
/// </summary>
20882137
public sealed class Options : OptionsBase
20892138
{

0 commit comments

Comments
 (0)