Skip to content

Commit 5975856

Browse files
committed
multiclasssamples
1 parent 8bc1781 commit 5975856

25 files changed

+1248
-264
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
7+
namespace Samples.Dynamic.Trainers.MulticlassClassification
8+
{
9+
public static class LbfgsMaximumEntropy
10+
{
11+
public static void Example()
12+
{
13+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
14+
// as a catalog of available operations and as the source of randomness.
15+
// Setting the seed to a fixed number in this example to make outputs deterministic.
16+
var mlContext = new MLContext(seed: 0);
17+
18+
// Create a list of training data points.
19+
var dataPoints = GenerateRandomDataPoints(1000);
20+
21+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
22+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
23+
24+
// Define the trainer.
25+
var pipeline =
26+
// Convert the string labels into key types.
27+
mlContext.Transforms.Conversion.MapValueToKey("Label")
28+
// Apply LbfgsMaximumEntropy multiclass trainer.
29+
.Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy());
30+
31+
// Train the model.
32+
var model = pipeline.Fit(trainingData);
33+
34+
// Create testing data. Use different random seed to make it different from training data.
35+
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
36+
37+
// Run the model on test data set.
38+
var transformedTestData = model.Transform(testData);
39+
40+
// Convert IDataView object to a list.
41+
var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList();
42+
43+
// Look at 5 predictions
44+
foreach (var p in predictions.Take(5))
45+
Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}");
46+
47+
// Expected output:
48+
// Label: 1, Prediction: 1
49+
// Label: 2, Prediction: 2
50+
// Label: 3, Prediction: 2
51+
// Label: 2, Prediction: 2
52+
// Label: 3, Prediction: 3
53+
54+
// Evaluate the overall metrics
55+
var metrics = mlContext.MulticlassClassification.Evaluate(transformedTestData);
56+
PrintMetrics(metrics);
57+
58+
// Expected output:
59+
// Micro Accuracy: 0.91
60+
// Macro Accuracy: 0.91
61+
// Log Loss: 0.24
62+
// Log Loss Reduction: 0.78
63+
}
64+
65+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
66+
{
67+
var random = new Random(seed);
68+
float randomFloat() => (float)random.NextDouble();
69+
for (int i = 0; i < count; i++)
70+
{
71+
// Generate Labels that are integers 1, 2 or 3
72+
var label = random.Next(1, 4);
73+
yield return new DataPoint
74+
{
75+
Label = (uint)label,
76+
// Create random features that are correlated with the label.
77+
// The feature values are slightly increased by adding a constant multiple of label.
78+
Features = Enumerable.Repeat(label, 20).Select(x => randomFloat() + label * 0.2f).ToArray()
79+
};
80+
}
81+
}
82+
83+
// Example with label and 20 feature values. A data set is a collection of such examples.
84+
private class DataPoint
85+
{
86+
public uint Label { get; set; }
87+
[VectorType(20)]
88+
public float[] Features { get; set; }
89+
}
90+
91+
// Class used to capture predictions.
92+
private class Prediction
93+
{
94+
// Original label.
95+
public uint Label { get; set; }
96+
// Predicted label from the trainer.
97+
public uint PredictedLabel { get; set; }
98+
}
99+
100+
// Pretty-print MulticlassClassificationMetrics objects.
101+
public static void PrintMetrics(MulticlassClassificationMetrics metrics)
102+
{
103+
Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}");
104+
Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}");
105+
Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}");
106+
Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}");
107+
}
108+
}
109+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
<#@ include file="MulticlassClassification.ttinclude"#>
2+
<#+
3+
string ClassName = "LbfgsMaximumEntropy";
4+
string Trainer = "LbfgsMaximumEntropy";
5+
string MetaTrainer = null;
6+
string TrainerOptions = null;
7+
8+
string OptionsInclude = "";
9+
string Comments = "";
10+
bool CacheData = false;
11+
12+
string ExpectedOutputPerInstance = @"// Expected output:
13+
// Label: 1, Prediction: 1
14+
// Label: 2, Prediction: 2
15+
// Label: 3, Prediction: 2
16+
// Label: 2, Prediction: 2
17+
// Label: 3, Prediction: 3";
18+
19+
string ExpectedOutput = @"// Expected output:
20+
// Micro Accuracy: 0.91
21+
// Macro Accuracy: 0.91
22+
// Log Loss: 0.24
23+
// Log Loss Reduction: 0.78";
24+
#>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
using Microsoft.ML.Trainers;
7+
8+
namespace Samples.Dynamic.Trainers.MulticlassClassification
9+
{
10+
public static class LbfgsMaximumEntropyWithOptions
11+
{
12+
public static void Example()
13+
{
14+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
15+
// as a catalog of available operations and as the source of randomness.
16+
// Setting the seed to a fixed number in this example to make outputs deterministic.
17+
var mlContext = new MLContext(seed: 0);
18+
19+
// Create a list of training data points.
20+
var dataPoints = GenerateRandomDataPoints(1000);
21+
22+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
23+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
24+
25+
// Define trainer options.
26+
var options = new LbfgsMaximumEntropyMulticlassTrainer.Options
27+
{
28+
HistorySize = 50,
29+
L1Regularization = 0.1f,
30+
NumberOfThreads = 1
31+
};
32+
33+
// Define the trainer.
34+
var pipeline =
35+
// Convert the string labels into key types.
36+
mlContext.Transforms.Conversion.MapValueToKey("Label")
37+
// Apply LbfgsMaximumEntropy multiclass trainer.
38+
.Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(options));
39+
40+
41+
// Train the model.
42+
var model = pipeline.Fit(trainingData);
43+
44+
// Create testing data. Use different random seed to make it different from training data.
45+
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
46+
47+
// Run the model on test data set.
48+
var transformedTestData = model.Transform(testData);
49+
50+
// Convert IDataView object to a list.
51+
var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList();
52+
53+
// Look at 5 predictions
54+
foreach (var p in predictions.Take(5))
55+
Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}");
56+
57+
// Expected output:
58+
// Label: 1, Prediction: 1
59+
// Label: 2, Prediction: 2
60+
// Label: 3, Prediction: 2
61+
// Label: 2, Prediction: 2
62+
// Label: 3, Prediction: 3
63+
64+
// Evaluate the overall metrics
65+
var metrics = mlContext.MulticlassClassification.Evaluate(transformedTestData);
66+
PrintMetrics(metrics);
67+
68+
// Expected output:
69+
// Micro Accuracy: 0.91
70+
// Macro Accuracy: 0.91
71+
// Log Loss: 0.22
72+
// Log Loss Reduction: 0.80
73+
}
74+
75+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
76+
{
77+
var random = new Random(seed);
78+
float randomFloat() => (float)random.NextDouble();
79+
for (int i = 0; i < count; i++)
80+
{
81+
// Generate Labels that are integers 1, 2 or 3
82+
var label = random.Next(1, 4);
83+
yield return new DataPoint
84+
{
85+
Label = (uint)label,
86+
// Create random features that are correlated with the label.
87+
// The feature values are slightly increased by adding a constant multiple of label.
88+
Features = Enumerable.Repeat(label, 20).Select(x => randomFloat() + label * 0.2f).ToArray()
89+
};
90+
}
91+
}
92+
93+
// Example with label and 20 feature values. A data set is a collection of such examples.
94+
private class DataPoint
95+
{
96+
public uint Label { get; set; }
97+
[VectorType(20)]
98+
public float[] Features { get; set; }
99+
}
100+
101+
// Class used to capture predictions.
102+
private class Prediction
103+
{
104+
// Original label.
105+
public uint Label { get; set; }
106+
// Predicted label from the trainer.
107+
public uint PredictedLabel { get; set; }
108+
}
109+
110+
// Pretty-print MulticlassClassificationMetrics objects.
111+
public static void PrintMetrics(MulticlassClassificationMetrics metrics)
112+
{
113+
Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}");
114+
Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}");
115+
Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}");
116+
Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}");
117+
}
118+
}
119+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<#@ include file="MulticlassClassification.ttinclude"#>
2+
<#+
3+
string ClassName = "LbfgsMaximumEntropyWithOptions";
4+
string Trainer = "LbfgsMaximumEntropy";
5+
string MetaTrainer = null;
6+
string TrainerOptions = @"LbfgsMaximumEntropyMulticlassTrainer.Options
7+
{
8+
HistorySize = 50,
9+
L1Regularization = 0.1f,
10+
NumberOfThreads = 1
11+
}";
12+
13+
string OptionsInclude = "using Microsoft.ML.Trainers;";
14+
string Comments = "";
15+
16+
bool CacheData = false;
17+
18+
string ExpectedOutputPerInstance = @"// Expected output:
19+
// Label: 1, Prediction: 1
20+
// Label: 2, Prediction: 2
21+
// Label: 3, Prediction: 2
22+
// Label: 2, Prediction: 2
23+
// Label: 3, Prediction: 3";
24+
25+
string ExpectedOutput = @"// Expected output:
26+
// Micro Accuracy: 0.91
27+
// Macro Accuracy: 0.91
28+
// Log Loss: 0.22
29+
// Log Loss Reduction: 0.80";
30+
#>

0 commit comments

Comments
 (0)