-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add an example for static pipeline with in-memory data and show how to get class probabilities #1953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an example for static pipeline with in-memory data and show how to get class probabilities #1953
Changes from all commits
683bd3b
c92c431
56c53f2
6fce105
5aa58c3
2f5ed5f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.LightGBM.StaticPipe; | ||
using Microsoft.ML.SamplesUtils; | ||
using Microsoft.ML.StaticPipe; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
|
||
namespace Microsoft.ML.Samples.Static | ||
{ | ||
class LightGBMMulticlassWithInMemoryData | ||
{ | ||
public void MultiClassLightGbmStaticPipelineWithInMemoryData() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
i think most of the emphasis in 1.0 is on the dynamic API. Would this example add value as a dynamic sample? #WontFix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
{ | ||
// Create a general context for ML.NET operations. It can be used for exception tracking and logging, | ||
// as a catalog of available operations and as the source of randomness. | ||
var mlContext = new MLContext(); | ||
|
||
// Create in-memory examples as C# native class. | ||
var examples = DatasetUtils.GenerateRandomMulticlassClassificationExamples(1000); | ||
|
||
// Convert native C# class to IDataView, a consumble format to ML.NET functions. | ||
var dataView = ComponentCreation.CreateDataView(mlContext, examples); | ||
|
||
// IDataView is the data format used in dynamic-typed pipeline. To use static-typed pipeline, we need to convert | ||
// IDataView to DataView by calling AssertStatic(...). The basic idea is to specify the static type for each column | ||
// in IDataView in a lambda function. | ||
var staticDataView = dataView.AssertStatic(mlContext, c => ( | ||
Features: c.R4.Vector, | ||
Label: c.Text.Scalar)); | ||
|
||
// Create static pipeline. First, we make an estimator out of static DataView as the starting of a pipeline. | ||
// Then, we append necessary transforms and a classifier to the starting estimator. | ||
var pipe = staticDataView.MakeNewEstimator() | ||
.Append(mapper: r => ( | ||
r.Label, | ||
// Train multi-class LightGBM. The trained model maps Features to Label and probability of each class. | ||
// The call of ToKey() is needed to convert string labels to integer indexes. | ||
Predictions: mlContext.MulticlassClassification.Trainers.LightGbm(r.Label.ToKey(), r.Features) | ||
)) | ||
.Append(r => ( | ||
// Actual label. | ||
r.Label, | ||
// Labels are converted to keys when training LightGBM so we convert it here again for calling evaluation function. | ||
LabelIndex: r.Label.ToKey(), | ||
// Used to compute metrics such as accuracy. | ||
r.Predictions, | ||
// Assign a new name to predicted class index. | ||
PredictedLabelIndex: r.Predictions.predictedLabel, | ||
// Assign a new name to class probabilities. | ||
Scores: r.Predictions.score | ||
)); | ||
|
||
// Split the static-typed data into training and test sets. Only training set is used in fitting | ||
// the created pipeline. Metrics are computed on the test. | ||
var (trainingData, testingData) = mlContext.MulticlassClassification.TrainTestSplit(staticDataView, testFraction: 0.5); | ||
|
||
// Train the model. | ||
var model = pipe.Fit(trainingData); | ||
|
||
// Do prediction on the test set. | ||
var prediction = model.Transform(testingData); | ||
|
||
// Evaluate the trained model is the test set. | ||
var metrics = mlContext.MulticlassClassification.Evaluate(prediction, r => r.LabelIndex, r => r.Predictions); | ||
|
||
// Check if metrics are resonable. | ||
Console.WriteLine ("Macro accuracy: {0}, Micro accuracy: {1}.", 0.863482146891263, 0.86309523809523814); | ||
|
||
// Convert prediction in ML.NET format to native C# class. | ||
var nativePredictions = new List<DatasetUtils.MulticlassClassificationExample>(prediction.AsDynamic.AsEnumerable<DatasetUtils.MulticlassClassificationExample>(mlContext, false)); | ||
|
||
// Get schema object out of the prediction. It contains metadata such as the mapping from predicted label index | ||
// (e.g., 1) to its actual label (e.g., "AA"). The call to "AsDynamic" converts our statically-typed pipeline into | ||
// a dynamically-typed one only for extracting metadata. In the future, metadata in statically-typed pipeline should | ||
// be accessible without dynamically-typed things. | ||
var schema = prediction.AsDynamic.Schema; | ||
|
||
// Retrieve the mapping from labels to label indexes. | ||
var labelBuffer = new VBuffer<ReadOnlyMemory<char>>(); | ||
schema[nameof(DatasetUtils.MulticlassClassificationExample.PredictedLabelIndex)].Metadata.GetValue("KeyValues", ref labelBuffer); | ||
// nativeLabels is { "AA" , "BB", "CC", "DD" } | ||
var nativeLabels = labelBuffer.DenseValues().ToArray(); // nativeLabels[nativePrediction.PredictedLabelIndex - 1] is the original label indexed by nativePrediction.PredictedLabelIndex. | ||
|
||
|
||
// Show prediction result for the 3rd example. | ||
var nativePrediction = nativePredictions[2]; | ||
// Console output: | ||
// Our predicted label to this example is "AA" with probability 0.922597349. | ||
Console.WriteLine("Our predicted label to this example is {0} with probability {1}", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
have a comment for the WriteLines showing here directly what the output would be. #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
nativeLabels[(int)nativePrediction.PredictedLabelIndex - 1], | ||
nativePrediction.Scores[(int)nativePrediction.PredictedLabelIndex - 1]); | ||
|
||
var expectedProbabilities = new float[] { 0.922597349f, 0.07508608f, 0.00221699756f, 9.95488E-05f }; | ||
// Scores and nativeLabels are two parallel attributes; that is, Scores[i] is the probability of being nativeLabels[i]. | ||
// Console output: | ||
// The probability of being class "AA" is 0.922597349. | ||
// The probability of being class "BB" is 0.07508608. | ||
// The probability of being class "CC" is 0.00221699756. | ||
// The probability of being class "DD" is 9.95488E-05. | ||
for (int i = 0; i < labelBuffer.Length; ++i) | ||
Console.WriteLine("The probability of being class {0} is {1}.", nativeLabels[i], nativePrediction.Scores[i]); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs to be referenced from some other file, otherwise it won't display in the documentation.
Maybe from the LightGBMStatics catalog, if we have one already? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will do it following pattern for MF. Please take a look at LightGbmStaticExtensions.cs in the next iteration.
In reply to: 244876344 [](ancestors = 244876344)