Skip to content

Commit 738e5d5

Browse files
authored
Add samples in TT for FFM (#3312)
1 parent 5538ccf commit 738e5d5

8 files changed

+629
-226
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,177 @@
11
using System;
2+
using System.Collections.Generic;
23
using System.Linq;
34
using Microsoft.ML;
45
using Microsoft.ML.Data;
56

6-
namespace Samples.Dynamic
7+
namespace Samples.Dynamic.Trainers.BinaryClassification
78
{
8-
public static class FFMBinaryClassification
9+
public static class FieldAwareFactorizationMachine
910
{
11+
// This example first train a field-aware factorization to binary classification, measure the trained model's quality, and finally
12+
// use the trained model to make prediction.
1013
public static void Example()
1114
{
1215
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
1316
// as a catalog of available operations and as the source of randomness.
14-
var mlContext = new MLContext();
15-
16-
// Download and featurize the dataset.
17-
var dataviews = Microsoft.ML.SamplesUtils.DatasetUtils.LoadFeaturizedSentimentDataset(mlContext);
18-
var trainData = dataviews[0];
19-
var testData = dataviews[1];
20-
21-
// ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to
22-
// expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially
23-
// helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a
24-
// cache step in a pipeline is also possible, please see the construction of pipeline below.
25-
trainData = mlContext.Data.Cache(trainData);
26-
27-
// Step 2: Pipeline
28-
// Create the 'FieldAwareFactorizationMachine' binary classifier, setting the "Sentiment" column as the label of the dataset, and
29-
// the "Features" column as the features column.
30-
var pipeline = new EstimatorChain<ITransformer>().AppendCacheCheckpoint(mlContext)
31-
.Append(mlContext.BinaryClassification.Trainers.
32-
FieldAwareFactorizationMachine(labelColumnName: "Sentiment", featureColumnNames: new[] { "Features" }));
33-
34-
// Fit the model.
35-
var model = pipeline.Fit(trainData);
36-
37-
// Let's get the model parameters from the model.
38-
var modelParams = model.LastTransformer.Model;
39-
40-
// Let's inspect the model parameters.
41-
var featureCount = modelParams.FeatureCount;
42-
var fieldCount = modelParams.FieldCount;
43-
var latentDim = modelParams.LatentDimension;
44-
var linearWeights = modelParams.GetLinearWeights();
45-
var latentWeights = modelParams.GetLatentWeights();
46-
47-
Console.WriteLine("The feature count is: " + featureCount);
48-
Console.WriteLine("The number of fields is: " + fieldCount);
49-
Console.WriteLine("The latent dimension is: " + latentDim);
50-
Console.WriteLine("The linear weights of some of the features are: " +
51-
string.Concat(Enumerable.Range(1, 10).Select(i => $"{linearWeights[i]:F4} ")));
52-
Console.WriteLine("The weights of some of the latent features are: " +
53-
string.Concat(Enumerable.Range(1, 10).Select(i => $"{latentWeights[i]:F4} ")));
54-
55-
// The feature count is: 9374
56-
// The number of fields is: 1
57-
// The latent dimension is: 20
58-
// The linear weights of some of the features are: 0.0196 0.0000 -0.0045 -0.0205 0.0000 0.0032 0.0682 0.0091 -0.0151 0.0089
59-
// The weights of some of the latent features are: 0.3316 0.2140 0.0752 0.0908 -0.0495 -0.0810 0.0761 0.0966 0.0090 -0.0962
60-
61-
// Evaluate how the model is doing on the test data.
62-
var dataWithPredictions = model.Transform(testData);
63-
64-
var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions, "Sentiment");
65-
Microsoft.ML.SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
66-
67-
// Accuracy: 0.72
68-
// AUC: 0.75
69-
// F1 Score: 0.74
70-
// Negative Precision: 0.75
71-
// Negative Recall: 0.67
72-
// Positive Precision: 0.70
73-
// Positive Recall: 0.78
17+
// Setting the seed to a fixed number in this example to make outputs deterministic.
18+
var mlContext = new MLContext(seed: 0);
19+
20+
// Create a list of training data points.
21+
IEnumerable<DataPoint> data = GenerateRandomDataPoints(500);
22+
23+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
24+
var trainingData = mlContext.Data.LoadFromEnumerable(data);
25+
26+
// Define the trainer.
27+
// This trainer trains field-aware factorization (FFM) for binary classification. See https://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf
28+
// for the theory behind and https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf for the training
29+
// algorithm implemented in ML.NET.
30+
var pipeline = mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine(
31+
// Specify three feature columns!
32+
new[] {nameof(DataPoint.Field0), nameof(DataPoint.Field1), nameof(DataPoint.Field2) },
33+
// Specify binary label's column name.
34+
nameof(DataPoint.Label) );
35+
36+
// Train the model.
37+
var model = pipeline.Fit(trainingData);
38+
39+
// Run the model on training data set.
40+
var transformedTrainingData = model.Transform(trainingData);
41+
42+
// Measure the quality of the trained model.
43+
var metrics = mlContext.BinaryClassification.Evaluate(transformedTrainingData);
44+
45+
// Show the quality metrics.
46+
PrintMetrics(metrics);
47+
48+
// Expected output:
49+
// Accuracy: 0.99
50+
// AUC: 1.00
51+
// F1 Score: 0.99
52+
// Negative Precision: 1.00
53+
// Negative Recall: 0.98
54+
// Positive Precision: 0.98
55+
// Positive Recall: 1.00
56+
// Log Loss: 0.17
57+
// Log Loss Reduction: 0.83
58+
// Entropy: 1.00
59+
60+
// Create prediction function from the trained model.
61+
var engine = mlContext.Model.CreatePredictionEngine<DataPoint, Result>(model);
62+
63+
// Make some predictions.
64+
foreach(var dataPoint in data.Take(5))
65+
{
66+
var result = engine.Predict(dataPoint);
67+
Console.WriteLine($"Actual label: {dataPoint.Label}, predicted label: {result.PredictedLabel}, " +
68+
$"score of being positive class: {result.Score}, and probability of beling positive class: {result.Probability}.");
69+
}
70+
71+
// Expected output:
72+
// Actual label: True, predicted label: True, score of being positive class: 1.115094, and probability of beling positive class: 0.7530775.
73+
// Actual label: False, predicted label: False, score of being positive class: -3.478797, and probability of beling positive class: 0.02992158.
74+
// Actual label: True, predicted label: True, score of being positive class: 3.191896, and probability of beling positive class: 0.9605282.
75+
// Actual label: False, predicted label: False, score of being positive class: -3.400863, and probability of beling positive class: 0.03226851.
76+
// Actual label: True, predicted label: True, score of being positive class: 4.06056, and probability of beling positive class: 0.9830528.
77+
}
78+
79+
// Number of features per field.
80+
const int featureLength = 5;
81+
82+
// This class defines objects fed to the trained model.
83+
private class DataPoint
84+
{
85+
// Label.
86+
public bool Label { get; set; }
87+
88+
// Features from the first field. Note that different fields can have different numbers of features.
89+
[VectorType(featureLength)]
90+
public float[] Field0 { get; set; }
91+
92+
// Features from the second field.
93+
[VectorType(featureLength)]
94+
public float[] Field1 { get; set; }
95+
96+
// Features from the thrid field.
97+
[VectorType(featureLength)]
98+
public float[] Field2 { get; set; }
99+
}
100+
101+
// This class defines objects produced by trained model. The trained model maps
102+
// a DataPoint to a Result.
103+
public class Result
104+
{
105+
// Label.
106+
public bool Label { get; set; }
107+
// Predicted label.
108+
public bool PredictedLabel { get; set; }
109+
// Predicted score.
110+
public float Score { get; set; }
111+
// Probability of belonging to positive class.
112+
public float Probability { get; set; }
113+
}
114+
115+
// Function used to create toy data sets.
116+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int exampleCount, int seed = 0)
117+
{
118+
var rnd = new Random(seed);
119+
var data = new List<DataPoint>();
120+
for (int i = 0; i < exampleCount; ++i)
121+
{
122+
// Initialize an example with a random label and an empty feature vector.
123+
var sample = new DataPoint()
124+
{
125+
Label = rnd.Next() % 2 == 0,
126+
Field0 = new float[featureLength],
127+
Field1 = new float[featureLength],
128+
Field2 = new float[featureLength]
129+
};
130+
131+
// Fill feature vectors according the assigned label.
132+
// Notice that features from different fields have different biases and therefore different distributions.
133+
// In practices such as game recommendation, one may use one field to store features from user profile and
134+
// another field to store features from game profile.
135+
for (int j = 0; j < featureLength; ++j)
136+
{
137+
var value0 = (float)rnd.NextDouble();
138+
// Positive class gets larger feature value.
139+
if (sample.Label)
140+
value0 += 0.2f;
141+
sample.Field0[j] = value0;
142+
143+
var value1 = (float)rnd.NextDouble();
144+
// Positive class gets smaller feature value.
145+
if (sample.Label)
146+
value1 -= 0.2f;
147+
sample.Field1[j] = value1;
148+
149+
var value2 = (float)rnd.NextDouble();
150+
// Positive class gets larger feature value.
151+
if (sample.Label)
152+
value2 += 0.8f;
153+
sample.Field2[j] = value2;
154+
}
155+
156+
data.Add(sample);
157+
}
158+
return data;
159+
}
160+
161+
// Function used to show evaluation metrics such as accuracy of predictions.
162+
private static void PrintMetrics(CalibratedBinaryClassificationMetrics metrics)
163+
{
164+
Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
165+
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
166+
Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
167+
Console.WriteLine($"Negative Precision: {metrics.NegativePrecision:F2}");
168+
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
169+
Console.WriteLine($"Positive Precision: {metrics.PositivePrecision:F2}");
170+
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}");
171+
Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}");
172+
Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}");
173+
Console.WriteLine($"Entropy: {metrics.Entropy:F2}");
74174
}
75175
}
76176
}
177+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
<#@ include file="MultipleFeatureColumnsBinaryClassification.ttinclude"#>
2+
<#+
3+
string ClassName="FieldAwareFactorizationMachine";
4+
string Trainer = @"FieldAwareFactorizationMachine(
5+
// Specify three feature columns!
6+
new[] {nameof(DataPoint.Field0), nameof(DataPoint.Field1), nameof(DataPoint.Field2) },
7+
// Specify binary label's column name.
8+
nameof(DataPoint.Label) )";
9+
10+
string OptionsInclude = null;
11+
12+
string Comments = @"
13+
// This example first train a field-aware factorization to binary classification, measure the trained model's quality, and finally
14+
// use the trained model to make prediction.";
15+
16+
string TrainerDescription = @"// This trainer trains field-aware factorization (FFM) for binary classification. See https://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf
17+
// for the theory behind and https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf for the training
18+
// algorithm implemented in ML.NET.";
19+
20+
string TrainerOptions = null;
21+
22+
string ExpectedOutputPerInstance= @"// Expected output:
23+
// Actual label: True, predicted label: True, score of being positive class: 1.115094, and probability of beling positive class: 0.7530775.
24+
// Actual label: False, predicted label: False, score of being positive class: -3.478797, and probability of beling positive class: 0.02992158.
25+
// Actual label: True, predicted label: True, score of being positive class: 3.191896, and probability of beling positive class: 0.9605282.
26+
// Actual label: False, predicted label: False, score of being positive class: -3.400863, and probability of beling positive class: 0.03226851.
27+
// Actual label: True, predicted label: True, score of being positive class: 4.06056, and probability of beling positive class: 0.9830528.";
28+
29+
string ExpectedOutput = @"// Expected output:
30+
// Accuracy: 0.99
31+
// AUC: 1.00
32+
// F1 Score: 0.99
33+
// Negative Precision: 1.00
34+
// Negative Recall: 0.98
35+
// Positive Precision: 0.98
36+
// Positive Recall: 1.00
37+
// Log Loss: 0.17
38+
// Log Loss Reduction: 0.83
39+
// Entropy: 1.00";
40+
#>

0 commit comments

Comments
 (0)