Skip to content

Commit 1942c8f

Browse files
authored
Cleaning and Fixing public API for set of learners. (#2765)
1 parent d65af0f commit 1942c8f

File tree

18 files changed

+127
-110
lines changed

18 files changed

+127
-110
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/PriorTrainerSample.cs renamed to docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/PriorTrainerSample.cs

+27-22
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,22 @@ public class PriorTrainer
77
{
88
public static void Example()
99
{
10-
// Downloading the dataset from github.com/dotnet/machinelearning.
11-
// This will create a sentiment.tsv file in the filesystem.
12-
// You can open this file, if you want to see the data.
13-
string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset()[0];
10+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
11+
// as a catalog of available operations and as the source of randomness.
12+
var mlContext = new MLContext();
13+
14+
// Download and featurize the dataset.
15+
var dataFiles = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
16+
var trainFile = dataFiles[0];
17+
var testFile = dataFiles[1];
1418

1519
// A preview of the data.
1620
// Sentiment SentimentText
1721
// 0 " :Erm, thank you. "
1822
// 1 ==You're cool==
1923

20-
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
21-
// as a catalog of available operations and as the source of randomness.
22-
var mlContext = new MLContext();
23-
24-
// Step 1: Load the data as an IDataView.
25-
// First, we define the loader: specify the data columns and where to find them in the text file.
24+
// Step 1: Read the data as an IDataView.
25+
// First, we define the reader: specify the data columns and where to find them in the text file.
2626
var loader = mlContext.Data.CreateTextLoader(
2727
columns: new[]
2828
{
@@ -31,12 +31,9 @@ public static void Example()
3131
},
3232
hasHeader: true
3333
);
34-
35-
// Load the data
36-
var data = loader.Load(dataFile);
3734

38-
// Split it between training and test data
39-
var trainTestData = mlContext.BinaryClassification.TrainTestSplit(data);
35+
// Load the data
36+
var trainData = loader.Load(trainFile);
4037

4138
// Step 2: Pipeline
4239
// Featurize the text column through the FeaturizeText API.
@@ -47,19 +44,27 @@ public static void Example()
4744
.Append(mlContext.BinaryClassification.Trainers.Prior(labelColumnName: "Sentiment"));
4845

4946
// Step 3: Train the pipeline
50-
var trainedPipeline = pipeline.Fit(trainTestData.TrainSet);
47+
var trainedPipeline = pipeline.Fit(trainData);
5148

5249
// Step 4: Evaluate on the test set
53-
var transformedData = trainedPipeline.Transform(trainTestData.TestSet);
50+
var transformedData = trainedPipeline.Transform(loader.Load(testFile));
5451
var evalMetrics = mlContext.BinaryClassification.Evaluate(transformedData, label: "Sentiment");
55-
56-
// Step 5: Inspect the output
57-
Console.WriteLine("Accuracy: " + evalMetrics.Accuracy);
52+
SamplesUtils.ConsoleUtils.PrintMetrics(evalMetrics);
5853

5954
// The Prior trainer outputs the proportion of a label in the dataset as the probability of that label.
60-
// In this case it means that there is a split of around 64%-36% of positive and negative labels in the dataset.
55+
// In this case 'Accuracy: 0.50' means that there is a split of around 50%-50% of positive and negative labels in the test dataset.
6156
// Expected output:
62-
// Accuracy: 0.647058823529412
57+
58+
// Accuracy: 0.50
59+
// AUC: 0.50
60+
// F1 Score: 0.67
61+
// Negative Precision: 0.00
62+
// Negative Recall: 0.00
63+
// Positive Precision: 0.50
64+
// Positive Recall: 1.00
65+
// LogLoss: 1.05
66+
// LogLossReduction: -4.89
67+
// Entropy: 1.00
6368
}
6469
}
6570
}

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/RandomTrainerSample.cs renamed to docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/RandomTrainerSample.cs

+29-24
Original file line numberDiff line numberDiff line change
@@ -7,59 +7,64 @@ public static class RandomTrainer
77
{
88
public static void Example()
99
{
10-
// Downloading the dataset from github.com/dotnet/machinelearning.
11-
// This will create a sentiment.tsv file in the filesystem.
12-
// You can open this file, if you want to see the data.
13-
string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset()[0];
10+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
11+
// as a catalog of available operations and as the source of randomness.
12+
var mlContext = new MLContext(seed: 1);
13+
14+
// Download and featurize the dataset.
15+
var dataFiles = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
16+
var trainFile = dataFiles[0];
17+
var testFile = dataFiles[1];
1418

1519
// A preview of the data.
1620
// Sentiment SentimentText
1721
// 0 " :Erm, thank you. "
1822
// 1 ==You're cool==
1923

20-
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
21-
// as a catalog of available operations and as the source of randomness.
22-
var mlContext = new MLContext(seed: 1);
23-
24-
// Step 1: Load the data as an IDataView.
25-
// First, we define the loader: specify the data columns and where to find them in the text file.
26-
var loader = mlContext.Data.CreateTextLoader(
24+
// Step 1: Read the data as an IDataView.
25+
// First, we define the reader: specify the data columns and where to find them in the text file.
26+
var reader = mlContext.Data.CreateTextLoader(
2727
columns: new[]
2828
{
2929
new TextLoader.Column("Sentiment", DataKind.Single, 0),
3030
new TextLoader.Column("SentimentText", DataKind.String, 1)
3131
},
3232
hasHeader: true
3333
);
34-
35-
// Load the data
36-
var data = loader.Load(dataFile);
3734

38-
// Split it between training and test data
39-
var trainTestData = mlContext.BinaryClassification.TrainTestSplit(data);
35+
// Read the data
36+
var trainData = reader.Load(trainFile);
4037

4138
// Step 2: Pipeline
4239
// Featurize the text column through the FeaturizeText API.
4340
// Then append a binary classifier, setting the "Label" column as the label of the dataset, and
4441
// the "Features" column produced by FeaturizeText as the features column.
4542
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
46-
.AppendCacheCheckpoint(mlContext) // Add a data-cache step within a pipeline.
43+
.AppendCacheCheckpoint(mlContext)
4744
.Append(mlContext.BinaryClassification.Trainers.Random());
4845

4946
// Step 3: Train the pipeline
50-
var trainedPipeline = pipeline.Fit(trainTestData.TrainSet);
47+
var trainedPipeline = pipeline.Fit(trainData);
5148

5249
// Step 4: Evaluate on the test set
53-
var transformedData = trainedPipeline.Transform(trainTestData.TestSet);
50+
var transformedData = trainedPipeline.Transform(reader.Load(testFile));
5451
var evalMetrics = mlContext.BinaryClassification.Evaluate(transformedData, label: "Sentiment");
55-
56-
// Step 5: Inspect the output
57-
Console.WriteLine("Accuracy: " + evalMetrics.Accuracy);
52+
SamplesUtils.ConsoleUtils.PrintMetrics(evalMetrics);
5853

5954
// We expect an output probability closet to 0.5 as the Random trainer outputs a random prediction.
6055
// Regardless of the input features, the trainer will predict either positive or negative label with equal probability.
61-
// Expected output (close to 0.5):
62-
// Accuracy: 0.588235294117647
56+
// Expected output: (close to 0.5):
57+
58+
// Accuracy: 0.56
59+
// AUC: 0.57
60+
// F1 Score: 0.60
61+
// Negative Precision: 0.57
62+
// Negative Recall: 0.44
63+
// Positive Precision: 0.55
64+
// Positive Recall: 0.67
65+
// LogLoss: 1.53
66+
// LogLossReduction: -53.37
67+
// Entropy: 1.00
6368
}
6469
}
6570
}

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/SymbolicStochasticGradientDescentWithOptions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public static void Example()
2020
var split = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1);
2121
// Create data training pipeline
2222
var pipeline = mlContext.BinaryClassification.Trainers.SymbolicStochasticGradientDescent(
23-
new ML.Trainers.HalLearners.SymSgdClassificationTrainer.Options()
23+
new ML.Trainers.HalLearners.SymbolicStochasticGradientDescentClassificationTrainer.Options()
2424
{
2525
LearningRate = 0.2f,
2626
NumberOfIterations = 10,

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquaresWithOptions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public static void Example()
4444

4545
// Create the estimator, here we only need OrdinaryLeastSquares trainer
4646
// as data is already processed in a form consumable by the trainer
47-
var pipeline = mlContext.Regression.Trainers.OrdinaryLeastSquares(new OlsLinearRegressionTrainer.Options()
47+
var pipeline = mlContext.Regression.Trainers.OrdinaryLeastSquares(new OrdinaryLeastSquaresRegressionTrainer.Options()
4848
{
4949
L2Weight = 0.1f,
5050
PerParameterSignificance = false

src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace Microsoft.ML.Trainers
1111
{
12-
using Mkl = OlsLinearRegressionTrainer.Mkl;
12+
using Mkl = OrdinaryLeastSquaresRegressionTrainer.Mkl;
1313

1414
public sealed class ComputeLRTrainingStdThroughHal : ComputeLRTrainingStd
1515
{

src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs

+20-20
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
namespace Microsoft.ML
1010
{
1111
/// <summary>
12-
/// The trainer catalog extensions for the <see cref="OlsLinearRegressionTrainer"/> and <see cref="SymSgdClassificationTrainer"/>.
12+
/// The trainer catalog extensions for the <see cref="OrdinaryLeastSquaresRegressionTrainer"/> and <see cref="SymbolicStochasticGradientDescentClassificationTrainer"/>.
1313
/// </summary>
1414
public static class HalLearnersCatalog
1515
{
1616
/// <summary>
17-
/// Predict a target using a linear regression model trained with the <see cref="OlsLinearRegressionTrainer"/>.
17+
/// Predict a target using a linear regression model trained with the <see cref="OrdinaryLeastSquaresRegressionTrainer"/>.
1818
/// </summary>
1919
/// <param name="catalog">The <see cref="RegressionCatalog"/>.</param>
2020
/// <param name="labelColumnName">The name of the label column.</param>
@@ -27,48 +27,48 @@ public static class HalLearnersCatalog
2727
/// ]]>
2828
/// </format>
2929
/// </example>
30-
public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionCatalog.RegressionTrainers catalog,
30+
public static OrdinaryLeastSquaresRegressionTrainer OrdinaryLeastSquares(this RegressionCatalog.RegressionTrainers catalog,
3131
string labelColumnName = DefaultColumnNames.Label,
3232
string featureColumnName = DefaultColumnNames.Features,
3333
string exampleWeightColumnName = null)
3434
{
3535
Contracts.CheckValue(catalog, nameof(catalog));
3636
var env = CatalogUtils.GetEnvironment(catalog);
37-
var options = new OlsLinearRegressionTrainer.Options
37+
var options = new OrdinaryLeastSquaresRegressionTrainer.Options
3838
{
3939
LabelColumnName = labelColumnName,
4040
FeatureColumnName = featureColumnName,
4141
ExampleWeightColumnName = exampleWeightColumnName
4242
};
4343

44-
return new OlsLinearRegressionTrainer(env, options);
44+
return new OrdinaryLeastSquaresRegressionTrainer(env, options);
4545
}
4646

4747
/// <summary>
48-
/// Predict a target using a linear regression model trained with the <see cref="OlsLinearRegressionTrainer"/>.
48+
/// Predict a target using a linear regression model trained with the <see cref="OrdinaryLeastSquaresRegressionTrainer"/>.
4949
/// </summary>
5050
/// <param name="catalog">The <see cref="RegressionCatalog"/>.</param>
51-
/// <param name="options">Algorithm advanced options. See <see cref="OlsLinearRegressionTrainer.Options"/>.</param>
51+
/// <param name="options">Algorithm advanced options. See <see cref="OrdinaryLeastSquaresRegressionTrainer.Options"/>.</param>
5252
/// <example>
5353
/// <format type="text/markdown">
5454
/// <![CDATA[
5555
/// [!code-csharp[OrdinaryLeastSquares](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquaresWithOptions.cs)]
5656
/// ]]>
5757
/// </format>
5858
/// </example>
59-
public static OlsLinearRegressionTrainer OrdinaryLeastSquares(
59+
public static OrdinaryLeastSquaresRegressionTrainer OrdinaryLeastSquares(
6060
this RegressionCatalog.RegressionTrainers catalog,
61-
OlsLinearRegressionTrainer.Options options)
61+
OrdinaryLeastSquaresRegressionTrainer.Options options)
6262
{
6363
Contracts.CheckValue(catalog, nameof(catalog));
6464
Contracts.CheckValue(options, nameof(options));
6565

6666
var env = CatalogUtils.GetEnvironment(catalog);
67-
return new OlsLinearRegressionTrainer(env, options);
67+
return new OrdinaryLeastSquaresRegressionTrainer(env, options);
6868
}
6969

7070
/// <summary>
71-
/// Predict a target using a linear binary classification model trained with the <see cref="SymSgdClassificationTrainer"/>.
71+
/// Predict a target using a linear binary classification model trained with the <see cref="SymbolicStochasticGradientDescentClassificationTrainer"/>.
7272
/// </summary>
7373
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
7474
/// <param name="labelColumnName">The name of the label column.</param>
@@ -81,43 +81,43 @@ public static OlsLinearRegressionTrainer OrdinaryLeastSquares(
8181
/// ]]>
8282
/// </format>
8383
/// </example>
84-
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
84+
public static SymbolicStochasticGradientDescentClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
8585
string labelColumnName = DefaultColumnNames.Label,
8686
string featureColumnName = DefaultColumnNames.Features,
87-
int numberOfIterations = SymSgdClassificationTrainer.Defaults.NumberOfIterations)
87+
int numberOfIterations = SymbolicStochasticGradientDescentClassificationTrainer.Defaults.NumberOfIterations)
8888
{
8989
Contracts.CheckValue(catalog, nameof(catalog));
9090
var env = CatalogUtils.GetEnvironment(catalog);
9191

92-
var options = new SymSgdClassificationTrainer.Options
92+
var options = new SymbolicStochasticGradientDescentClassificationTrainer.Options
9393
{
9494
LabelColumnName = labelColumnName,
9595
FeatureColumnName = featureColumnName,
9696
};
9797

98-
return new SymSgdClassificationTrainer(env, options);
98+
return new SymbolicStochasticGradientDescentClassificationTrainer(env, options);
9999
}
100100

101101
/// <summary>
102-
/// Predict a target using a linear binary classification model trained with the <see cref="SymSgdClassificationTrainer"/>.
102+
/// Predict a target using a linear binary classification model trained with the <see cref="SymbolicStochasticGradientDescentClassificationTrainer"/>.
103103
/// </summary>
104104
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
105-
/// <param name="options">Algorithm advanced options. See <see cref="SymSgdClassificationTrainer.Options"/>.</param>
105+
/// <param name="options">Algorithm advanced options. See <see cref="SymbolicStochasticGradientDescentClassificationTrainer.Options"/>.</param>
106106
/// <example>
107107
/// <format type="text/markdown">
108108
/// <![CDATA[
109109
/// [!code-csharp[SymbolicStochasticGradientDescent](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/SymbolicStochasticGradientDescentWithOptions.cs)]
110110
/// ]]>
111111
/// </format>
112112
/// </example>
113-
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(
113+
public static SymbolicStochasticGradientDescentClassificationTrainer SymbolicStochasticGradientDescent(
114114
this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
115-
SymSgdClassificationTrainer.Options options)
115+
SymbolicStochasticGradientDescentClassificationTrainer.Options options)
116116
{
117117
Contracts.CheckValue(catalog, nameof(catalog));
118118
Contracts.CheckValue(options, nameof(options));
119119
var env = CatalogUtils.GetEnvironment(catalog);
120-
return new SymSgdClassificationTrainer(env, options);
120+
return new SymbolicStochasticGradientDescentClassificationTrainer(env, options);
121121
}
122122

123123
/// <summary>

0 commit comments

Comments
 (0)