Skip to content

Commit 6f3c4dd

Browse files
committed
Really extract labels from learned pipeline
1 parent 9d2e7f0 commit 6f3c4dd

File tree

2 files changed

+45
-45
lines changed

2 files changed

+45
-45
lines changed

docs/samples/Microsoft.ML.Samples/Static/LightGBMMulticlassWithInMemoryData.cs

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using Microsoft.ML.StaticPipe;
55
using System;
66
using System.Collections.Generic;
7+
using System.Linq;
78

89
namespace Microsoft.ML.Samples.Static
910
{
@@ -23,11 +24,9 @@ private class NativeExample
2324
[VectorType(_featureVectorLength)]
2425
public float[] Features;
2526
[ColumnName("Label")]
26-
// One of "AA", "BB", "CC", and "DD".
2727
public string Label;
2828
public uint LabelIndex;
29-
// One of "AA", "BB", "CC", and "DD".
30-
public string PredictedLabel;
29+
public uint PredictedLabelIndex;
3130
[VectorType(4)]
3231
// The probabilities of being "AA", "BB", "CC", and "DD".
3332
public float[] Scores;
@@ -68,15 +67,15 @@ private static List<NativeExample> GenerateRandomExamples(int count)
6867

6968
// The following three attributes are just placeholder for storing prediction results.
7069
example.LabelIndex = default;
71-
example.PredictedLabel = null;
70+
example.PredictedLabelIndex = default;
7271
example.Scores = new float[4];
7372

7473
examples.Add(example);
7574
}
7675
return examples;
7776
}
7877

79-
public static void MultiClassLightGbmStaticPipelineWithInMemoryData()
78+
public void MultiClassLightGbmStaticPipelineWithInMemoryData()
8079
{
8180
// Create a general context for ML.NET operations. It can be used for exception tracking and logging,
8281
// as a catalog of available operations and as the source of randomness.
@@ -112,11 +111,10 @@ public static void MultiClassLightGbmStaticPipelineWithInMemoryData()
112111
r.Label,
113112
// Labels are converted to keys when training LightGBM so we convert it here again for calling evaluation function.
114113
LabelIndex: r.Label.ToKey(),
115-
// Instance of ClassificationVectorData returned
114+
// Used to compute metrics such as accuracy.
116115
r.Predictions,
117-
// ToValue() is used to get the original out from the class indexes computed by ToKey().
118-
// For example, if label "AA" is maped to index 0 via ToKey(), then ToValue() produces "AA" from 0.
119-
PredictedLabel: r.Predictions.predictedLabel.ToValue(),
116+
// Assign a new name to predicted class index.
117+
PredictedLabelIndex: r.Predictions.predictedLabel,
120118
// Assign a new name to class probabilities.
121119
Scores: r.Predictions.score
122120
));
@@ -135,26 +133,30 @@ public static void MultiClassLightGbmStaticPipelineWithInMemoryData()
135133
var metrics = ctx.Evaluate(prediction, r => r.LabelIndex, r => r.Predictions);
136134

137135
// Check if metrics are resonable.
138-
Console.WriteLine(metrics.AccuracyMacro); // expected value: 0.863482146891263
139-
Console.WriteLine(metrics.AccuracyMicro); // expected value: 0.86309523809523814
136+
Console.WriteLine ("Macro accuracy: {0}, Micro accuracy: {1}.", 0.863482146891263, 0.86309523809523814);
140137

141138
// Convert prediction in ML.NET format to native C# class.
142139
var nativePredictions = new List<NativeExample>(prediction.AsDynamic.AsEnumerable<NativeExample>(mlContext, false));
143140

144-
// Check predicted label and class probabilities of second-first example.
145-
// If you see a label with LabelIndex 1, its means its probability is the 1st element in the Scores field.
146-
// For example, if "AA" is indexed by 1, "BB" indexed by 2, "CC" indexed by 3, and "DD" indexed by 4, Scores is
147-
// ["AA" probability, "BB" probability, "CC" probability, "DD" probability].
141+
// Get cchema object of the prediction. It contains metadata such as the mapping from predicted label index
142+
// (e.g., 1) to its actual label (e.g., "AA").
143+
var schema = prediction.AsDynamic.Schema;
144+
145+
// Retrieve the mapping from labels to label indexes.
146+
var labelBuffer = new VBuffer<ReadOnlyMemory<char>>();
147+
schema[nameof(NativeExample.PredictedLabelIndex)].Metadata.GetValue("KeyValues", ref labelBuffer);
148+
var nativeLabels = labelBuffer.DenseValues().ToList(); // nativeLabels[nativePrediction.PredictedLabelIndex-1] is the original label indexed by nativePrediction.PredictedLabelIndex.
149+
150+
// Show prediction result for the 3rd example.
148151
var nativePrediction = nativePredictions[2];
149-
var probAA = nativePrediction.Scores[0];
150-
var probBB = nativePrediction.Scores[1];
151-
var probCC = nativePrediction.Scores[2];
152-
var probDD = nativePrediction.Scores[3];
153-
154-
Console.WriteLine(probAA); // expected value: 0.922597349
155-
Console.WriteLine(probBB); // expected value: 0.07508608
156-
Console.WriteLine(probCC); // expected value: 0.00221699756
157-
Console.WriteLine(probDD); // expected value: 9.95488E-05
152+
Console.WriteLine("Our predicted label to this example is {0} with probability {1}",
153+
nativeLabels[(int)nativePrediction.PredictedLabelIndex-1],
154+
nativePrediction.Scores[(int)nativePrediction.PredictedLabelIndex-1]);
155+
156+
var expectedProbabilities = new float[] { 0.922597349f, 0.07508608f, 0.00221699756f, 9.95488E-05f };
157+
// Scores and nativeLabels are two parallel attributes; that is, Scores[i] is the probability of being nativeLabels[i].
158+
for (int i = 0; i < labelBuffer.Length; ++i)
159+
Console.WriteLine("The probability of being class {0} is {1}.", nativeLabels[i], nativePrediction.Scores[i]);
158160
}
159161
}
160162
}

test/Microsoft.ML.StaticPipelineTesting/Training.cs

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,11 +1025,9 @@ private class NativeExample
10251025
[VectorType(_featureVectorLength)]
10261026
public float[] Features;
10271027
[ColumnName("Label")]
1028-
// One of "AA", "BB", "CC", and "DD".
10291028
public string Label;
10301029
public uint LabelIndex;
1031-
// One of "AA", "BB", "CC", and "DD".
1032-
public string PredictedLabel;
1030+
public uint PredictedLabelIndex;
10331031
[VectorType(4)]
10341032
// The probabilities of being "AA", "BB", "CC", and "DD".
10351033
public float[] Scores;
@@ -1070,7 +1068,7 @@ private static List<NativeExample> GenerateRandomExamples(int count)
10701068

10711069
// The following three attributes are just placeholder for storing prediction results.
10721070
example.LabelIndex = default;
1073-
example.PredictedLabel = null;
1071+
example.PredictedLabelIndex = default;
10741072
example.Scores = new float[4];
10751073

10761074
examples.Add(example);
@@ -1115,11 +1113,10 @@ public void MultiClassLightGbmStaticPipelineWithInMemoryData()
11151113
r.Label,
11161114
// Labels are converted to keys when training LightGBM so we convert it here again for calling evaluation function.
11171115
LabelIndex: r.Label.ToKey(),
1118-
// Instance of ClassificationVectorData returned
1116+
// Used to compute metrics such as accuracy.
11191117
r.Predictions,
1120-
// ToValue() is used to get the original out from the class indexes computed by ToKey().
1121-
// For example, if label "AA" is maped to index 0 via ToKey(), then ToValue() produces "AA" from 0.
1122-
PredictedLabel: r.Predictions.predictedLabel.ToValue(),
1118+
// Assign a new name to predicted class index.
1119+
PredictedLabelIndex: r.Predictions.predictedLabel,
11231120
// Assign a new name to class probabilities.
11241121
Scores: r.Predictions.score
11251122
));
@@ -1144,20 +1141,21 @@ public void MultiClassLightGbmStaticPipelineWithInMemoryData()
11441141
// Convert prediction in ML.NET format to native C# class.
11451142
var nativePredictions = new List<NativeExample>(prediction.AsDynamic.AsEnumerable<NativeExample>(mlContext, false));
11461143

1147-
// Check predicted label and class probabilities of second-first example.
1148-
// If you see a label with LabelIndex 1, its means its probability is the 1st element in the Scores field.
1149-
// For example, if "AA" is indexed by 1, "BB" indexed by 2, "CC" indexed by 3, and "DD" indexed by 4, Scores is
1150-
// ["AA" probability, "BB" probability, "CC" probability, "DD" probability].
1144+
// Get cchema object of the prediction. It contains metadata such as the mapping from predicted label index
1145+
// (e.g., 1) to its actual label (e.g., "AA").
1146+
var schema = prediction.AsDynamic.Schema;
1147+
1148+
// Retrieve the mapping from labels to label indexes.
1149+
var labelBuffer = new VBuffer<ReadOnlyMemory<char>>();
1150+
schema[nameof(NativeExample.PredictedLabelIndex)].Metadata.GetValue("KeyValues", ref labelBuffer);
1151+
var nativeLabels = labelBuffer.DenseValues().ToList(); // nativeLabels[nativePrediction.PredictedLabelIndex-1] is the original label indexed by nativePrediction.PredictedLabelIndex.
1152+
1153+
// Show prediction result for the 3rd example.
11511154
var nativePrediction = nativePredictions[2];
1152-
var probAA = nativePrediction.Scores[0];
1153-
var probBB = nativePrediction.Scores[1];
1154-
var probCC = nativePrediction.Scores[2];
1155-
var probDD = nativePrediction.Scores[3];
1156-
1157-
Assert.Equal(0.922597349, probAA, 6);
1158-
Assert.Equal(0.07508608, probBB, 6);
1159-
Assert.Equal(0.00221699756, probCC, 6);
1160-
Assert.Equal(9.95488E-05, probDD, 6);
1155+
var expectedProbabilities = new float[] { 0.922597349f, 0.07508608f, 0.00221699756f, 9.95488E-05f };
1156+
// Scores and nativeLabels are two parallel attributes; that is, Scores[i] is the probability of being nativeLabels[i].
1157+
for (int i = 0; i < labelBuffer.Length; ++i)
1158+
Assert.Equal(expectedProbabilities[i], nativePrediction.Scores[i], 6);
11611159
}
11621160
}
11631161
}

0 commit comments

Comments
 (0)