Skip to content

Scrubbing FieldAwareFactorizationMachine learner. #2730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 27, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public static void Example()
Console.WriteLine("The linear weights of some of the features are: " +
string.Concat(Enumerable.Range(1, 10).Select(i => $"{linearWeights[i]:F4} ")));
Console.WriteLine("The weights of some of the latent features are: " +
string.Concat(Enumerable.Range(1, 10).Select(i => $"{latentWeights[i]:F4} ")));
string.Concat(Enumerable.Range(1, 10).Select(i => $"{latentWeights[i]:F4} ")));

// The feature count is: 9374
// The number of fields is: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ public static void Example()
var pipeline = new EstimatorChain<ITransformer>().AppendCacheCheckpoint(mlContext)
.Append(mlContext.BinaryClassification.Trainers.
FieldAwareFactorizationMachine(
new FieldAwareFactorizationMachineTrainer.Options
new FieldAwareFactorizationMachineBinaryClassificationTrainer.Options
{
FeatureColumn = "Features",
LabelColumn = "Sentiment",
LearningRate = 0.1f,
Iterations = 10
NumberOfIterations = 10
}));

// Fit the model.
Expand All @@ -57,7 +57,7 @@ public static void Example()
Console.WriteLine("The linear weights of some of the features are: " +
string.Concat(Enumerable.Range(1, 10).Select(i => $"{linearWeights[i]:F4} ")));
Copy link
Member

@wschin wschin Feb 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

string.Concats are not aligned. #Resolved

Console.WriteLine("The weights of some of the latent features are: " +
string.Concat(Enumerable.Range(1, 10).Select(i => $"{latentWeights[i]:F4} ")));
string.Concat(Enumerable.Range(1, 10).Select(i => $"{latentWeights[i]:F4} ")));

// The feature count is: 9374
// The number of fields is: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace Microsoft.ML
{
/// <summary>
/// Extension method to create <see cref="FieldAwareFactorizationMachineTrainer"/>
/// Extension method to create <see cref="FieldAwareFactorizationMachineBinaryClassificationTrainer"/>
/// </summary>
public static class FactorizationMachineExtensions
{
Expand All @@ -26,14 +26,14 @@ public static class FactorizationMachineExtensions
/// [!code-csharp[FieldAwareFactorizationMachine](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachine.cs)]
/// ]]></format>
/// </example>
public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
public static FieldAwareFactorizationMachineBinaryClassificationTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
string[] featureColumnNames,
string labelColumnName = DefaultColumnNames.Label,
string exampleWeightColumnName = null)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new FieldAwareFactorizationMachineTrainer(env, featureColumnNames, labelColumnName, exampleWeightColumnName);
return new FieldAwareFactorizationMachineBinaryClassificationTrainer(env, featureColumnNames, labelColumnName, exampleWeightColumnName);
}

/// <summary>
Expand All @@ -47,12 +47,12 @@ public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachi
/// [!code-csharp[FieldAwareFactorizationMachine](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachineWithOptions.cs)]
/// ]]></format>
/// </example>
public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
FieldAwareFactorizationMachineTrainer.Options options)
public static FieldAwareFactorizationMachineBinaryClassificationTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
FieldAwareFactorizationMachineBinaryClassificationTrainer.Options options)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new FieldAwareFactorizationMachineTrainer(env, options);
return new FieldAwareFactorizationMachineBinaryClassificationTrainer(env, options);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Trainers;

[assembly: LoadableClass(FieldAwareFactorizationMachineTrainer.Summary, typeof(FieldAwareFactorizationMachineTrainer),
typeof(FieldAwareFactorizationMachineTrainer.Options), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }
, FieldAwareFactorizationMachineTrainer.UserName, FieldAwareFactorizationMachineTrainer.LoadName,
FieldAwareFactorizationMachineTrainer.ShortName, DocName = "trainer/FactorizationMachine.md")]
[assembly: LoadableClass(FieldAwareFactorizationMachineBinaryClassificationTrainer.Summary, typeof(FieldAwareFactorizationMachineBinaryClassificationTrainer),
typeof(FieldAwareFactorizationMachineBinaryClassificationTrainer.Options), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }
, FieldAwareFactorizationMachineBinaryClassificationTrainer.UserName, FieldAwareFactorizationMachineBinaryClassificationTrainer.LoadName,
FieldAwareFactorizationMachineBinaryClassificationTrainer.ShortName, DocName = "trainer/FactorizationMachine.md")]

[assembly: LoadableClass(typeof(void), typeof(FieldAwareFactorizationMachineTrainer), null, typeof(SignatureEntryPointModule), FieldAwareFactorizationMachineTrainer.LoadName)]
[assembly: LoadableClass(typeof(void), typeof(FieldAwareFactorizationMachineBinaryClassificationTrainer), null, typeof(SignatureEntryPointModule), FieldAwareFactorizationMachineBinaryClassificationTrainer.LoadName)]

namespace Microsoft.ML.FactorizationMachine
{
Expand All @@ -32,7 +32,7 @@ namespace Microsoft.ML.FactorizationMachine
[3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
*/
/// <include file='doc.xml' path='doc/members/member[@name="FieldAwareFactorizationMachineBinaryClassifier"]/*' />
public sealed class FieldAwareFactorizationMachineTrainer : ITrainer<FieldAwareFactorizationMachineModelParameters>,
public sealed class FieldAwareFactorizationMachineBinaryClassificationTrainer : ITrainer<FieldAwareFactorizationMachineModelParameters>,
IEstimator<FieldAwareFactorizationMachinePredictionTransformer>
{
internal const string Summary = "Train a field-aware factorization machine for binary classification";
Expand All @@ -52,9 +52,9 @@ public sealed class Options : LearnerInputBaseWithWeight
/// <summary>
/// Number of training iterations.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of training iterations", ShortName = "iters", SortOrder = 2)]
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of training iterations", ShortName = "iters,iter", SortOrder = 2)]
[TlcModule.SweepableLongParam(1, 100)]
public int Iterations = 5;
public int NumberOfIterations = 5;

/// <summary>
/// Latent space dimension.
Expand Down Expand Up @@ -151,12 +151,12 @@ public sealed class Options : LearnerInputBaseWithWeight
private float _radius;

/// <summary>
/// Initializes a new instance of <see cref="FieldAwareFactorizationMachineTrainer"/> through the <see cref="Options"/> class.
/// Initializes a new instance of <see cref="FieldAwareFactorizationMachineBinaryClassificationTrainer"/> through the <see cref="Options"/> class.
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="options">An instance of the legacy <see cref="Options"/> to apply advanced parameters to the algorithm.</param>
[BestFriend]
internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Options options)
internal FieldAwareFactorizationMachineBinaryClassificationTrainer(IHostEnvironment env, Options options)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(LoadName);
Expand All @@ -178,14 +178,14 @@ internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Options opt
}

/// <summary>
/// Initializes a new instance of <see cref="FieldAwareFactorizationMachineTrainer"/>.
/// Initializes a new instance of <see cref="FieldAwareFactorizationMachineBinaryClassificationTrainer"/>.
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="featureColumnNames">The name of column hosting the features. The i-th element stores feature column of the i-th field.</param>
/// <param name="labelColumnName">The name of the label column.</param>
/// <param name="weightColumnName">The name of the weight column (optional).</param>
[BestFriend]
internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env,
internal FieldAwareFactorizationMachineBinaryClassificationTrainer(IHostEnvironment env,
string[] featureColumnNames,
string labelColumnName = DefaultColumnNames.Label,
string weightColumnName = null)
Expand Down Expand Up @@ -218,13 +218,13 @@ private void Initialize(IHostEnvironment env, Options options)
_host.CheckUserArg(options.LambdaLinear >= 0, nameof(options.LambdaLinear), "Must be non-negative");
_host.CheckUserArg(options.LambdaLatent >= 0, nameof(options.LambdaLatent), "Must be non-negative");
_host.CheckUserArg(options.LearningRate > 0, nameof(options.LearningRate), "Must be positive");
_host.CheckUserArg(options.Iterations >= 0, nameof(options.Iterations), "Must be non-negative");
_host.CheckUserArg(options.NumberOfIterations >= 0, nameof(options.NumberOfIterations), "Must be non-negative");
_latentDim = options.LatentDimension;
_latentDimAligned = FieldAwareFactorizationMachineUtils.GetAlignedVectorLength(_latentDim);
_lambdaLinear = options.LambdaLinear;
_lambdaLatent = options.LambdaLatent;
_learningRate = options.LearningRate;
_numIterations = options.Iterations;
_numIterations = options.NumberOfIterations;
_norm = options.Normalize;
_shuffle = options.Shuffle;
_verbose = options.Verbose;
Expand Down Expand Up @@ -514,12 +514,12 @@ internal static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnviro
var host = env.Register("Train a field-aware factorization machine");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input, () => new FieldAwareFactorizationMachineTrainer(host, input),
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input, () => new FieldAwareFactorizationMachineBinaryClassificationTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}

/// <summary>
/// Continues the training of a <see cref="FieldAwareFactorizationMachineTrainer"/> using an already trained <paramref name="modelParameters"/> and/or validation data,
/// Continues the training of a <see cref="FieldAwareFactorizationMachineBinaryClassificationTrainer"/> using an already trained <paramref name="modelParameters"/> and/or validation data,
/// and returns a <see cref="FieldAwareFactorizationMachinePredictionTransformer"/>.
/// </summary>
public FieldAwareFactorizationMachinePredictionTransformer Fit(IDataView trainData,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,9 @@ internal void CopyLatentWeightsTo(AlignedArray latentWeights)
/// <summary>
/// The linear coefficients of the features. It's the symbol `w` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
/// </summary>
public float[] GetLinearWeights()
public IReadOnlyList<float> GetLinearWeights()
Copy link
Member

@wschin wschin Feb 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One line function should look like public IReadOnlyList<float> Get() => _linearWeights. #Resolved

{
var linearWeights = new float[_linearWeights.Length];
CopyLinearWeightsTo(linearWeights);
return linearWeights;
return _linearWeights;
}

/// <summary>
Expand All @@ -267,7 +265,7 @@ public float[] GetLinearWeights()
/// The k-th element in v_{j, f} is latentWeights[j * fieldCount * latentDim + f * latentDim + k].
/// The size of the returned value is featureCount x fieldCount x latentDim.
/// </summary>
public float[] GetLatentWeights()
public IReadOnlyList<float> GetLatentWeights()
{
var latentWeights = new float[FeatureCount * FieldCount * LatentDimension];
for (int j = 0; j < FeatureCount; j++)
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.StaticPipe/FactorizationMachineStatic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static (Scalar<float> score, Scalar<bool> predictedLabel) FieldAwareFacto

var rec = new CustomReconciler((env, labelCol, featureCols) =>
{
var trainer = new FieldAwareFactorizationMachineTrainer(env, featureCols, labelCol);
var trainer = new FieldAwareFactorizationMachineBinaryClassificationTrainer(env, featureCols, labelCol);

if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
Expand All @@ -66,7 +66,7 @@ public static (Scalar<float> score, Scalar<bool> predictedLabel) FieldAwareFacto
/// <returns>The predicted output.</returns>
public static (Scalar<float> score, Scalar<bool> predictedLabel) FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
Scalar<bool> label, Vector<float>[] features,
FieldAwareFactorizationMachineTrainer.Options options,
FieldAwareFactorizationMachineBinaryClassificationTrainer.Options options,
Action<FieldAwareFactorizationMachineModelParameters> onFit = null)
{
Contracts.CheckValue(label, nameof(label));
Expand All @@ -77,7 +77,7 @@ public static (Scalar<float> score, Scalar<bool> predictedLabel) FieldAwareFacto

var rec = new CustomReconciler((env, labelCol, featureCols) =>
{
var trainer = new FieldAwareFactorizationMachineTrainer(env, options);
var trainer = new FieldAwareFactorizationMachineBinaryClassificationTrainer(env, options);
if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
else
Expand Down
2 changes: 1 addition & 1 deletion test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Trainers.FastTreeBinaryClassifier Uses a logit-boost boosted tree learner to per
Trainers.FastTreeRanker Trains gradient boosted decision trees to the LambdaRank quasi-gradient. Microsoft.ML.Trainers.FastTree.FastTree TrainRanking Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput
Trainers.FastTreeRegressor Trains gradient boosted decision trees to fit target values using least-squares. Microsoft.ML.Trainers.FastTree.FastTree TrainRegression Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
Trainers.FastTreeTweedieRegressor Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression. Microsoft.ML.Trainers.FastTree.FastTree TrainTweedieRegression Microsoft.ML.Trainers.FastTree.FastTreeTweedieTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
Trainers.FieldAwareFactorizationMachineBinaryClassifier Train a field-aware factorization machine for binary classification Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineTrainer TrainBinary Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
Trainers.FieldAwareFactorizationMachineBinaryClassifier Train a field-aware factorization machine for binary classification Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineBinaryClassificationTrainer TrainBinary Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineBinaryClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
Trainers.GeneralizedAdditiveModelBinaryClassifier Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainBinary Microsoft.ML.Trainers.FastTree.BinaryClassificationGamTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
Trainers.GeneralizedAdditiveModelRegressor Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainRegression Microsoft.ML.Trainers.FastTree.RegressionGamTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
Trainers.KMeansPlusPlusClusterer K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers. Microsoft.ML.Trainers.KMeans.KMeansPlusPlusTrainer TrainKMeans Microsoft.ML.Trainers.KMeans.KMeansPlusPlusTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+ClusteringOutput
Expand Down
5 changes: 3 additions & 2 deletions test/BaselineOutput/Common/EntryPoints/core_manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -10119,11 +10119,12 @@
"IsNullable": false
},
{
"Name": "Iterations",
"Name": "NumberOfIterations",
"Type": "Int",
"Desc": "Number of training iterations",
"Aliases": [
"iters"
"iters",
"iter"
],
"Required": false,
"SortOrder": 2.0,
Expand Down
6 changes: 3 additions & 3 deletions test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public void FfmBinaryClassificationWithAdvancedArguments()
var data = DatasetUtils.GenerateFfmSamples(500);
var dataView = mlContext.Data.ReadFromEnumerable(data);

var ffmArgs = new FieldAwareFactorizationMachineTrainer.Options();
var ffmArgs = new FieldAwareFactorizationMachineBinaryClassificationTrainer.Options();

// Customized the field names.
ffmArgs.FeatureColumn = nameof(DatasetUtils.FfmExample.Field0); // First field.
Expand All @@ -44,11 +44,11 @@ public void FieldAwareFactorizationMachine_Estimator()
var data = new TextLoader(Env, GetFafmBCLoaderArgs())
.Read(GetDataPath(TestDatasets.breastCancer.trainFilename));

var ffmArgs = new FieldAwareFactorizationMachineTrainer.Options {
var ffmArgs = new FieldAwareFactorizationMachineBinaryClassificationTrainer.Options {
FeatureColumn = "Feature1", // Features from the 1st field.
ExtraFeatureColumns = new[] { "Feature2", "Feature3", "Feature4" }, // 2nd field's feature column, 3rd field's feature column, 4th field's feature column.
Shuffle = false,
Iterations = 3,
NumberOfIterations = 3,
LatentDimension = 7,
};

Expand Down