1
+ using System;
2
+ using System.Collections.Generic;
3
+ using System.Linq;
4
+ using Microsoft.ML;
5
+ using Microsoft.ML.Data;
6
+ using Microsoft.ML.SamplesUtils;
7
+ <# if (TrainerOptions != null) { #>
8
+ <#=OptionsInclude#>
9
+ <# } #>
10
+
11
+ namespace Samples.Dynamic.Trainers.MulticlassClassification
12
+ {
13
+ public static class <#=ClassName#>
14
+ {<#=Comments#>
15
+ public static void Example()
16
+ {
17
+ // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
18
+ // as a catalog of available operations and as the source of randomness.
19
+ // Setting the seed to a fixed number in this example to make outputs deterministic.
20
+ var mlContext = new MLContext(seed: 0);
21
+
22
+ // Create a list of training data points.
23
+ var dataPoints = GenerateRandomDataPoints(1000);
24
+
25
+ // Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
26
+ var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
27
+
28
+ <# if (MetaTrainer != null) { #>
29
+ // Define the trainer.
30
+ var pipeline =
31
+ // Convert the string labels into key types.
32
+ mlContext.Transforms.Conversion.MapValueToKey("Label")
33
+ // Apply <#=MetaTrainer#> multiclass meta trainer on top of binary trainer.
34
+ .Append(mlContext.MulticlassClassification.Trainers.<#=MetaTrainer#>(<#=Trainer#>()));
35
+ <# } else if (TrainerOptions == null) { #>
36
+ // Define the trainer.
37
+ var pipeline =
38
+ // Convert the string labels into key types.
39
+ mlContext.Transforms.Conversion.MapValueToKey("Label")
40
+ // Apply <#=Trainer#> multiclass trainer.
41
+ .Append(mlContext.MulticlassClassification.Trainers.<#=Trainer#>());
42
+ <# } else { #>
43
+ // Define trainer options.
44
+ var options = new <#=TrainerOptions#>;
45
+
46
+ // Define the trainer.
47
+ var pipeline = mlContext.MulticlassClassification.Trainers.<#=Trainer#>(options);
48
+ <# } #>
49
+
50
+ // Train the model.
51
+ var model = pipeline.Fit(trainingData);
52
+
53
+ // Create testing data. Use different random seed to make it different from training data.
54
+ var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
55
+
56
+ // Run the model on test data set.
57
+ var transformedTestData = model.Transform(testData);
58
+
59
+ // Convert IDataView object to a list.
60
+ var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList();
61
+
62
+ // Look at 5 predictions
63
+ foreach (var p in predictions.Take(5))
64
+ Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}");
65
+
66
+ <#=ExpectedOutputPerInstance#>
67
+
68
+ // Evaluate the overall metrics
69
+ var metrics = mlContext.MulticlassClassification.Evaluate(transformedTestData);
70
+ ConsoleUtils.PrintMetrics(metrics);
71
+
72
+ <#=ExpectedOutput#>
73
+ }
74
+
75
+ private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
76
+ {
77
+ var random = new Random(seed);
78
+ float randomFloat() => (float)random.NextDouble();
79
+ for (int i = 0; i < count; i++)
80
+ {
81
+ // Generate Labels that are integers 1, 2 or 3
82
+ var label = random.Next(1, 4);
83
+ yield return new DataPoint
84
+ {
85
+ Label = (uint)label,
86
+ // Create random features that are correlated with the label.
87
+ // The feature values are slightly increased by adding a constant multiple of label.
88
+ Features = Enumerable.Repeat(label, 20).Select(x => randomFloat() + label * 0.2f).ToArray()
89
+ };
90
+ }
91
+ }
92
+
93
+ // Example with label and 20 feature values. A data set is a collection of such examples.
94
+ private class DataPoint
95
+ {
96
+ public uint Label { get; set; }
97
+ [VectorType(20)]
98
+ public float[] Features { get; set; }
99
+ }
100
+
101
+ // Class used to capture predictions.
102
+ private class Prediction
103
+ {
104
+ // Original label.
105
+ public uint Label { get; set; }
106
+ // Predicted label from the trainer.
107
+ public uint PredictedLabel { get; set; }
108
+ }
109
+ }
110
+ }
0 commit comments