Skip to content

Commit 73dc2d4

Browse files
committed
multiclasssamples
1 parent 738e5d5 commit 73dc2d4

25 files changed

+369
-330
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ErrorGeneratingOutput
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
<#@ include file="MulticlassClassification.ttinclude"#>
2+
<#+
3+
string ClassName = "LbfgsMaximumEntropy";
4+
string Trainer = "LbfgsMaximumEntropy";
5+
string MetaTrainer = null;
6+
string TrainerOptions = null;
7+
8+
string OptionsInclude = "";
9+
string Comments = "";
10+
bool CacheData = false;
11+
12+
string ExpectedOutputPerInstance = @"// Expected output:
13+
// Label: 1, Prediction: 1
14+
// Label: 2, Prediction: 2
15+
// Label: 3, Prediction: 2
16+
// Label: 2, Prediction: 2
17+
// Label: 3, Prediction: 3";
18+
19+
string ExpectedOutput = @"// Expected output:
20+
// Micro Accuracy: 0.91
21+
// Macro Accuracy: 0.91
22+
// Log Loss: 0.24
23+
// Log Loss Reduction: 0.78";
24+
#>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ErrorGeneratingOutput
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<#@ include file="MulticlassClassification.ttinclude"#>
2+
<#+
3+
string ClassName = "LbfgsMaximumEntropyWithOptions";
4+
string Trainer = "LbfgsMaximumEntropy";
5+
string MetaTrainer = null;
6+
string TrainerOptions = @"LbfgsMaximumEntropyMulticlassTrainer.Options
7+
{
8+
HistorySize = 50,
9+
L1Regularization = 0.1f,
10+
NumberOfThreads = 1
11+
}";
12+
13+
string OptionsInclude = "using Microsoft.ML.Trainers;";
14+
string Comments = "";
15+
16+
bool CacheData = false;
17+
18+
string ExpectedOutputPerInstance = @"// Expected output:
19+
// Label: 1, Prediction: 1
20+
// Label: 2, Prediction: 2
21+
// Label: 3, Prediction: 2
22+
// Label: 2, Prediction: 2
23+
// Label: 3, Prediction: 3";
24+
25+
string ExpectedOutput = @"// Expected output:
26+
// Micro Accuracy: 0.91
27+
// Macro Accuracy: 0.91
28+
// Log Loss: 0.22
29+
// Log Loss Reduction: 0.80";
30+
#>
Original file line numberDiff line numberDiff line change
@@ -1,86 +1 @@
1-
using System;
2-
using System.Linq;
3-
using Microsoft.ML;
4-
using Microsoft.ML.Data;
5-
using Microsoft.ML.SamplesUtils;
6-
7-
namespace Samples.Dynamic.Trainers.MulticlassClassification
8-
{
9-
public static class LightGbm
10-
{
11-
// This example requires installation of additional nuget package <a href="https://www.nuget.org/packages/Microsoft.ML.LightGbm/">Microsoft.ML.LightGbm</a>.
12-
public static void Example()
13-
{
14-
// Create a general context for ML.NET operations. It can be used for exception tracking and logging,
15-
// as a catalog of available operations and as the source of randomness.
16-
var mlContext = new MLContext();
17-
18-
// Create a list of data examples.
19-
var examples = DatasetUtils.GenerateRandomMulticlassClassificationExamples(1000);
20-
21-
// Convert the examples list to an IDataView object, which is consumable by ML.NET API.
22-
var dataView = mlContext.Data.LoadFromEnumerable(examples);
23-
24-
//////////////////// Data Preview ////////////////////
25-
// Label Features
26-
// AA 0.7262433,0.8173254,0.7680227,0.5581612,0.2060332,0.5588848,0.9060271,0.4421779,0.9775497,0.2737045
27-
// BB 0.4919063,0.6673147,0.8326591,0.6695119,1.182151,0.230367,1.06237,1.195347,0.8771811,0.5145918
28-
// CC 1.216908,1.248052,1.391902,0.4326252,1.099942,0.9262842,1.334019,1.08762,0.9468155,0.4811099
29-
// DD 0.7871246,1.053327,0.8971719,1.588544,1.242697,1.362964,0.6303943,0.9810045,0.9431419,1.557455
30-
31-
// Create a pipeline.
32-
// - Convert the string labels into key types.
33-
// - Apply LightGbm multiclass trainer.
34-
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelIndex", "Label")
35-
.Append(mlContext.MulticlassClassification.Trainers.LightGbm(labelColumnName: "LabelIndex"))
36-
.Append(mlContext.Transforms.Conversion.MapValueToKey("PredictedLabelIndex", "PredictedLabel"))
37-
.Append(mlContext.Transforms.CopyColumns("Scores", "Score"));
38-
39-
// Split the static-typed data into training and test sets. Only training set is used in fitting
40-
// the created pipeline. Metrics are computed on the test.
41-
var split = mlContext.Data.TrainTestSplit(dataView, testFraction: 0.5);
42-
43-
// Train the model.
44-
var model = pipeline.Fit(split.TrainSet);
45-
46-
// Do prediction on the test set.
47-
var dataWithPredictions = model.Transform(split.TestSet);
48-
49-
// Evaluate the trained model using the test set.
50-
var metrics = mlContext.MulticlassClassification.Evaluate(dataWithPredictions, labelColumnName: "LabelIndex");
51-
52-
// Check if metrics are reasonable.
53-
Console.WriteLine($"Macro accuracy: {metrics.MacroAccuracy:F4}, Micro accuracy: {metrics.MicroAccuracy:F4}.");
54-
// Console output:
55-
// Macro accuracy: 0.8655, Micro accuracy: 0.8651.
56-
57-
// IDataView with predictions, to an IEnumerable<DatasetUtils.MulticlassClassificationExample>.
58-
var nativePredictions = mlContext.Data.CreateEnumerable<DatasetUtils.MulticlassClassificationExample>(dataWithPredictions, false).ToList();
59-
60-
// Get schema object out of the prediction. It contains annotations such as the mapping from predicted label index
61-
// (e.g., 1) to its actual label (e.g., "AA").
62-
// The annotations can be used to get all the unique labels used during training.
63-
var labelBuffer = new VBuffer<ReadOnlyMemory<char>>();
64-
dataWithPredictions.Schema["PredictedLabelIndex"].GetKeyValues(ref labelBuffer);
65-
// nativeLabels is { "AA" , "BB", "CC", "DD" }
66-
var nativeLabels = labelBuffer.DenseValues().ToArray(); // nativeLabels[nativePrediction.PredictedLabelIndex - 1] is the original label indexed by nativePrediction.PredictedLabelIndex.
67-
68-
69-
// Show prediction result for the 3rd example.
70-
var nativePrediction = nativePredictions[2];
71-
// Console output:
72-
// Our predicted label to this example is "AA" with probability 0.9257.
73-
Console.WriteLine($"Our predicted label to this example is {nativeLabels[(int)nativePrediction.PredictedLabelIndex - 1]} " +
74-
$"with probability {nativePrediction.Scores[(int)nativePrediction.PredictedLabelIndex - 1]:F4}.");
75-
76-
// Scores and nativeLabels are two parallel attributes; that is, Scores[i] is the probability of being nativeLabels[i].
77-
// Console output:
78-
// The probability of being class "AA" is 0.9257.
79-
// The probability of being class "BB" is 0.0739.
80-
// The probability of being class "CC" is 0.0002.
81-
// The probability of being class "DD" is 0.0001.
82-
for (int i = 0; i < nativeLabels.Length; ++i)
83-
Console.WriteLine($"The probability of being class {nativeLabels[i]} is {nativePrediction.Scores[i]:F4}.");
84-
}
85-
}
86-
}
1+
ErrorGeneratingOutput
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
<#@ include file="MulticlassClassification.ttinclude"#>
2+
<#+
3+
string ClassName = "LightGbm";
4+
string Trainer = "LightGbm";
5+
string MetaTrainer = null;
6+
string TrainerOptions = null;
7+
8+
string OptionsInclude = "";
9+
string Comments = @"
10+
// This example requires installation of additional NuGet package
11+
// <a href=""https://www.nuget.org/packages/Microsoft.ML.FastTree/"">Microsoft.ML.FastTree</a>.";
12+
bool CacheData = false;
13+
14+
string ExpectedOutputPerInstance = @"// Expected output:
15+
// Label: 1, Prediction: 1
16+
// Label: 2, Prediction: 2
17+
// Label: 3, Prediction: 3
18+
// Label: 2, Prediction: 2
19+
// Label: 3, Prediction: 3";
20+
21+
string ExpectedOutput = @"// Expected output:
22+
// Micro Accuracy: 0.99
23+
// Macro Accuracy: 0.99
24+
// Log Loss: 0.05
25+
// Log Loss Reduction: 0.96";
26+
#>
Original file line numberDiff line numberDiff line change
@@ -1,96 +1 @@
1-
using System;
2-
using System.Linq;
3-
using Microsoft.ML;
4-
using Microsoft.ML.Data;
5-
using Microsoft.ML.SamplesUtils;
6-
using Microsoft.ML.Trainers.LightGbm;
7-
8-
namespace Samples.Dynamic.Trainers.MulticlassClassification
9-
{
10-
public static class LightGbmWithOptions
11-
{
12-
// This example requires installation of additional nuget package <a href="https://www.nuget.org/packages/Microsoft.ML.LightGbm/">Microsoft.ML.LightGbm</a>.
13-
public static void Example()
14-
{
15-
// Create a general context for ML.NET operations. It can be used for exception tracking and logging,
16-
// as a catalog of available operations and as the source of randomness.
17-
var mlContext = new MLContext(seed: 0);
18-
19-
// Create a list of data examples.
20-
var examples = DatasetUtils.GenerateRandomMulticlassClassificationExamples(1000);
21-
22-
// Convert the examples list to an IDataView object, which is consumable by ML.NET API.
23-
var dataView = mlContext.Data.LoadFromEnumerable(examples);
24-
25-
//////////////////// Data Preview ////////////////////
26-
// Label Features
27-
// AA 0.7262433,0.8173254,0.7680227,0.5581612,0.2060332,0.5588848,0.9060271,0.4421779,0.9775497,0.2737045
28-
// BB 0.4919063,0.6673147,0.8326591,0.6695119,1.182151,0.230367,1.06237,1.195347,0.8771811,0.5145918
29-
// CC 1.216908,1.248052,1.391902,0.4326252,1.099942,0.9262842,1.334019,1.08762,0.9468155,0.4811099
30-
// DD 0.7871246,1.053327,0.8971719,1.588544,1.242697,1.362964,0.6303943,0.9810045,0.9431419,1.557455
31-
32-
// Create a pipeline.
33-
// - Convert the string labels into key types.
34-
// - Apply LightGbm multiclass trainer with advanced options.
35-
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelIndex", "Label")
36-
.Append(mlContext.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options
37-
{
38-
LabelColumnName = "LabelIndex",
39-
FeatureColumnName = "Features",
40-
Booster = new DartBooster.Options()
41-
{
42-
TreeDropFraction = 0.15,
43-
XgboostDartMode = false
44-
}
45-
}))
46-
.Append(mlContext.Transforms.Conversion.MapValueToKey("PredictedLabelIndex", "PredictedLabel"))
47-
.Append(mlContext.Transforms.CopyColumns("Scores", "Score"));
48-
49-
// Split the static-typed data into training and test sets. Only training set is used in fitting
50-
// the created pipeline. Metrics are computed on the test.
51-
var split = mlContext.Data.TrainTestSplit(dataView, testFraction: 0.5);
52-
53-
// Train the model.
54-
var model = pipeline.Fit(split.TrainSet);
55-
56-
// Do prediction on the test set.
57-
var dataWithPredictions = model.Transform(split.TestSet);
58-
59-
// Evaluate the trained model using the test set.
60-
var metrics = mlContext.MulticlassClassification.Evaluate(dataWithPredictions, labelColumnName: "LabelIndex");
61-
62-
// Check if metrics are reasonable.
63-
Console.WriteLine($"Macro accuracy: {metrics.MacroAccuracy:F4}, Micro accuracy: {metrics.MicroAccuracy:F4}.");
64-
// Console output:
65-
// Macro accuracy: 0.8619, Micro accuracy: 0.8611.
66-
67-
// IDataView with predictions, to an IEnumerable<DatasetUtils.MulticlassClassificationExample>.
68-
var nativePredictions = mlContext.Data.CreateEnumerable<DatasetUtils.MulticlassClassificationExample>(dataWithPredictions, false).ToList();
69-
70-
// Get schema object out of the prediction. It contains metadata such as the mapping from predicted label index
71-
// (e.g., 1) to its actual label (e.g., "AA").
72-
// The metadata can be used to get all the unique labels used during training.
73-
var labelBuffer = new VBuffer<ReadOnlyMemory<char>>();
74-
dataWithPredictions.Schema["PredictedLabelIndex"].GetKeyValues(ref labelBuffer);
75-
// nativeLabels is { "AA" , "BB", "CC", "DD" }
76-
var nativeLabels = labelBuffer.DenseValues().ToArray(); // nativeLabels[nativePrediction.PredictedLabelIndex - 1] is the original label indexed by nativePrediction.PredictedLabelIndex.
77-
78-
79-
// Show prediction result for the 3rd example.
80-
var nativePrediction = nativePredictions[2];
81-
// Console output:
82-
// Our predicted label to this example is AA with probability 0.8986.
83-
Console.WriteLine($"Our predicted label to this example is {nativeLabels[(int)nativePrediction.PredictedLabelIndex - 1]} " +
84-
$"with probability {nativePrediction.Scores[(int)nativePrediction.PredictedLabelIndex - 1]:F4}.");
85-
86-
// Scores and nativeLabels are two parallel attributes; that is, Scores[i] is the probability of being nativeLabels[i].
87-
// Console output:
88-
// The probability of being class AA is 0.8986.
89-
// The probability of being class BB is 0.0961.
90-
// The probability of being class CC is 0.0050.
91-
// The probability of being class DD is 0.0003.
92-
for (int i = 0; i < nativeLabels.Length; ++i)
93-
Console.WriteLine($"The probability of being class {nativeLabels[i]} is {nativePrediction.Scores[i]:F4}.");
94-
}
95-
}
96-
}
1+
ErrorGeneratingOutput
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
<#@ include file="MulticlassClassification.ttinclude"#>
2+
<#+
3+
string ClassName = "LightGbmWithOptions";
4+
string Trainer = "LightGbm";
5+
string MetaTrainer = null;
6+
string TrainerOptions = @"LightGbmMulticlassTrainer.Options
7+
{
8+
Booster = new DartBooster.Options()
9+
{
10+
TreeDropFraction = 0.15,
11+
XgboostDartMode = false
12+
}
13+
}";
14+
15+
string OptionsInclude = "using Microsoft.ML.Trainers.LightGbm;";
16+
string Comments = @"
17+
// This example requires installation of additional NuGet package
18+
// <a href=""https://www.nuget.org/packages/Microsoft.ML.FastTree/"">Microsoft.ML.FastTree</a>.";
19+
20+
bool CacheData = false;
21+
22+
string ExpectedOutputPerInstance = @"// Expected output:
23+
// Label: 1, Prediction: 1
24+
// Label: 2, Prediction: 2
25+
// Label: 3, Prediction: 3
26+
// Label: 2, Prediction: 2
27+
// Label: 3, Prediction: 3";
28+
29+
string ExpectedOutput = @"// Expected output:
30+
// Micro Accuracy: 0.98
31+
// Macro Accuracy: 0.98
32+
// Log Loss: 0.06
33+
// Log Loss Reduction: 0.80";
34+
#>

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/MulticlassClassification.ttinclude

+24-6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ namespace Samples.Dynamic.Trainers.MulticlassClassification
2323

2424
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
2525
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
26+
<# if (CacheData) { #>
27+
28+
// ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times,
29+
// it can be slow due to expensive featurization and disk operations. When the considered data can fit into memory,
30+
// a solution is to cache the data in memory. Caching is especially helpful when working with iterative algorithms
31+
// which needs many data passes.
32+
trainingData = mlContext.Data.Cache(trainingData);
33+
<# } #>
2634

2735
<# if (MetaTrainer != null) { #>
2836
// Define the trainer.
@@ -43,7 +51,12 @@ namespace Samples.Dynamic.Trainers.MulticlassClassification
4351
var options = new <#=TrainerOptions#>;
4452

4553
// Define the trainer.
46-
var pipeline = mlContext.MulticlassClassification.Trainers.<#=Trainer#>(options);
54+
var pipeline =
55+
// Convert the string labels into key types.
56+
mlContext.Transforms.Conversion.MapValueToKey("Label")
57+
// Apply <#=Trainer#> multiclass trainer.
58+
.Append(mlContext.MulticlassClassification.Trainers.<#=Trainer#>(options));
59+
4760
<# } #>
4861

4962
// Train the model.
@@ -66,11 +79,7 @@ namespace Samples.Dynamic.Trainers.MulticlassClassification
6679

6780
// Evaluate the overall metrics
6881
var metrics = mlContext.MulticlassClassification.Evaluate(transformedTestData);
69-
Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}");
70-
Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}");
71-
Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}");
72-
Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}");
73-
82+
PrintMetrics(metrics);
7483

7584
<#=ExpectedOutput#>
7685
}
@@ -110,5 +119,14 @@ namespace Samples.Dynamic.Trainers.MulticlassClassification
110119
// Predicted label from the trainer.
111120
public uint PredictedLabel { get; set; }
112121
}
122+
123+
// Pretty-print MulticlassClassificationMetrics objects.
124+
public static void PrintMetrics(MulticlassClassificationMetrics metrics)
125+
{
126+
Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}");
127+
Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}");
128+
Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}");
129+
Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}");
130+
}
113131
}
114132
}

0 commit comments

Comments
 (0)