|
1 | 1 | using System;
|
| 2 | +using System.Collections.Generic; |
2 | 3 | using System.Linq;
|
3 | 4 | using Microsoft.ML;
|
4 | 5 | using Microsoft.ML.Data;
|
5 | 6 |
|
6 |
| -namespace Samples.Dynamic |
| 7 | +namespace Samples.Dynamic.Trainers.BinaryClassification |
7 | 8 | {
|
8 |
| - public static class FFMBinaryClassification |
| 9 | + public static class FieldAwareFactorizationMachine |
9 | 10 | {
|
| 11 | + // This example first train a field-aware factorization to binary classification, measure the trained model's quality, and finally |
| 12 | + // use the trained model to make prediction. |
10 | 13 | public static void Example()
|
11 | 14 | {
|
12 | 15 | // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
|
13 | 16 | // as a catalog of available operations and as the source of randomness.
|
14 |
| - var mlContext = new MLContext(); |
15 |
| - |
16 |
| - // Download and featurize the dataset. |
17 |
| - var dataviews = Microsoft.ML.SamplesUtils.DatasetUtils.LoadFeaturizedSentimentDataset(mlContext); |
18 |
| - var trainData = dataviews[0]; |
19 |
| - var testData = dataviews[1]; |
20 |
| - |
21 |
| - // ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to |
22 |
| - // expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially |
23 |
| - // helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a |
24 |
| - // cache step in a pipeline is also possible, please see the construction of pipeline below. |
25 |
| - trainData = mlContext.Data.Cache(trainData); |
26 |
| - |
27 |
| - // Step 2: Pipeline |
28 |
| - // Create the 'FieldAwareFactorizationMachine' binary classifier, setting the "Sentiment" column as the label of the dataset, and |
29 |
| - // the "Features" column as the features column. |
30 |
| - var pipeline = new EstimatorChain<ITransformer>().AppendCacheCheckpoint(mlContext) |
31 |
| - .Append(mlContext.BinaryClassification.Trainers. |
32 |
| - FieldAwareFactorizationMachine(labelColumnName: "Sentiment", featureColumnNames: new[] { "Features" })); |
33 |
| - |
34 |
| - // Fit the model. |
35 |
| - var model = pipeline.Fit(trainData); |
36 |
| - |
37 |
| - // Let's get the model parameters from the model. |
38 |
| - var modelParams = model.LastTransformer.Model; |
39 |
| - |
40 |
| - // Let's inspect the model parameters. |
41 |
| - var featureCount = modelParams.FeatureCount; |
42 |
| - var fieldCount = modelParams.FieldCount; |
43 |
| - var latentDim = modelParams.LatentDimension; |
44 |
| - var linearWeights = modelParams.GetLinearWeights(); |
45 |
| - var latentWeights = modelParams.GetLatentWeights(); |
46 |
| - |
47 |
| - Console.WriteLine("The feature count is: " + featureCount); |
48 |
| - Console.WriteLine("The number of fields is: " + fieldCount); |
49 |
| - Console.WriteLine("The latent dimension is: " + latentDim); |
50 |
| - Console.WriteLine("The linear weights of some of the features are: " + |
51 |
| - string.Concat(Enumerable.Range(1, 10).Select(i => $"{linearWeights[i]:F4} "))); |
52 |
| - Console.WriteLine("The weights of some of the latent features are: " + |
53 |
| - string.Concat(Enumerable.Range(1, 10).Select(i => $"{latentWeights[i]:F4} "))); |
54 |
| - |
55 |
| - // The feature count is: 9374 |
56 |
| - // The number of fields is: 1 |
57 |
| - // The latent dimension is: 20 |
58 |
| - // The linear weights of some of the features are: 0.0196 0.0000 -0.0045 -0.0205 0.0000 0.0032 0.0682 0.0091 -0.0151 0.0089 |
59 |
| - // The weights of some of the latent features are: 0.3316 0.2140 0.0752 0.0908 -0.0495 -0.0810 0.0761 0.0966 0.0090 -0.0962 |
60 |
| - |
61 |
| - // Evaluate how the model is doing on the test data. |
62 |
| - var dataWithPredictions = model.Transform(testData); |
63 |
| - |
64 |
| - var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions, "Sentiment"); |
65 |
| - Microsoft.ML.SamplesUtils.ConsoleUtils.PrintMetrics(metrics); |
66 |
| - |
67 |
| - // Accuracy: 0.72 |
68 |
| - // AUC: 0.75 |
69 |
| - // F1 Score: 0.74 |
70 |
| - // Negative Precision: 0.75 |
71 |
| - // Negative Recall: 0.67 |
72 |
| - // Positive Precision: 0.70 |
73 |
| - // Positive Recall: 0.78 |
| 17 | + // Setting the seed to a fixed number in this example to make outputs deterministic. |
| 18 | + var mlContext = new MLContext(seed: 0); |
| 19 | + |
| 20 | + // Create a list of training data points. |
| 21 | + IEnumerable<DataPoint> data = GenerateRandomDataPoints(500); |
| 22 | + |
| 23 | + // Convert the list of data points to an IDataView object, which is consumable by ML.NET API. |
| 24 | + var trainingData = mlContext.Data.LoadFromEnumerable(data); |
| 25 | + |
| 26 | + // Define the trainer. |
| 27 | + // This trainer trains field-aware factorization (FFM) for binary classification. See https://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf |
| 28 | + // for the theory behind and https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf for the training |
| 29 | + // algorithm implemented in ML.NET. |
| 30 | + var pipeline = mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine( |
| 31 | + // Specify three feature columns! |
| 32 | + new[] {nameof(DataPoint.Field0), nameof(DataPoint.Field1), nameof(DataPoint.Field2) }, |
| 33 | + // Specify binary label's column name. |
| 34 | + nameof(DataPoint.Label) ); |
| 35 | + |
| 36 | + // Train the model. |
| 37 | + var model = pipeline.Fit(trainingData); |
| 38 | + |
| 39 | + // Run the model on training data set. |
| 40 | + var transformedTrainingData = model.Transform(trainingData); |
| 41 | + |
| 42 | + // Measure the quality of the trained model. |
| 43 | + var metrics = mlContext.BinaryClassification.Evaluate(transformedTrainingData); |
| 44 | + |
| 45 | + // Show the quality metrics. |
| 46 | + PrintMetrics(metrics); |
| 47 | + |
| 48 | + // Expected output: |
| 49 | + // Accuracy: 0.99 |
| 50 | + // AUC: 1.00 |
| 51 | + // F1 Score: 0.99 |
| 52 | + // Negative Precision: 1.00 |
| 53 | + // Negative Recall: 0.98 |
| 54 | + // Positive Precision: 0.98 |
| 55 | + // Positive Recall: 1.00 |
| 56 | + // Log Loss: 0.17 |
| 57 | + // Log Loss Reduction: 0.83 |
| 58 | + // Entropy: 1.00 |
| 59 | + |
| 60 | + // Create prediction function from the trained model. |
| 61 | + var engine = mlContext.Model.CreatePredictionEngine<DataPoint, Result>(model); |
| 62 | + |
| 63 | + // Make some predictions. |
| 64 | + foreach(var dataPoint in data.Take(5)) |
| 65 | + { |
| 66 | + var result = engine.Predict(dataPoint); |
| 67 | + Console.WriteLine($"Actual label: {dataPoint.Label}, predicted label: {result.PredictedLabel}, " + |
| 68 | + $"score of being positive class: {result.Score}, and probability of beling positive class: {result.Probability}."); |
| 69 | + } |
| 70 | + |
| 71 | + // Expected output: |
| 72 | + // Actual label: True, predicted label: True, score of being positive class: 1.115094, and probability of beling positive class: 0.7530775. |
| 73 | + // Actual label: False, predicted label: False, score of being positive class: -3.478797, and probability of beling positive class: 0.02992158. |
| 74 | + // Actual label: True, predicted label: True, score of being positive class: 3.191896, and probability of beling positive class: 0.9605282. |
| 75 | + // Actual label: False, predicted label: False, score of being positive class: -3.400863, and probability of beling positive class: 0.03226851. |
| 76 | + // Actual label: True, predicted label: True, score of being positive class: 4.06056, and probability of beling positive class: 0.9830528. |
| 77 | + } |
| 78 | + |
| 79 | + // Number of features per field. |
| 80 | + const int featureLength = 5; |
| 81 | + |
| 82 | + // This class defines objects fed to the trained model. |
| 83 | + private class DataPoint |
| 84 | + { |
| 85 | + // Label. |
| 86 | + public bool Label { get; set; } |
| 87 | + |
| 88 | + // Features from the first field. Note that different fields can have different numbers of features. |
| 89 | + [VectorType(featureLength)] |
| 90 | + public float[] Field0 { get; set; } |
| 91 | + |
| 92 | + // Features from the second field. |
| 93 | + [VectorType(featureLength)] |
| 94 | + public float[] Field1 { get; set; } |
| 95 | + |
| 96 | + // Features from the thrid field. |
| 97 | + [VectorType(featureLength)] |
| 98 | + public float[] Field2 { get; set; } |
| 99 | + } |
| 100 | + |
| 101 | + // This class defines objects produced by trained model. The trained model maps |
| 102 | + // a DataPoint to a Result. |
| 103 | + public class Result |
| 104 | + { |
| 105 | + // Label. |
| 106 | + public bool Label { get; set; } |
| 107 | + // Predicted label. |
| 108 | + public bool PredictedLabel { get; set; } |
| 109 | + // Predicted score. |
| 110 | + public float Score { get; set; } |
| 111 | + // Probability of belonging to positive class. |
| 112 | + public float Probability { get; set; } |
| 113 | + } |
| 114 | + |
| 115 | + // Function used to create toy data sets. |
| 116 | + private static IEnumerable<DataPoint> GenerateRandomDataPoints(int exampleCount, int seed = 0) |
| 117 | + { |
| 118 | + var rnd = new Random(seed); |
| 119 | + var data = new List<DataPoint>(); |
| 120 | + for (int i = 0; i < exampleCount; ++i) |
| 121 | + { |
| 122 | + // Initialize an example with a random label and an empty feature vector. |
| 123 | + var sample = new DataPoint() |
| 124 | + { |
| 125 | + Label = rnd.Next() % 2 == 0, |
| 126 | + Field0 = new float[featureLength], |
| 127 | + Field1 = new float[featureLength], |
| 128 | + Field2 = new float[featureLength] |
| 129 | + }; |
| 130 | + |
| 131 | + // Fill feature vectors according the assigned label. |
| 132 | + // Notice that features from different fields have different biases and therefore different distributions. |
| 133 | + // In practices such as game recommendation, one may use one field to store features from user profile and |
| 134 | + // another field to store features from game profile. |
| 135 | + for (int j = 0; j < featureLength; ++j) |
| 136 | + { |
| 137 | + var value0 = (float)rnd.NextDouble(); |
| 138 | + // Positive class gets larger feature value. |
| 139 | + if (sample.Label) |
| 140 | + value0 += 0.2f; |
| 141 | + sample.Field0[j] = value0; |
| 142 | + |
| 143 | + var value1 = (float)rnd.NextDouble(); |
| 144 | + // Positive class gets smaller feature value. |
| 145 | + if (sample.Label) |
| 146 | + value1 -= 0.2f; |
| 147 | + sample.Field1[j] = value1; |
| 148 | + |
| 149 | + var value2 = (float)rnd.NextDouble(); |
| 150 | + // Positive class gets larger feature value. |
| 151 | + if (sample.Label) |
| 152 | + value2 += 0.8f; |
| 153 | + sample.Field2[j] = value2; |
| 154 | + } |
| 155 | + |
| 156 | + data.Add(sample); |
| 157 | + } |
| 158 | + return data; |
| 159 | + } |
| 160 | + |
| 161 | + // Function used to show evaluation metrics such as accuracy of predictions. |
| 162 | + private static void PrintMetrics(CalibratedBinaryClassificationMetrics metrics) |
| 163 | + { |
| 164 | + Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}"); |
| 165 | + Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}"); |
| 166 | + Console.WriteLine($"F1 Score: {metrics.F1Score:F2}"); |
| 167 | + Console.WriteLine($"Negative Precision: {metrics.NegativePrecision:F2}"); |
| 168 | + Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}"); |
| 169 | + Console.WriteLine($"Positive Precision: {metrics.PositivePrecision:F2}"); |
| 170 | + Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}"); |
| 171 | + Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}"); |
| 172 | + Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}"); |
| 173 | + Console.WriteLine($"Entropy: {metrics.Entropy:F2}"); |
74 | 174 | }
|
75 | 175 | }
|
76 | 176 | }
|
| 177 | + |
0 commit comments