Skip to content

Commit 9d2e7f0

Browse files
committed
Add an example for static pipeline with in-memory data and show how to get class probabilities
1 parent 2e67134 commit 9d2e7f0

File tree

2 files changed

+310
-0
lines changed

2 files changed

+310
-0
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
using Microsoft.ML.Data;
2+
using Microsoft.ML.LightGBM.StaticPipe;
3+
using Microsoft.ML.Runtime.Data;
4+
using Microsoft.ML.StaticPipe;
5+
using System;
6+
using System.Collections.Generic;
7+
8+
namespace Microsoft.ML.Samples.Static
9+
{
10+
class LightGBMMulticlassWithInMemoryData
11+
{
12+
/// <summary>
13+
/// Number of features per example used in <see cref="MultiClassLightGbmStaticPipelineWithInMemoryData"/>.
14+
/// </summary>
15+
private const int _featureVectorLength = 10;
16+
17+
/// <summary>
18+
/// Data point used in <see cref="MultiClassLightGbmStaticPipelineWithInMemoryData"/>. A data set there
19+
/// is a collection of <see cref="NativeExample"/>.
20+
/// </summary>
21+
private class NativeExample
22+
{
23+
[VectorType(_featureVectorLength)]
24+
public float[] Features;
25+
[ColumnName("Label")]
26+
// One of "AA", "BB", "CC", and "DD".
27+
public string Label;
28+
public uint LabelIndex;
29+
// One of "AA", "BB", "CC", and "DD".
30+
public string PredictedLabel;
31+
[VectorType(4)]
32+
// The probabilities of being "AA", "BB", "CC", and "DD".
33+
public float[] Scores;
34+
35+
public NativeExample()
36+
{
37+
Features = new float[_featureVectorLength];
38+
}
39+
}
40+
41+
/// <summary>
42+
/// Helper function used to generate <see cref="NativeExample"/>s.
43+
/// </summary>
44+
private static List<NativeExample> GenerateRandomExamples(int count)
45+
{
46+
var examples = new List<NativeExample>();
47+
var rnd = new Random(0);
48+
for (int i = 0; i < count; ++i)
49+
{
50+
var example = new NativeExample();
51+
var res = i % 4;
52+
// Generate random float feature values.
53+
for (int j = 0; j < _featureVectorLength; ++j)
54+
{
55+
var value = (float)rnd.NextDouble() + res * 0.2f;
56+
example.Features[j] = value;
57+
}
58+
59+
// Generate label based on feature sum.
60+
if (res == 0)
61+
example.Label = "AA";
62+
else if (res == 1)
63+
example.Label = "BB";
64+
else if (res == 2)
65+
example.Label = "CC";
66+
else
67+
example.Label = "DD";
68+
69+
// The following three attributes are just placeholder for storing prediction results.
70+
example.LabelIndex = default;
71+
example.PredictedLabel = null;
72+
example.Scores = new float[4];
73+
74+
examples.Add(example);
75+
}
76+
return examples;
77+
}
78+
79+
public static void MultiClassLightGbmStaticPipelineWithInMemoryData()
80+
{
81+
// Create a general context for ML.NET operations. It can be used for exception tracking and logging,
82+
// as a catalog of available operations and as the source of randomness.
83+
var mlContext = new MLContext(seed: 1, conc: 1);
84+
85+
// Context for calling static classifiers. It contains constructors of classifiers and evaluation utilities.
86+
var ctx = new MulticlassClassificationContext(mlContext);
87+
88+
// Create in-memory examples as C# native class.
89+
var examples = GenerateRandomExamples(1000);
90+
91+
// Convert native C# class to IDataView, a consumble format to ML.NET functions.
92+
var dataView = ComponentCreation.CreateDataView(mlContext, examples);
93+
94+
// IDataView is the data format used in dynamic-typed pipeline. To use static-typed pipeline, we need to convert
95+
// IDataView to DataView by calling AssertStatic(...). The basic idea is to specify the static type for each column
96+
// in IDataView in a lambda function.
97+
var staticDataView = dataView.AssertStatic(mlContext, c => (
98+
Features: c.R4.Vector,
99+
Label: c.Text.Scalar));
100+
101+
// Create static pipeline. First, we make an estimator out of static DataView as the starting of a pipeline.
102+
// Then, we append necessary transforms and a classifier to the starting estimator.
103+
var pipe = staticDataView.MakeNewEstimator()
104+
.Append(mapper: r => (
105+
r.Label,
106+
// Train multi-class LightGBM. The trained model maps Features to Label and probability of each class.
107+
// The call of ToKey() is needed to convert string labels to integer indexes.
108+
Predictions: ctx.Trainers.LightGbm(r.Label.ToKey(), r.Features)
109+
))
110+
.Append(r => (
111+
// Actual label.
112+
r.Label,
113+
// Labels are converted to keys when training LightGBM so we convert it here again for calling evaluation function.
114+
LabelIndex: r.Label.ToKey(),
115+
// Instance of ClassificationVectorData returned
116+
r.Predictions,
117+
// ToValue() is used to get the original out from the class indexes computed by ToKey().
118+
// For example, if label "AA" is maped to index 0 via ToKey(), then ToValue() produces "AA" from 0.
119+
PredictedLabel: r.Predictions.predictedLabel.ToValue(),
120+
// Assign a new name to class probabilities.
121+
Scores: r.Predictions.score
122+
));
123+
124+
// Split the static-typed data into training and test sets. Only training set is used in fitting
125+
// the created pipeline. Metrics are computed on the test.
126+
var (trainingData, testingData) = ctx.TrainTestSplit(staticDataView, testFraction: 0.5);
127+
128+
// Train the model.
129+
var model = pipe.Fit(trainingData);
130+
131+
// Do prediction on the test set.
132+
var prediction = model.Transform(testingData);
133+
134+
// Evaluate the trained model is the test set.
135+
var metrics = ctx.Evaluate(prediction, r => r.LabelIndex, r => r.Predictions);
136+
137+
// Check if metrics are resonable.
138+
Console.WriteLine(metrics.AccuracyMacro); // expected value: 0.863482146891263
139+
Console.WriteLine(metrics.AccuracyMicro); // expected value: 0.86309523809523814
140+
141+
// Convert prediction in ML.NET format to native C# class.
142+
var nativePredictions = new List<NativeExample>(prediction.AsDynamic.AsEnumerable<NativeExample>(mlContext, false));
143+
144+
// Check predicted label and class probabilities of second-first example.
145+
// If you see a label with LabelIndex 1, its means its probability is the 1st element in the Scores field.
146+
// For example, if "AA" is indexed by 1, "BB" indexed by 2, "CC" indexed by 3, and "DD" indexed by 4, Scores is
147+
// ["AA" probability, "BB" probability, "CC" probability, "DD" probability].
148+
var nativePrediction = nativePredictions[2];
149+
var probAA = nativePrediction.Scores[0];
150+
var probBB = nativePrediction.Scores[1];
151+
var probCC = nativePrediction.Scores[2];
152+
var probDD = nativePrediction.Scores[3];
153+
154+
Console.WriteLine(probAA); // expected value: 0.922597349
155+
Console.WriteLine(probBB); // expected value: 0.07508608
156+
Console.WriteLine(probCC); // expected value: 0.00221699756
157+
Console.WriteLine(probDD); // expected value: 9.95488E-05
158+
}
159+
}
160+
}

test/Microsoft.ML.StaticPipelineTesting/Training.cs

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
67
using System.Linq;
78
using Microsoft.ML;
89
using Microsoft.ML.Data;
@@ -1009,5 +1010,154 @@ public void MatrixFactorization()
10091010
// Naive test. Just make sure the pipeline runs.
10101011
Assert.InRange(metrics.L2, 0, 0.5);
10111012
}
1013+
1014+
/// <summary>
1015+
/// Number of features per example used in <see cref="MultiClassLightGbmStaticPipelineWithInMemoryData"/>.
1016+
/// </summary>
1017+
private const int _featureVectorLength = 10;
1018+
1019+
/// <summary>
1020+
/// Data point used in <see cref="MultiClassLightGbmStaticPipelineWithInMemoryData"/>. A data set there
1021+
/// is a collection of <see cref="NativeExample"/>.
1022+
/// </summary>
1023+
private class NativeExample
1024+
{
1025+
[VectorType(_featureVectorLength)]
1026+
public float[] Features;
1027+
[ColumnName("Label")]
1028+
// One of "AA", "BB", "CC", and "DD".
1029+
public string Label;
1030+
public uint LabelIndex;
1031+
// One of "AA", "BB", "CC", and "DD".
1032+
public string PredictedLabel;
1033+
[VectorType(4)]
1034+
// The probabilities of being "AA", "BB", "CC", and "DD".
1035+
public float[] Scores;
1036+
1037+
public NativeExample()
1038+
{
1039+
Features = new float[_featureVectorLength];
1040+
}
1041+
}
1042+
1043+
/// <summary>
1044+
/// Helper function used to generate <see cref="NativeExample"/>s.
1045+
/// </summary>
1046+
private static List<NativeExample> GenerateRandomExamples(int count)
1047+
{
1048+
var examples = new List<NativeExample>();
1049+
var rnd = new Random(0);
1050+
for (int i = 0; i < count; ++i)
1051+
{
1052+
var example = new NativeExample();
1053+
var res = i % 4;
1054+
// Generate random float feature values.
1055+
for (int j = 0; j < _featureVectorLength; ++j)
1056+
{
1057+
var value = (float)rnd.NextDouble() + res * 0.2f;
1058+
example.Features[j] = value;
1059+
}
1060+
1061+
// Generate label based on feature sum.
1062+
if (res == 0)
1063+
example.Label = "AA";
1064+
else if (res == 1)
1065+
example.Label = "BB";
1066+
else if (res == 2)
1067+
example.Label = "CC";
1068+
else
1069+
example.Label = "DD";
1070+
1071+
// The following three attributes are just placeholder for storing prediction results.
1072+
example.LabelIndex = default;
1073+
example.PredictedLabel = null;
1074+
example.Scores = new float[4];
1075+
1076+
examples.Add(example);
1077+
}
1078+
return examples;
1079+
}
1080+
1081+
[Fact()]
1082+
public void MultiClassLightGbmStaticPipelineWithInMemoryData()
1083+
{
1084+
// Create a general context for ML.NET operations. It can be used for exception tracking and logging,
1085+
// as a catalog of available operations and as the source of randomness.
1086+
var mlContext = new MLContext(seed: 1, conc: 1);
1087+
1088+
// Context for calling static classifiers. It contains constructors of classifiers and evaluation utilities.
1089+
var ctx = new MulticlassClassificationContext(mlContext);
1090+
1091+
// Create in-memory examples as C# native class.
1092+
var examples = GenerateRandomExamples(1000);
1093+
1094+
// Convert native C# class to IDataView, a consumble format to ML.NET functions.
1095+
var dataView = ComponentCreation.CreateDataView(mlContext, examples);
1096+
1097+
// IDataView is the data format used in dynamic-typed pipeline. To use static-typed pipeline, we need to convert
1098+
// IDataView to DataView by calling AssertStatic(...). The basic idea is to specify the static type for each column
1099+
// in IDataView in a lambda function.
1100+
var staticDataView = dataView.AssertStatic(mlContext, c => (
1101+
Features: c.R4.Vector,
1102+
Label: c.Text.Scalar));
1103+
1104+
// Create static pipeline. First, we make an estimator out of static DataView as the starting of a pipeline.
1105+
// Then, we append necessary transforms and a classifier to the starting estimator.
1106+
var pipe = staticDataView.MakeNewEstimator()
1107+
.Append(mapper: r => (
1108+
r.Label,
1109+
// Train multi-class LightGBM. The trained model maps Features to Label and probability of each class.
1110+
// The call of ToKey() is needed to convert string labels to integer indexes.
1111+
Predictions: ctx.Trainers.LightGbm(r.Label.ToKey(), r.Features)
1112+
))
1113+
.Append(r => (
1114+
// Actual label.
1115+
r.Label,
1116+
// Labels are converted to keys when training LightGBM so we convert it here again for calling evaluation function.
1117+
LabelIndex: r.Label.ToKey(),
1118+
// Instance of ClassificationVectorData returned
1119+
r.Predictions,
1120+
// ToValue() is used to get the original out from the class indexes computed by ToKey().
1121+
// For example, if label "AA" is maped to index 0 via ToKey(), then ToValue() produces "AA" from 0.
1122+
PredictedLabel: r.Predictions.predictedLabel.ToValue(),
1123+
// Assign a new name to class probabilities.
1124+
Scores: r.Predictions.score
1125+
));
1126+
1127+
// Split the static-typed data into training and test sets. Only training set is used in fitting
1128+
// the created pipeline. Metrics are computed on the test.
1129+
var (trainingData, testingData) = ctx.TrainTestSplit(staticDataView, testFraction: 0.5);
1130+
1131+
// Train the model.
1132+
var model = pipe.Fit(trainingData);
1133+
1134+
// Do prediction on the test set.
1135+
var prediction = model.Transform(testingData);
1136+
1137+
// Evaluate the trained model is the test set.
1138+
var metrics = ctx.Evaluate(prediction, r => r.LabelIndex, r => r.Predictions);
1139+
1140+
// Check if metrics are resonable.
1141+
Assert.Equal(0.863482146891263, metrics.AccuracyMacro, 6);
1142+
Assert.Equal(0.86309523809523814, metrics.AccuracyMicro, 6);
1143+
1144+
// Convert prediction in ML.NET format to native C# class.
1145+
var nativePredictions = new List<NativeExample>(prediction.AsDynamic.AsEnumerable<NativeExample>(mlContext, false));
1146+
1147+
// Check predicted label and class probabilities of second-first example.
1148+
// If you see a label with LabelIndex 1, its means its probability is the 1st element in the Scores field.
1149+
// For example, if "AA" is indexed by 1, "BB" indexed by 2, "CC" indexed by 3, and "DD" indexed by 4, Scores is
1150+
// ["AA" probability, "BB" probability, "CC" probability, "DD" probability].
1151+
var nativePrediction = nativePredictions[2];
1152+
var probAA = nativePrediction.Scores[0];
1153+
var probBB = nativePrediction.Scores[1];
1154+
var probCC = nativePrediction.Scores[2];
1155+
var probDD = nativePrediction.Scores[3];
1156+
1157+
Assert.Equal(0.922597349, probAA, 6);
1158+
Assert.Equal(0.07508608, probBB, 6);
1159+
Assert.Equal(0.00221699756, probCC, 6);
1160+
Assert.Equal(9.95488E-05, probDD, 6);
1161+
}
10121162
}
10131163
}

0 commit comments

Comments
 (0)