Skip to content

Commit 19d25e2

Browse files
committed
Scrubbing FieldAwareFactorizationMachine learner.
1 parent 4acf5aa commit 19d25e2

File tree

11 files changed

+206
-85
lines changed

11 files changed

+206
-85
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public static void Example()
1616
// This will create a sentiment.tsv file in the filesystem.
1717
// The string, dataFile, is the path to the downloaded file.
1818
// You can open this file, if you want to see the data.
19-
string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
19+
(string dataFile, _ ) = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
2020

2121
// A preview of the data.
2222
// Sentiment SentimentText

docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs

-71
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using System;
2+
using System.Linq;
3+
using Microsoft.ML.Data;
4+
namespace Microsoft.ML.Samples.Dynamic
5+
{
6+
public static class FFMBinaryClassification
7+
{
8+
public static void Example()
9+
{
10+
// Creating the ML.Net IHostEnvironment object, needed for the pipeline.
11+
var mlContext = new MLContext();
12+
13+
// Download and featurize the dataset.
14+
(var trainData, var testData) = SamplesUtils.DatasetUtils.LoadFeaturizedSentimentDataset(mlContext);
15+
16+
// ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to
17+
// expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially
18+
// helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a
19+
// cache step in a pipeline is also possible, please see the construction of pipeline below.
20+
trainData = mlContext.Data.Cache(trainData);
21+
22+
// Step 2: Pipeline
23+
// Create the 'FieldAwareFactorizationMachine' binary classifier, setting the "Sentiment" column as the label of the dataset, and
24+
// the "Features" column as the features column.
25+
var pipeline = new EstimatorChain<ITransformer>().AppendCacheCheckpoint(mlContext)
26+
.Append(mlContext.BinaryClassification.Trainers.
27+
FieldAwareFactorizationMachine(labelColumnName: "Sentiment", featureColumnNames: new[] { "Features" }));
28+
29+
// Fit the model.
30+
var model = pipeline.Fit(trainData);
31+
32+
// Let's get the model parameters from the model.
33+
var modelParams = model.LastTransformer.Model;
34+
35+
// Let's inspect the model parameters.
36+
var featureCount = modelParams.GetFeatureCount();
37+
var fieldCount = modelParams.GetFieldCount();
38+
var latentDim = modelParams.GetLatentDim();
39+
var linearWeights = modelParams.GetLinearWeights();
40+
var latentWeights = modelParams.GetLatentWeights();
41+
42+
Console.WriteLine("The feature count is: " + featureCount);
43+
Console.WriteLine("The number of fields is: " + fieldCount);
44+
Console.WriteLine("The latent dimension is: " + latentDim);
45+
Console.WriteLine("The linear weights of some of the features are: " +
46+
string.Concat(Enumerable.Range(1, 10).Select(i => $"{linearWeights[i]:F4} ")));
47+
Console.WriteLine("The weights of some of the latent features are: " +
48+
string.Concat(Enumerable.Range(1, 10).Select(i => $"{latentWeights[i]:F4} ")));
49+
50+
// The feature count is: 9374
51+
// The number of fields is: 1
52+
// The latent dimension is: 20
53+
// The linear weights of some of the features are: 0.0196 0.0000 -0.0045 -0.0205 0.0000 0.0032 0.0682 0.0091 -0.0151 0.0089
54+
// The weights of some of the latent features are: 0.3316 0.2140 0.0752 0.0908 -0.0495 -0.0810 0.0761 0.0966 0.0090 -0.0962
55+
56+
// Evaluate how the model is doing on the test data.
57+
var dataWithPredictions = model.Transform(testData);
58+
59+
var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions, "Sentiment");
60+
SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
61+
62+
// Accuracy: 0.72
63+
// AUC: 0.75
64+
// F1 Score: 0.74
65+
// Negative Precision: 0.75
66+
// Negative Recall: 0.67
67+
// Positive Precision: 0.70
68+
// Positive Recall: 0.78
69+
}
70+
}
71+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using System;
2+
using System.Linq;
3+
using Microsoft.ML.Data;
4+
using Microsoft.ML.FactorizationMachine;
5+
6+
namespace Microsoft.ML.Samples.Dynamic
7+
{
8+
public static class FFMBinaryClassificationWithOptions
9+
{
10+
public static void Example()
11+
{
12+
// Creating the ML.Net IHostEnvironment object, needed for the pipeline.
13+
var mlContext = new MLContext();
14+
15+
// Download and featurize the dataset.
16+
(var trainData, var testData) = SamplesUtils.DatasetUtils.LoadFeaturizedSentimentDataset(mlContext);
17+
18+
// ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to
19+
// expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially
20+
// helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a
21+
// cache step in a pipeline is also possible, please see the construction of pipeline below.
22+
trainData = mlContext.Data.Cache(trainData);
23+
24+
// Step 2: Pipeline
25+
// Create the 'FieldAwareFactorizationMachine' binary classifier, setting the "Sentiment" column as the label of the dataset, and
26+
// the "Features" column as the features column.
27+
var pipeline = new EstimatorChain<ITransformer>().AppendCacheCheckpoint(mlContext)
28+
.Append(mlContext.BinaryClassification.Trainers.
29+
FieldAwareFactorizationMachine(
30+
new FieldAwareFactorizationMachineTrainer.Options
31+
{
32+
FeatureColumn = "Features",
33+
LabelColumn = "Sentiment",
34+
LearningRate = 0.1f,
35+
Iters = 10
36+
}));
37+
38+
// Fit the model.
39+
var model = pipeline.Fit(trainData);
40+
41+
// Let's get the model parameters from the model.
42+
var modelParams = model.LastTransformer.Model;
43+
44+
// Let's inspect the model parameters.
45+
var featureCount = modelParams.GetFeatureCount();
46+
var fieldCount = modelParams.GetFieldCount();
47+
var latentDim = modelParams.GetLatentDim();
48+
var linearWeights = modelParams.GetLinearWeights();
49+
var latentWeights = modelParams.GetLatentWeights();
50+
51+
Console.WriteLine("The feature count is: " + featureCount);
52+
Console.WriteLine("The number of fields is: " + fieldCount);
53+
Console.WriteLine("The latent dimension is: " + latentDim);
54+
Console.WriteLine("The linear weights of some of the features are: " +
55+
string.Concat(Enumerable.Range(1, 10).Select(i => $"{linearWeights[i]:F4} ")));
56+
Console.WriteLine("The weights of some of the latent features are: " +
57+
string.Concat(Enumerable.Range(1, 10).Select(i => $"{latentWeights[i]:F4} ")));
58+
59+
// The feature count is: 9374
60+
// The number of fields is: 1
61+
// The latent dimension is: 20
62+
// The linear weights of some of the features are: 0.0410 0.0000 -0.0078 -0.0285 0.0000 0.0114 0.1313 0.0183 -0.0224 0.0166
63+
// The weights of some of the latent features are: -0.0326 0.1127 0.0621 0.1446 0.2038 0.1608 0.2084 0.0141 0.2458 -0.0625
64+
65+
// Evaluate how the model is doing on the test data.
66+
var dataWithPredictions = model.Transform(testData);
67+
68+
var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions, "Sentiment");
69+
SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
70+
71+
// Accuracy: 0.78
72+
// AUC: 0.81
73+
// F1 Score: 0.78
74+
// Negative Precision: 0.78
75+
// Negative Recall: 0.78
76+
// Positive Precision: 0.78
77+
// Positive Recall: 0.78
78+
}
79+
}
80+
}

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/SDCALogisticRegression.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public static void Example()
1212
// Downloading the dataset from github.com/dotnet/machinelearning.
1313
// This will create a sentiment.tsv file in the filesystem.
1414
// You can open this file, if you want to see the data.
15-
string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
15+
(var dataFile, _ ) = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
1616

1717
// A preview of the data.
1818
// Sentiment SentimentText

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/PriorTrainerSample.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public static void Example()
1010
// Downloading the dataset from github.com/dotnet/machinelearning.
1111
// This will create a sentiment.tsv file in the filesystem.
1212
// You can open this file, if you want to see the data.
13-
string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
13+
(string dataFile, _ ) = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
1414

1515
// A preview of the data.
1616
// Sentiment SentimentText

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/RandomTrainerSample.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public static void Example()
1010
// Downloading the dataset from github.com/dotnet/machinelearning.
1111
// This will create a sentiment.tsv file in the filesystem.
1212
// You can open this file, if you want to see the data.
13-
string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
13+
(string dataFile, _ ) = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
1414

1515
// A preview of the data.
1616
// Sentiment SentimentText

docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
<NativeAssemblyReference Include="CpuMathNative" />
2323
<NativeAssemblyReference Include="FastTreeNative" />
2424
<NativeAssemblyReference Include="MatrixFactorizationNative" />
25+
<NativeAssemblyReference Include="FactorizationMachineNative" />
2526
<NativeAssemblyReference Include="LdaNative" />
2627
<NativeAssemblyReference Include="SymSgdNative" />
2728
<PackageReference Include="Microsoft.ML.TensorFlow.Redist" Version="0.10.0" />

src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs

+39-5
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,48 @@ public sealed class HousingRegression
7878
/// <summary>
7979
/// Downloads the wikipedia detox dataset from the ML.NET repo.
8080
/// </summary>
81-
public static string DownloadSentimentDataset()
82-
=> Download("https://raw.githubusercontent.com/dotnet/machinelearning/76cb2cdf5cc8b6c88ca44b8969153836e589df04/test/data/wikipedia-detox-250-line-data.tsv", "sentiment.tsv");
81+
public static (string trainFile, string testFile) DownloadSentimentDataset()
82+
{
83+
var trainFile = Download("https://raw.githubusercontent.com/dotnet/machinelearning/76cb2cdf5cc8b6c88ca44b8969153836e589df04/test/data/wikipedia-detox-250-line-data.tsv", "sentiment.tsv");
84+
var testFile = Download("https://raw.githubusercontent.com/dotnet/machinelearning/76cb2cdf5cc8b6c88ca44b8969153836e589df04/test/data/wikipedia-detox-250-line-test.tsv", "sentimenttest.tsv");
85+
return (trainFile, testFile);
86+
}
87+
88+
/// <summary>
89+
/// Downloads the adult dataset from the ML.NET repo.
90+
/// </summary>
91+
public static string DownloadAdultDataset()
92+
=> Download("https://raw.githubusercontent.com/dotnet/machinelearning/244a8c2ac832657af282aa312d568211698790aa/test/data/adult.train", "adult.txt");
8393

8494
/// <summary>
85-
/// Downloads the adult dataset from the ML.NET repo.
95+
/// Downloads the wikipedia detox dataset and featurizes it to be suitable for sentiment classification tasks.
8696
/// </summary>
87-
public static string DownloadAdultDataset()
88-
=> Download("https://raw.githubusercontent.com/dotnet/machinelearning/244a8c2ac832657af282aa312d568211698790aa/test/data/adult.train", "adult.txt");
97+
/// <param name="mlContext"><see cref="MLContext"/> used for data loading and processing.</param>
98+
/// <returns>Featurized dataset.</returns>
99+
public static (IDataView trainData, IDataView testData) LoadFeaturizedSentimentDataset(MLContext mlContext)
100+
{
101+
// Download the file
102+
(string trainFile, string testFile) = DownloadSentimentDataset();
103+
104+
// Define the columns to read
105+
var reader = mlContext.Data.CreateTextLoader(
106+
columns: new[]
107+
{
108+
new TextLoader.Column("Sentiment", DataKind.BL, 0),
109+
new TextLoader.Column("SentimentText", DataKind.Text, 1)
110+
},
111+
hasHeader: true
112+
);
113+
114+
// Create data featurizing pipeline
115+
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText");
116+
117+
var data = reader.Read(trainFile);
118+
var model = pipeline.Fit(data);
119+
var featurizedDataTrain = model.Transform(data);
120+
var featurizedDataTest = model.Transform(reader.Read(testFile));
121+
return (featurizedDataTrain, featurizedDataTest);
122+
}
89123

90124
/// <summary>
91125
/// Downloads the Adult UCI dataset and featurizes it to be suitable for classification tasks.

src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs

+7-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public static class FactorizationMachineExtensions
2323
/// <example>
2424
/// <format type="text/markdown">
2525
/// <![CDATA[
26-
/// [!code-csharp[FieldAwareFactorizationMachine](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs)]
26+
/// [!code-csharp[FieldAwareFactorizationMachine](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachine.cs)]
2727
/// ]]></format>
2828
/// </example>
2929
public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
@@ -41,6 +41,12 @@ public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachi
4141
/// </summary>
4242
/// <param name="catalog">The binary classification catalog trainer object.</param>
4343
/// <param name="options">Advanced arguments to the algorithm.</param>
44+
/// <example>
45+
/// <format type="text/markdown">
46+
/// <![CDATA[
47+
/// [!code-csharp[FieldAwareFactorizationMachine](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachineWithOptions.cs)]
48+
/// ]]></format>
49+
/// </example>
4450
public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
4551
FieldAwareFactorizationMachineTrainer.Options options)
4652
{

0 commit comments

Comments
 (0)