Skip to content

Commit 5fbe385

Browse files
committed
Binary LR samples using T4 templates (dotnet#3099)
1 parent 62a5b34 commit 5fbe385

13 files changed

+414
-189
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs

-86
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML.Data;
5+
<# if (TrainerOptions != null) { #>
6+
<#=OptionsInclude#>
7+
<# } #>
8+
9+
namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification
10+
{
11+
public static class <#=ClassName#>
12+
{<#=Comments#>
13+
public static void Example()
14+
{
15+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
16+
// as a catalog of available operations and as the source of randomness.
17+
// Setting the seed to a fixed number in this example to make outputs deterministic.
18+
var mlContext = new MLContext(seed: 0);
19+
20+
// Create a list of training data points.
21+
var dataPoints = GenerateRandomDataPoints(1000);
22+
23+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
24+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
25+
26+
<# if (TrainerOptions == null) { #>
27+
// Define the trainer.
28+
var pipeline = mlContext.BinaryClassification.Trainers.<#=Trainer#>();
29+
<# } else { #>
30+
// Define trainer options.
31+
var options = new <#=TrainerOptions#>;
32+
33+
// Define the trainer.
34+
var pipeline = mlContext.BinaryClassification.Trainers.<#=Trainer#>(options);
35+
<# } #>
36+
37+
// Train the model.
38+
var model = pipeline.Fit(trainingData);
39+
40+
// Create testing data. Use different random seed to make it different from training data.
41+
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
42+
43+
// Run the model on test data set.
44+
var transformedTestData = model.Transform(testData);
45+
46+
// Convert IDataView object to a list.
47+
var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList();
48+
49+
// Look at 5 predictions
50+
foreach (var p in predictions.Take(5))
51+
Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}");
52+
53+
<#=ExpectedOutputPerInstance#>
54+
<# string Evaluator = IsCalibrated ? "Evaluate" : "EvaluateNonCalibrated"; #>
55+
56+
// Evaluate the overall metrics
57+
var metrics = mlContext.BinaryClassification.<#=Evaluator#>(transformedTestData);
58+
SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
59+
60+
<#=ExpectedOutput#>
61+
}
62+
63+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
64+
{
65+
var random = new Random(seed);
66+
float randomFloat() => (float)random.NextDouble();
67+
for (int i = 0; i < count; i++)
68+
{
69+
var label = randomFloat() > 0.5f;
70+
yield return new DataPoint
71+
{
72+
Label = label,
73+
// Create random features that are correlated with the label.
74+
// For data points with false label, the feature values are slightly increased by adding a constant.
75+
Features = Enumerable.Repeat(label, 50).Select(x => x ? randomFloat() : randomFloat() + <#=DataSepValue#>).ToArray()
76+
};
77+
}
78+
}
79+
80+
// Example with label and 50 feature values. A data set is a collection of such examples.
81+
private class DataPoint
82+
{
83+
public bool Label { get; set; }
84+
[VectorType(50)]
85+
public float[] Features { get; set; }
86+
}
87+
88+
// Class used to capture predictions.
89+
private class Prediction
90+
{
91+
// Original label.
92+
public bool Label { get; set; }
93+
// Predicted label from the trainer.
94+
public bool PredictedLabel { get; set; }
95+
}
96+
}
97+
}

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.tt

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
<#@ include file="TreeSamplesTemplate.ttinclude"#>
2-
32
<#+
43
string ClassName="FastForest";
54
string Trainer = "FastForest";

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForestWithOptions.tt

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
<#@ include file="TreeSamplesTemplate.ttinclude"#>
2-
32
<#+
43
string ClassName="FastForestWithOptions";
54
string Trainer = "FastForest";

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTree.tt

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
<#@ include file="TreeSamplesTemplate.ttinclude"#>
2-
32
<#+
43
string ClassName="FastTree";
54
string Trainer = "FastTree";

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTreeWithOptions.tt

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
<#@ include file="TreeSamplesTemplate.ttinclude"#>
2-
32
<#+
43
string ClassName="FastTreeWithOptions";
54
string Trainer = "FastTree";
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML.Data;
5+
6+
namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification
7+
{
8+
public static class LbfgsLogisticRegression
9+
{
10+
public static void Example()
11+
{
12+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
13+
// as a catalog of available operations and as the source of randomness.
14+
// Setting the seed to a fixed number in this example to make outputs deterministic.
15+
var mlContext = new MLContext(seed: 0);
16+
17+
// Create a list of training data points.
18+
var dataPoints = GenerateRandomDataPoints(1000);
19+
20+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
21+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
22+
23+
// Define the trainer.
24+
var pipeline = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression();
25+
26+
// Train the model.
27+
var model = pipeline.Fit(trainingData);
28+
29+
// Create testing data. Use different random seed to make it different from training data.
30+
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
31+
32+
// Run the model on test data set.
33+
var transformedTestData = model.Transform(testData);
34+
35+
// Convert IDataView object to a list.
36+
var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList();
37+
38+
// Look at 5 predictions
39+
foreach (var p in predictions.Take(5))
40+
Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}");
41+
42+
// Expected output:
43+
// Label: True, Prediction: True
44+
// Label: False, Prediction: True
45+
// Label: True, Prediction: True
46+
// Label: True, Prediction: True
47+
// Label: False, Prediction: False
48+
49+
// Evaluate the overall metrics
50+
var metrics = mlContext.BinaryClassification.Evaluate(transformedTestData);
51+
SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
52+
53+
// Expected output:
54+
// Accuracy: 0.88
55+
// AUC: 0.96
56+
// F1 Score: 0.87
57+
// Negative Precision: 0.90
58+
// Negative Recall: 0.87
59+
// Positive Precision: 0.86
60+
// Positive Recall: 0.89
61+
// Log Loss: 0.38
62+
// Log Loss Reduction: 0.62
63+
// Entropy: 1.00
64+
}
65+
66+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
67+
{
68+
var random = new Random(seed);
69+
float randomFloat() => (float)random.NextDouble();
70+
for (int i = 0; i < count; i++)
71+
{
72+
var label = randomFloat() > 0.5f;
73+
yield return new DataPoint
74+
{
75+
Label = label,
76+
// Create random features that are correlated with the label.
77+
// For data points with false label, the feature values are slightly increased by adding a constant.
78+
Features = Enumerable.Repeat(label, 50).Select(x => x ? randomFloat() : randomFloat() + 0.1f).ToArray()
79+
};
80+
}
81+
}
82+
83+
// Example with label and 50 feature values. A data set is a collection of such examples.
84+
private class DataPoint
85+
{
86+
public bool Label { get; set; }
87+
[VectorType(50)]
88+
public float[] Features { get; set; }
89+
}
90+
91+
// Class used to capture predictions.
92+
private class Prediction
93+
{
94+
// Original label.
95+
public bool Label { get; set; }
96+
// Predicted label from the trainer.
97+
public bool PredictedLabel { get; set; }
98+
}
99+
}
100+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<#@ include file="BinaryClassification.ttinclude"#>
2+
<#+
3+
string ClassName="LbfgsLogisticRegression";
4+
string Trainer = "LbfgsLogisticRegression";
5+
string TrainerOptions = null;
6+
bool IsCalibrated = true;
7+
8+
string DataSepValue = "0.1f";
9+
string OptionsInclude = "";
10+
string Comments= "";
11+
12+
string ExpectedOutputPerInstance= @"// Expected output:
13+
// Label: True, Prediction: True
14+
// Label: False, Prediction: True
15+
// Label: True, Prediction: True
16+
// Label: True, Prediction: True
17+
// Label: False, Prediction: False";
18+
19+
string ExpectedOutput = @"// Expected output:
20+
// Accuracy: 0.88
21+
// AUC: 0.96
22+
// F1 Score: 0.87
23+
// Negative Precision: 0.90
24+
// Negative Recall: 0.87
25+
// Positive Precision: 0.86
26+
// Positive Recall: 0.89
27+
// Log Loss: 0.38
28+
// Log Loss Reduction: 0.62
29+
// Entropy: 1.00";
30+
#>

0 commit comments

Comments
 (0)