Skip to content

Commit 17bdb98

Browse files
authored
Add an example for static pipeline with in-memory data and show how to get class probabilities (#1953)
* Add an example for static pipeline with in-memory data and show how to get class probabilities * Really extract labels from learned pipeline * Extend a comment to mention that extracting labels is a temporal solution * Bypass 32-bit LightGbm test * Address comments * Move example data structure to SamplesUtils
1 parent 94964cf commit 17bdb98

File tree

5 files changed

+262
-0
lines changed

5 files changed

+262
-0
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
using Microsoft.ML.Data;
2+
using Microsoft.ML.LightGBM.StaticPipe;
3+
using Microsoft.ML.SamplesUtils;
4+
using Microsoft.ML.StaticPipe;
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
9+
namespace Microsoft.ML.Samples.Static
10+
{
11+
class LightGBMMulticlassWithInMemoryData
12+
{
13+
public void MultiClassLightGbmStaticPipelineWithInMemoryData()
14+
{
15+
// Create a general context for ML.NET operations. It can be used for exception tracking and logging,
16+
// as a catalog of available operations and as the source of randomness.
17+
var mlContext = new MLContext();
18+
19+
// Create in-memory examples as C# native class.
20+
var examples = DatasetUtils.GenerateRandomMulticlassClassificationExamples(1000);
21+
22+
// Convert native C# class to IDataView, a consumble format to ML.NET functions.
23+
var dataView = ComponentCreation.CreateDataView(mlContext, examples);
24+
25+
// IDataView is the data format used in dynamic-typed pipeline. To use static-typed pipeline, we need to convert
26+
// IDataView to DataView by calling AssertStatic(...). The basic idea is to specify the static type for each column
27+
// in IDataView in a lambda function.
28+
var staticDataView = dataView.AssertStatic(mlContext, c => (
29+
Features: c.R4.Vector,
30+
Label: c.Text.Scalar));
31+
32+
// Create static pipeline. First, we make an estimator out of static DataView as the starting of a pipeline.
33+
// Then, we append necessary transforms and a classifier to the starting estimator.
34+
var pipe = staticDataView.MakeNewEstimator()
35+
.Append(mapper: r => (
36+
r.Label,
37+
// Train multi-class LightGBM. The trained model maps Features to Label and probability of each class.
38+
// The call of ToKey() is needed to convert string labels to integer indexes.
39+
Predictions: mlContext.MulticlassClassification.Trainers.LightGbm(r.Label.ToKey(), r.Features)
40+
))
41+
.Append(r => (
42+
// Actual label.
43+
r.Label,
44+
// Labels are converted to keys when training LightGBM so we convert it here again for calling evaluation function.
45+
LabelIndex: r.Label.ToKey(),
46+
// Used to compute metrics such as accuracy.
47+
r.Predictions,
48+
// Assign a new name to predicted class index.
49+
PredictedLabelIndex: r.Predictions.predictedLabel,
50+
// Assign a new name to class probabilities.
51+
Scores: r.Predictions.score
52+
));
53+
54+
// Split the static-typed data into training and test sets. Only training set is used in fitting
55+
// the created pipeline. Metrics are computed on the test.
56+
var (trainingData, testingData) = mlContext.MulticlassClassification.TrainTestSplit(staticDataView, testFraction: 0.5);
57+
58+
// Train the model.
59+
var model = pipe.Fit(trainingData);
60+
61+
// Do prediction on the test set.
62+
var prediction = model.Transform(testingData);
63+
64+
// Evaluate the trained model is the test set.
65+
var metrics = mlContext.MulticlassClassification.Evaluate(prediction, r => r.LabelIndex, r => r.Predictions);
66+
67+
// Check if metrics are resonable.
68+
Console.WriteLine ("Macro accuracy: {0}, Micro accuracy: {1}.", 0.863482146891263, 0.86309523809523814);
69+
70+
// Convert prediction in ML.NET format to native C# class.
71+
var nativePredictions = new List<DatasetUtils.MulticlassClassificationExample>(prediction.AsDynamic.AsEnumerable<DatasetUtils.MulticlassClassificationExample>(mlContext, false));
72+
73+
// Get schema object out of the prediction. It contains metadata such as the mapping from predicted label index
74+
// (e.g., 1) to its actual label (e.g., "AA"). The call to "AsDynamic" converts our statically-typed pipeline into
75+
// a dynamically-typed one only for extracting metadata. In the future, metadata in statically-typed pipeline should
76+
// be accessible without dynamically-typed things.
77+
var schema = prediction.AsDynamic.Schema;
78+
79+
// Retrieve the mapping from labels to label indexes.
80+
var labelBuffer = new VBuffer<ReadOnlyMemory<char>>();
81+
schema[nameof(DatasetUtils.MulticlassClassificationExample.PredictedLabelIndex)].Metadata.GetValue("KeyValues", ref labelBuffer);
82+
// nativeLabels is { "AA" , "BB", "CC", "DD" }
83+
var nativeLabels = labelBuffer.DenseValues().ToArray(); // nativeLabels[nativePrediction.PredictedLabelIndex - 1] is the original label indexed by nativePrediction.PredictedLabelIndex.
84+
85+
86+
// Show prediction result for the 3rd example.
87+
var nativePrediction = nativePredictions[2];
88+
// Console output:
89+
// Our predicted label to this example is "AA" with probability 0.922597349.
90+
Console.WriteLine("Our predicted label to this example is {0} with probability {1}",
91+
nativeLabels[(int)nativePrediction.PredictedLabelIndex - 1],
92+
nativePrediction.Scores[(int)nativePrediction.PredictedLabelIndex - 1]);
93+
94+
var expectedProbabilities = new float[] { 0.922597349f, 0.07508608f, 0.00221699756f, 9.95488E-05f };
95+
// Scores and nativeLabels are two parallel attributes; that is, Scores[i] is the probability of being nativeLabels[i].
96+
// Console output:
97+
// The probability of being class "AA" is 0.922597349.
98+
// The probability of being class "BB" is 0.07508608.
99+
// The probability of being class "CC" is 0.00221699756.
100+
// The probability of being class "DD" is 9.95488E-05.
101+
for (int i = 0; i < labelBuffer.Length; ++i)
102+
Console.WriteLine("The probability of being class {0} is {1}.", nativeLabels[i], nativePrediction.Scores[i]);
103+
}
104+
}
105+
}

src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,13 @@ public static Scalar<float> LightGbm<TVal>(this RankingContext.RankingTrainers c
181181
/// the linear model that was trained. Note that this action cannot change the
182182
/// result in any way; it is only a way for the caller to be informed about what was learnt.</param>
183183
/// <returns>The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label.</returns>
184+
/// <example>
185+
/// <format type="text/markdown">
186+
/// <![CDATA[
187+
/// [!code-csharp[MF](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Static/LightGBMMulticlassWithInMemoryData.cs)]
188+
/// ]]>
189+
/// </format>
190+
/// </example>
184191
public static (Vector<float> score, Key<uint, TVal> predictedLabel)
185192
LightGbm<TVal>(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx,
186193
Key<uint, TVal> label,

src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,5 +237,68 @@ public static IEnumerable<SampleVectorOfNumbersData> GetVectorOfNumbersData()
237237
});
238238
return data;
239239
}
240+
241+
/// <summary>
242+
/// feature vector's length in <see cref="MulticlassClassificationExample"/>.
243+
/// </summary>
244+
private const int _featureVectorLength = 10;
245+
246+
public class MulticlassClassificationExample
247+
{
248+
[VectorType(_featureVectorLength)]
249+
public float[] Features;
250+
[ColumnName("Label")]
251+
public string Label;
252+
public uint LabelIndex;
253+
public uint PredictedLabelIndex;
254+
[VectorType(4)]
255+
// The probabilities of being "AA", "BB", "CC", and "DD".
256+
public float[] Scores;
257+
258+
public MulticlassClassificationExample()
259+
{
260+
Features = new float[_featureVectorLength];
261+
}
262+
}
263+
264+
/// <summary>
265+
/// Helper function used to generate random <see cref="GenerateRandomMulticlassClassificationExamples"/>s.
266+
/// </summary>
267+
/// <param name="count">Number of generated examples.</param>
268+
/// <returns>A list of random examples.</returns>
269+
public static List<MulticlassClassificationExample> GenerateRandomMulticlassClassificationExamples(int count)
270+
{
271+
var examples = new List<MulticlassClassificationExample>();
272+
var rnd = new Random(0);
273+
for (int i = 0; i < count; ++i)
274+
{
275+
var example = new MulticlassClassificationExample();
276+
var res = i % 4;
277+
// Generate random float feature values.
278+
for (int j = 0; j < _featureVectorLength; ++j)
279+
{
280+
var value = (float)rnd.NextDouble() + res * 0.2f;
281+
example.Features[j] = value;
282+
}
283+
284+
// Generate label based on feature sum.
285+
if (res == 0)
286+
example.Label = "AA";
287+
else if (res == 1)
288+
example.Label = "BB";
289+
else if (res == 2)
290+
example.Label = "CC";
291+
else
292+
example.Label = "DD";
293+
294+
// The following three attributes are just placeholder for storing prediction results.
295+
example.LabelIndex = default;
296+
example.PredictedLabelIndex = default;
297+
example.Scores = new float[4];
298+
299+
examples.Add(example);
300+
}
301+
return examples;
302+
}
240303
}
241304
}

test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
<ProjectReference Include="..\..\src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" />
1010
<ProjectReference Include="..\..\src\Microsoft.ML.LightGBM.StaticPipe\Microsoft.ML.LightGBM.StaticPipe.csproj" />
1111
<ProjectReference Include="..\..\src\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
12+
<ProjectReference Include="..\..\src\Microsoft.ML.SamplesUtils\Microsoft.ML.SamplesUtils.csproj" />
1213
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
1314
<ProjectReference Include="..\..\src\Microsoft.ML.StaticPipe\Microsoft.ML.StaticPipe.csproj" />
1415
<ProjectReference Include="..\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj" />

test/Microsoft.ML.StaticPipelineTesting/Training.cs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
67
using System.Linq;
78
using Microsoft.ML;
89
using Microsoft.ML.Data;
@@ -13,6 +14,7 @@
1314
using Microsoft.ML.LightGBM;
1415
using Microsoft.ML.LightGBM.StaticPipe;
1516
using Microsoft.ML.RunTests;
17+
using Microsoft.ML.SamplesUtils;
1618
using Microsoft.ML.StaticPipe;
1719
using Microsoft.ML.Trainers;
1820
using Microsoft.ML.Trainers.FastTree;
@@ -1009,5 +1011,89 @@ public void MatrixFactorization()
10091011
// Naive test. Just make sure the pipeline runs.
10101012
Assert.InRange(metrics.L2, 0, 0.5);
10111013
}
1014+
1015+
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only
1016+
public void MultiClassLightGbmStaticPipelineWithInMemoryData()
1017+
{
1018+
// Create a general context for ML.NET operations. It can be used for exception tracking and logging,
1019+
// as a catalog of available operations and as the source of randomness.
1020+
var mlContext = new MLContext(seed: 1, conc: 1);
1021+
1022+
// Create in-memory examples as C# native class.
1023+
var examples = DatasetUtils.GenerateRandomMulticlassClassificationExamples(1000);
1024+
1025+
// Convert native C# class to IDataView, a consumble format to ML.NET functions.
1026+
var dataView = ComponentCreation.CreateDataView(mlContext, examples);
1027+
1028+
// IDataView is the data format used in dynamic-typed pipeline. To use static-typed pipeline, we need to convert
1029+
// IDataView to DataView by calling AssertStatic(...). The basic idea is to specify the static type for each column
1030+
// in IDataView in a lambda function.
1031+
var staticDataView = dataView.AssertStatic(mlContext, c => (
1032+
Features: c.R4.Vector,
1033+
Label: c.Text.Scalar));
1034+
1035+
// Create static pipeline. First, we make an estimator out of static DataView as the starting of a pipeline.
1036+
// Then, we append necessary transforms and a classifier to the starting estimator.
1037+
var pipe = staticDataView.MakeNewEstimator()
1038+
.Append(mapper: r => (
1039+
r.Label,
1040+
// Train multi-class LightGBM. The trained model maps Features to Label and probability of each class.
1041+
// The call of ToKey() is needed to convert string labels to integer indexes.
1042+
Predictions: mlContext.MulticlassClassification.Trainers.LightGbm(r.Label.ToKey(), r.Features)
1043+
))
1044+
.Append(r => (
1045+
// Actual label.
1046+
r.Label,
1047+
// Labels are converted to keys when training LightGBM so we convert it here again for calling evaluation function.
1048+
LabelIndex: r.Label.ToKey(),
1049+
// Used to compute metrics such as accuracy.
1050+
r.Predictions,
1051+
// Assign a new name to predicted class index.
1052+
PredictedLabelIndex: r.Predictions.predictedLabel,
1053+
// Assign a new name to class probabilities.
1054+
Scores: r.Predictions.score
1055+
));
1056+
1057+
// Split the static-typed data into training and test sets. Only training set is used in fitting
1058+
// the created pipeline. Metrics are computed on the test.
1059+
var (trainingData, testingData) = mlContext.MulticlassClassification.TrainTestSplit(staticDataView, testFraction: 0.5);
1060+
1061+
// Train the model.
1062+
var model = pipe.Fit(trainingData);
1063+
1064+
// Do prediction on the test set.
1065+
var prediction = model.Transform(testingData);
1066+
1067+
// Evaluate the trained model is the test set.
1068+
var metrics = mlContext.MulticlassClassification.Evaluate(prediction, r => r.LabelIndex, r => r.Predictions);
1069+
1070+
// Check if metrics are resonable.
1071+
Assert.Equal(0.863482146891263, metrics.AccuracyMacro, 6);
1072+
Assert.Equal(0.86309523809523814, metrics.AccuracyMicro, 6);
1073+
1074+
// Convert prediction in ML.NET format to native C# class.
1075+
var nativePredictions = new List<DatasetUtils.MulticlassClassificationExample>(prediction.AsDynamic.AsEnumerable<DatasetUtils.MulticlassClassificationExample>(mlContext, false));
1076+
1077+
// Get schema object of the prediction. It contains metadata such as the mapping from predicted label index
1078+
// (e.g., 1) to its actual label (e.g., "AA").
1079+
var schema = prediction.AsDynamic.Schema;
1080+
1081+
// Retrieve the mapping from labels to label indexes.
1082+
var labelBuffer = new VBuffer<ReadOnlyMemory<char>>();
1083+
schema[nameof(DatasetUtils.MulticlassClassificationExample.PredictedLabelIndex)].Metadata.GetValue("KeyValues", ref labelBuffer);
1084+
var nativeLabels = labelBuffer.DenseValues().ToList(); // nativeLabels[nativePrediction.PredictedLabelIndex-1] is the original label indexed by nativePrediction.PredictedLabelIndex.
1085+
1086+
// Show prediction result for the 3rd example.
1087+
var nativePrediction = nativePredictions[2];
1088+
var expectedProbabilities = new float[] { 0.922597349f, 0.07508608f, 0.00221699756f, 9.95488E-05f };
1089+
// Scores and nativeLabels are two parallel attributes; that is, Scores[i] is the probability of being nativeLabels[i].
1090+
for (int i = 0; i < labelBuffer.Length; ++i)
1091+
Assert.Equal(expectedProbabilities[i], nativePrediction.Scores[i], 6);
1092+
1093+
// The predicted label below should be with probability 0.922597349.
1094+
Console.WriteLine("Our predicted label to this example is {0} with probability {1}",
1095+
nativeLabels[(int)nativePrediction.PredictedLabelIndex-1],
1096+
nativePrediction.Scores[(int)nativePrediction.PredictedLabelIndex-1]);
1097+
}
10121098
}
10131099
}

0 commit comments

Comments
 (0)