Skip to content

Commit c871f69

Browse files
committed
review comments
1 parent 73dc2d4 commit c871f69

11 files changed

+948
-12
lines changed
Original file line numberDiff line numberDiff line change
@@ -1 +1,109 @@
1-
ErrorGeneratingOutput
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
7+
namespace Samples.Dynamic.Trainers.MulticlassClassification
8+
{
9+
public static class LbfgsMaximumEntropy
10+
{
11+
public static void Example()
12+
{
13+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
14+
// as a catalog of available operations and as the source of randomness.
15+
// Setting the seed to a fixed number in this example to make outputs deterministic.
16+
var mlContext = new MLContext(seed: 0);
17+
18+
// Create a list of training data points.
19+
var dataPoints = GenerateRandomDataPoints(1000);
20+
21+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
22+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
23+
24+
// Define the trainer.
25+
var pipeline =
26+
// Convert the string labels into key types.
27+
mlContext.Transforms.Conversion.MapValueToKey(nameof(DataPoint.Label))
28+
// Apply LbfgsMaximumEntropy multiclass trainer.
29+
.Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy());
30+
31+
// Train the model.
32+
var model = pipeline.Fit(trainingData);
33+
34+
// Create testing data. Use different random seed to make it different from training data.
35+
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123));
36+
37+
// Run the model on test data set.
38+
var transformedTestData = model.Transform(testData);
39+
40+
// Convert IDataView object to a list.
41+
var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList();
42+
43+
// Look at 5 predictions
44+
foreach (var p in predictions.Take(5))
45+
Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}");
46+
47+
// Expected output:
48+
// Label: 1, Prediction: 1
49+
// Label: 2, Prediction: 2
50+
// Label: 3, Prediction: 2
51+
// Label: 2, Prediction: 2
52+
// Label: 3, Prediction: 3
53+
54+
// Evaluate the overall metrics
55+
var metrics = mlContext.MulticlassClassification.Evaluate(transformedTestData);
56+
PrintMetrics(metrics);
57+
58+
// Expected output:
59+
// Micro Accuracy: 0.91
60+
// Macro Accuracy: 0.91
61+
// Log Loss: 0.24
62+
// Log Loss Reduction: 0.78
63+
}
64+
65+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
66+
{
67+
var random = new Random(seed);
68+
float randomFloat() => (float)random.NextDouble();
69+
for (int i = 0; i < count; i++)
70+
{
71+
// Generate Labels that are integers 1, 2 or 3
72+
var label = random.Next(1, 4);
73+
yield return new DataPoint
74+
{
75+
Label = (uint)label,
76+
// Create random features that are correlated with the label.
77+
// The feature values are slightly increased by adding a constant multiple of label.
78+
Features = Enumerable.Repeat(label, 20).Select(x => randomFloat() + label * 0.2f).ToArray()
79+
};
80+
}
81+
}
82+
83+
// Example with label and 20 feature values. A data set is a collection of such examples.
84+
private class DataPoint
85+
{
86+
public uint Label { get; set; }
87+
[VectorType(20)]
88+
public float[] Features { get; set; }
89+
}
90+
91+
// Class used to capture predictions.
92+
private class Prediction
93+
{
94+
// Original label.
95+
public uint Label { get; set; }
96+
// Predicted label from the trainer.
97+
public uint PredictedLabel { get; set; }
98+
}
99+
100+
// Pretty-print MulticlassClassificationMetrics objects.
101+
public static void PrintMetrics(MulticlassClassificationMetrics metrics)
102+
{
103+
Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}");
104+
Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}");
105+
Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}");
106+
Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}");
107+
}
108+
}
109+
}
Original file line numberDiff line numberDiff line change
@@ -1 +1,119 @@
1-
ErrorGeneratingOutput
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
using Microsoft.ML.Trainers;
7+
8+
namespace Samples.Dynamic.Trainers.MulticlassClassification
9+
{
10+
public static class LbfgsMaximumEntropyWithOptions
11+
{
12+
public static void Example()
13+
{
14+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
15+
// as a catalog of available operations and as the source of randomness.
16+
// Setting the seed to a fixed number in this example to make outputs deterministic.
17+
var mlContext = new MLContext(seed: 0);
18+
19+
// Create a list of training data points.
20+
var dataPoints = GenerateRandomDataPoints(1000);
21+
22+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
23+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
24+
25+
// Define trainer options.
26+
var options = new LbfgsMaximumEntropyMulticlassTrainer.Options
27+
{
28+
HistorySize = 50,
29+
L1Regularization = 0.1f,
30+
NumberOfThreads = 1
31+
};
32+
33+
// Define the trainer.
34+
var pipeline =
35+
// Convert the string labels into key types.
36+
mlContext.Transforms.Conversion.MapValueToKey("Label")
37+
// Apply LbfgsMaximumEntropy multiclass trainer.
38+
.Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(options));
39+
40+
41+
// Train the model.
42+
var model = pipeline.Fit(trainingData);
43+
44+
// Create testing data. Use different random seed to make it different from training data.
45+
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123));
46+
47+
// Run the model on test data set.
48+
var transformedTestData = model.Transform(testData);
49+
50+
// Convert IDataView object to a list.
51+
var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList();
52+
53+
// Look at 5 predictions
54+
foreach (var p in predictions.Take(5))
55+
Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}");
56+
57+
// Expected output:
58+
// Label: 1, Prediction: 1
59+
// Label: 2, Prediction: 2
60+
// Label: 3, Prediction: 2
61+
// Label: 2, Prediction: 2
62+
// Label: 3, Prediction: 3
63+
64+
// Evaluate the overall metrics
65+
var metrics = mlContext.MulticlassClassification.Evaluate(transformedTestData);
66+
PrintMetrics(metrics);
67+
68+
// Expected output:
69+
// Micro Accuracy: 0.91
70+
// Macro Accuracy: 0.91
71+
// Log Loss: 0.22
72+
// Log Loss Reduction: 0.80
73+
}
74+
75+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
76+
{
77+
var random = new Random(seed);
78+
float randomFloat() => (float)random.NextDouble();
79+
for (int i = 0; i < count; i++)
80+
{
81+
// Generate Labels that are integers 1, 2 or 3
82+
var label = random.Next(1, 4);
83+
yield return new DataPoint
84+
{
85+
Label = (uint)label,
86+
// Create random features that are correlated with the label.
87+
// The feature values are slightly increased by adding a constant multiple of label.
88+
Features = Enumerable.Repeat(label, 20).Select(x => randomFloat() + label * 0.2f).ToArray()
89+
};
90+
}
91+
}
92+
93+
// Example with label and 20 feature values. A data set is a collection of such examples.
94+
private class DataPoint
95+
{
96+
public uint Label { get; set; }
97+
[VectorType(20)]
98+
public float[] Features { get; set; }
99+
}
100+
101+
// Class used to capture predictions.
102+
private class Prediction
103+
{
104+
// Original label.
105+
public uint Label { get; set; }
106+
// Predicted label from the trainer.
107+
public uint PredictedLabel { get; set; }
108+
}
109+
110+
// Pretty-print MulticlassClassificationMetrics objects.
111+
public static void PrintMetrics(MulticlassClassificationMetrics metrics)
112+
{
113+
Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}");
114+
Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}");
115+
Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}");
116+
Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}");
117+
}
118+
}
119+
}
Original file line numberDiff line numberDiff line change
@@ -1 +1,111 @@
1-
ErrorGeneratingOutput
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
7+
namespace Samples.Dynamic.Trainers.MulticlassClassification
8+
{
9+
public static class LightGbm
10+
{
11+
// This example requires installation of additional NuGet package
12+
// <a href="https://www.nuget.org/packages/Microsoft.ML.FastTree/">Microsoft.ML.FastTree</a>.
13+
public static void Example()
14+
{
15+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
16+
// as a catalog of available operations and as the source of randomness.
17+
// Setting the seed to a fixed number in this example to make outputs deterministic.
18+
var mlContext = new MLContext(seed: 0);
19+
20+
// Create a list of training data points.
21+
var dataPoints = GenerateRandomDataPoints(1000);
22+
23+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
24+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
25+
26+
// Define the trainer.
27+
var pipeline =
28+
// Convert the string labels into key types.
29+
mlContext.Transforms.Conversion.MapValueToKey(nameof(DataPoint.Label))
30+
// Apply LightGbm multiclass trainer.
31+
.Append(mlContext.MulticlassClassification.Trainers.LightGbm());
32+
33+
// Train the model.
34+
var model = pipeline.Fit(trainingData);
35+
36+
// Create testing data. Use different random seed to make it different from training data.
37+
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123));
38+
39+
// Run the model on test data set.
40+
var transformedTestData = model.Transform(testData);
41+
42+
// Convert IDataView object to a list.
43+
var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList();
44+
45+
// Look at 5 predictions
46+
foreach (var p in predictions.Take(5))
47+
Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}");
48+
49+
// Expected output:
50+
// Label: 1, Prediction: 1
51+
// Label: 2, Prediction: 2
52+
// Label: 3, Prediction: 3
53+
// Label: 2, Prediction: 2
54+
// Label: 3, Prediction: 3
55+
56+
// Evaluate the overall metrics
57+
var metrics = mlContext.MulticlassClassification.Evaluate(transformedTestData);
58+
PrintMetrics(metrics);
59+
60+
// Expected output:
61+
// Micro Accuracy: 0.99
62+
// Macro Accuracy: 0.99
63+
// Log Loss: 0.05
64+
// Log Loss Reduction: 0.96
65+
}
66+
67+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
68+
{
69+
var random = new Random(seed);
70+
float randomFloat() => (float)random.NextDouble();
71+
for (int i = 0; i < count; i++)
72+
{
73+
// Generate Labels that are integers 1, 2 or 3
74+
var label = random.Next(1, 4);
75+
yield return new DataPoint
76+
{
77+
Label = (uint)label,
78+
// Create random features that are correlated with the label.
79+
// The feature values are slightly increased by adding a constant multiple of label.
80+
Features = Enumerable.Repeat(label, 20).Select(x => randomFloat() + label * 0.2f).ToArray()
81+
};
82+
}
83+
}
84+
85+
// Example with label and 20 feature values. A data set is a collection of such examples.
86+
private class DataPoint
87+
{
88+
public uint Label { get; set; }
89+
[VectorType(20)]
90+
public float[] Features { get; set; }
91+
}
92+
93+
// Class used to capture predictions.
94+
private class Prediction
95+
{
96+
// Original label.
97+
public uint Label { get; set; }
98+
// Predicted label from the trainer.
99+
public uint PredictedLabel { get; set; }
100+
}
101+
102+
// Pretty-print MulticlassClassificationMetrics objects.
103+
public static void PrintMetrics(MulticlassClassificationMetrics metrics)
104+
{
105+
Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}");
106+
Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}");
107+
Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}");
108+
Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}");
109+
}
110+
}
111+
}

0 commit comments

Comments
 (0)