diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs
index c9e0954a2e..7d085fc965 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs
@@ -20,19 +20,33 @@ public static class FactorizationMachineExtensions
/// The features, or independent variables.
/// The label, or dependent variable.
/// The optional example weights.
- /// A delegate to set more settings.
- /// The settings here will override the ones provided in the direct method signature,
- /// if both are present and have different values.
- /// The columns names, however need to be provided directly, not through the .
+ ///
+ ///
+ ///
+ ///
public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
string[] featureColumns,
string labelColumn = DefaultColumnNames.Label,
- string weights = null,
- Action advancedSettings = null)
+ string weights = null)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
- return new FieldAwareFactorizationMachineTrainer(env, featureColumns, labelColumn, weights, advancedSettings: advancedSettings);
+ return new FieldAwareFactorizationMachineTrainer(env, featureColumns, labelColumn, weights);
+ }
+
+ ///
+ /// Predict a target using a field-aware factorization machine algorithm.
+ ///
+ /// The binary classification catalog trainer object.
+ /// Advanced arguments to the algorithm.
+ public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
+ FieldAwareFactorizationMachineTrainer.Options options)
+ {
+ Contracts.CheckValue(catalog, nameof(catalog));
+ var env = CatalogUtils.GetEnvironment(catalog);
+ return new FieldAwareFactorizationMachineTrainer(env, options);
}
}
}
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
index a1655114df..270efd291c 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
@@ -16,7 +16,7 @@
using Microsoft.ML.Training;
[assembly: LoadableClass(FieldAwareFactorizationMachineTrainer.Summary, typeof(FieldAwareFactorizationMachineTrainer),
- typeof(FieldAwareFactorizationMachineTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }
+ typeof(FieldAwareFactorizationMachineTrainer.Options), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }
, FieldAwareFactorizationMachineTrainer.UserName, FieldAwareFactorizationMachineTrainer.LoadName,
FieldAwareFactorizationMachineTrainer.ShortName, DocName = "trainer/FactorizationMachine.md")]
@@ -40,7 +40,7 @@ public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase
/// The feature column that the trainer expects.
///
- public readonly SchemaShape.Column[] FeatureColumns;
+ internal readonly SchemaShape.Column[] FeatureColumns;
///
/// The label column that the trainer expects. Can be null, which indicates that label
/// is not used for training.
///
- public readonly SchemaShape.Column LabelColumn;
+ internal readonly SchemaShape.Column LabelColumn;
///
/// The weight column that the trainer expects. Can be null, which indicates that weight is
/// not used for training.
///
- public readonly SchemaShape.Column WeightColumn;
+ internal readonly SchemaShape.Column WeightColumn;
///
/// The containing at least the training data for this trainer.
@@ -121,48 +121,46 @@ public sealed class Arguments : LearnerInputBaseWithWeight
private float _radius;
///
- /// Legacy constructor initializing a new instance of through the legacy
- /// class.
+ /// Initializes a new instance of through the class.
///
/// The private instance of .
- /// An instance of the legacy to apply advanced parameters to the algorithm.
- public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments args)
+ /// An instance of the legacy to apply advanced parameters to the algorithm.
+ [BestFriend]
+ internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Options options)
: base(env, LoadName)
{
- Initialize(env, args);
+ Initialize(env, options);
Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);
- var extraColumnLength = (args.ExtraFeatureColumns != null ? args.ExtraFeatureColumns.Length : 0);
+ var extraColumnLength = (options.ExtraFeatureColumns != null ? options.ExtraFeatureColumns.Length : 0);
// There can be multiple feature columns in FFM, jointly specified by args.FeatureColumn and args.ExtraFeatureColumns.
FeatureColumns = new SchemaShape.Column[1 + extraColumnLength];
// Treat the default feature column as the 1st field.
- FeatureColumns[0] = new SchemaShape.Column(args.FeatureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
+ FeatureColumns[0] = new SchemaShape.Column(options.FeatureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
// Add 2nd, 3rd, and other fields from a FFM-specific argument, args.ExtraFeatureColumns.
for (int i = 0; i < extraColumnLength; i++)
- FeatureColumns[i + 1] = new SchemaShape.Column(args.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
+ FeatureColumns[i + 1] = new SchemaShape.Column(options.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
- LabelColumn = new SchemaShape.Column(args.LabelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
- WeightColumn = args.WeightColumn.IsExplicit ? new SchemaShape.Column(args.WeightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default;
+ LabelColumn = new SchemaShape.Column(options.LabelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
+ WeightColumn = options.WeightColumn.IsExplicit ? new SchemaShape.Column(options.WeightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default;
}
///
- /// Initializing a new instance of .
+ /// Initializes a new instance of .
///
/// The private instance of .
/// The name of column hosting the features. The i-th element stores feature column of the i-th field.
/// The name of the label column.
- /// A delegate to apply all the advanced arguments to the algorithm.
/// The name of the optional weights' column.
- public FieldAwareFactorizationMachineTrainer(IHostEnvironment env,
+ [BestFriend]
+ internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env,
string[] featureColumns,
string labelColumn = DefaultColumnNames.Label,
- string weights = null,
- Action advancedSettings = null)
+ string weights = null)
: base(env, LoadName)
{
- var args = new Arguments();
- advancedSettings?.Invoke(args);
+ var args = new Options();
Initialize(env, args);
Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);
@@ -181,24 +179,24 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env,
/// REVIEW: Once the legacy constructor goes away, this can move to the only constructor and most of the fields can be back to readonly.
///
///
- ///
- private void Initialize(IHostEnvironment env, Arguments args)
+ ///
+ private void Initialize(IHostEnvironment env, Options options)
{
- Host.CheckUserArg(args.LatentDim > 0, nameof(args.LatentDim), "Must be positive");
- Host.CheckUserArg(args.LambdaLinear >= 0, nameof(args.LambdaLinear), "Must be non-negative");
- Host.CheckUserArg(args.LambdaLatent >= 0, nameof(args.LambdaLatent), "Must be non-negative");
- Host.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), "Must be positive");
- Host.CheckUserArg(args.Iters >= 0, nameof(args.Iters), "Must be non-negative");
- _latentDim = args.LatentDim;
+ Host.CheckUserArg(options.LatentDim > 0, nameof(options.LatentDim), "Must be positive");
+ Host.CheckUserArg(options.LambdaLinear >= 0, nameof(options.LambdaLinear), "Must be non-negative");
+ Host.CheckUserArg(options.LambdaLatent >= 0, nameof(options.LambdaLatent), "Must be non-negative");
+ Host.CheckUserArg(options.LearningRate > 0, nameof(options.LearningRate), "Must be positive");
+ Host.CheckUserArg(options.Iters >= 0, nameof(options.Iters), "Must be non-negative");
+ _latentDim = options.LatentDim;
_latentDimAligned = FieldAwareFactorizationMachineUtils.GetAlignedVectorLength(_latentDim);
- _lambdaLinear = args.LambdaLinear;
- _lambdaLatent = args.LambdaLatent;
- _learningRate = args.LearningRate;
- _numIterations = args.Iters;
- _norm = args.Norm;
- _shuffle = args.Shuffle;
- _verbose = args.Verbose;
- _radius = args.Radius;
+ _lambdaLinear = options.LambdaLinear;
+ _lambdaLatent = options.LambdaLatent;
+ _learningRate = options.LearningRate;
+ _numIterations = options.Iters;
+ _norm = options.Norm;
+ _shuffle = options.Shuffle;
+ _verbose = options.Verbose;
+ _radius = options.Radius;
}
private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachineModelParameters predictor, out float[] linearWeights,
@@ -476,13 +474,13 @@ private protected override FieldAwareFactorizationMachineModelParameters Train(T
ShortName = ShortName,
XmlInclude = new[] { @"",
@"" })]
- public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input)
+ public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("Train a field-aware factorization machine");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
- return LearnerEntryPointsUtils.Train(host, input, () => new FieldAwareFactorizationMachineTrainer(host, input),
+ return LearnerEntryPointsUtils.Train(host, input, () => new FieldAwareFactorizationMachineTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}
diff --git a/src/Microsoft.ML.StaticPipe/FactorizationMachineStatic.cs b/src/Microsoft.ML.StaticPipe/FactorizationMachineStatic.cs
index 3bcd8e5c98..b85c3703ce 100644
--- a/src/Microsoft.ML.StaticPipe/FactorizationMachineStatic.cs
+++ b/src/Microsoft.ML.StaticPipe/FactorizationMachineStatic.cs
@@ -24,46 +24,61 @@ public static class FactorizationMachineExtensions
/// The binary classifier catalog trainer object.
/// The label, or dependent variable.
/// The features, or independent variables.
- /// Initial learning rate.
- /// Number of training iterations.
- /// Latent space dimensions.
- /// A delegate to set more settings.
- /// The settings here will override the ones provided in the direct method signature,
- /// if both are present and have different values.
- /// The columns names, however need to be provided directly, not through the ./// A delegate that is called every time the
+ /// A delegate that is called every time the
/// method is called on the
- /// instance created out of this. This delegate will receive
- /// the model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to
+ /// instance created out of this.
+ /// This delegate will receive the model that was trained. The type of the model is .
+ /// Note that this action cannot change the result in any way; it is only a way for the caller to be informed about what was learnt.
+ /// The predicted output.
+ public static (Scalar score, Scalar predictedLabel) FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
+ Scalar label, Vector[] features,
+ Action onFit = null)
+ {
+ Contracts.CheckValue(label, nameof(label));
+ Contracts.CheckNonEmpty(features, nameof(features));
+
+ Contracts.CheckValueOrNull(onFit);
+
+ var rec = new CustomReconciler((env, labelCol, featureCols) =>
+ {
+ var trainer = new FieldAwareFactorizationMachineTrainer(env, featureCols, labelCol);
+
+ if (onFit != null)
+ return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
+ else
+ return trainer;
+ }, label, features);
+ return rec.Output;
+ }
+
+ ///
+ /// Predict a target using a field-aware factorization machine.
+ ///
+ /// The binary classifier catalog trainer object.
+ /// The label, or dependent variable.
+ /// The features, or independent variables.
+ /// Advanced arguments to the algorithm.
+ /// A delegate that is called every time the
+ /// method is called on the
+ /// instance created out of this.
+ /// This delegate will receive the model that was trained. The type of the model is .
+ /// Note that this action cannot change the result in any way; it is only a way for the caller to
/// be informed about what was learnt.
/// The predicted output.
public static (Scalar score, Scalar predictedLabel) FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
Scalar label, Vector[] features,
- float learningRate = 0.1f,
- int numIterations = 5,
- int numLatentDimensions = 20,
- Action advancedSettings = null,
+ FieldAwareFactorizationMachineTrainer.Options options,
Action onFit = null)
{
Contracts.CheckValue(label, nameof(label));
Contracts.CheckNonEmpty(features, nameof(features));
- Contracts.CheckParam(learningRate > 0, nameof(learningRate), "Must be positive");
- Contracts.CheckParam(numIterations > 0, nameof(numIterations), "Must be positive");
- Contracts.CheckParam(numLatentDimensions > 0, nameof(numLatentDimensions), "Must be positive");
- Contracts.CheckValueOrNull(advancedSettings);
+ Contracts.CheckValueOrNull(options);
Contracts.CheckValueOrNull(onFit);
var rec = new CustomReconciler((env, labelCol, featureCols) =>
{
- var trainer = new FieldAwareFactorizationMachineTrainer(env, featureCols, labelCol, advancedSettings:
- args =>
- {
- args.LearningRate = learningRate;
- args.Iters = numIterations;
- args.LatentDim = numLatentDimensions;
-
- advancedSettings?.Invoke(args);
- });
+ var trainer = new FieldAwareFactorizationMachineTrainer(env, options);
if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
else
diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
index aecd8cdd3a..cbb12ae3ef 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
+++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
@@ -49,7 +49,7 @@ Trainers.FastTreeBinaryClassifier Uses a logit-boost boosted tree learner to per
Trainers.FastTreeRanker Trains gradient boosted decision trees to the LambdaRank quasi-gradient. Microsoft.ML.Trainers.FastTree.FastTree TrainRanking Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput
Trainers.FastTreeRegressor Trains gradient boosted decision trees to fit target values using least-squares. Microsoft.ML.Trainers.FastTree.FastTree TrainRegression Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
Trainers.FastTreeTweedieRegressor Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression. Microsoft.ML.Trainers.FastTree.FastTree TrainTweedieRegression Microsoft.ML.Trainers.FastTree.FastTreeTweedieTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
-Trainers.FieldAwareFactorizationMachineBinaryClassifier Train a field-aware factorization machine for binary classification Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineTrainer TrainBinary Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
+Trainers.FieldAwareFactorizationMachineBinaryClassifier Train a field-aware factorization machine for binary classification Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineTrainer TrainBinary Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
Trainers.GeneralizedAdditiveModelBinaryClassifier Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainBinary Microsoft.ML.Trainers.FastTree.BinaryClassificationGamTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
Trainers.GeneralizedAdditiveModelRegressor Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainRegression Microsoft.ML.Trainers.FastTree.RegressionGamTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
Trainers.KMeansPlusPlusClusterer K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers. Microsoft.ML.Trainers.KMeans.KMeansPlusPlusTrainer TrainKMeans Microsoft.ML.Trainers.KMeans.KMeansPlusPlusTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+ClusteringOutput
diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs
index 629c4c15c2..5292d4b339 100644
--- a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs
+++ b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs
@@ -20,13 +20,13 @@ public void FfmBinaryClassificationWithAdvancedArguments()
var data = DatasetUtils.GenerateFfmSamples(500);
var dataView = mlContext.Data.ReadFromEnumerable(data);
- var ffmArgs = new FieldAwareFactorizationMachineTrainer.Arguments();
+ var ffmArgs = new FieldAwareFactorizationMachineTrainer.Options();
// Customized the field names.
ffmArgs.FeatureColumn = nameof(DatasetUtils.FfmExample.Field0); // First field.
ffmArgs.ExtraFeatureColumns = new[]{ nameof(DatasetUtils.FfmExample.Field1), nameof(DatasetUtils.FfmExample.Field2) };
- var pipeline = new FieldAwareFactorizationMachineTrainer(mlContext, ffmArgs);
+ var pipeline = mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine(ffmArgs);
var model = pipeline.Fit(dataView);
var prediction = model.Transform(dataView);
@@ -45,13 +45,15 @@ public void FieldAwareFactorizationMachine_Estimator()
var data = new TextLoader(Env, GetFafmBCLoaderArgs())
.Read(GetDataPath(TestDatasets.breastCancer.trainFilename));
- var est = new FieldAwareFactorizationMachineTrainer(Env, new[] { "Feature1", "Feature2", "Feature3", "Feature4" }, "Label",
- advancedSettings: s =>
- {
- s.Shuffle = false;
- s.Iters = 3;
- s.LatentDim = 7;
- });
+ var ffmArgs = new FieldAwareFactorizationMachineTrainer.Options {
+ FeatureColumn = "Feature1", // Features from the 1st field.
+ ExtraFeatureColumns = new[] { "Feature2", "Feature3", "Feature4" }, // 2nd field's feature column, 3rd field's feature column, 4th field's feature column.
+ Shuffle = false,
+ Iters = 3,
+ LatentDim = 7,
+ };
+
+ var est = ML.BinaryClassification.Trainers.FieldAwareFactorizationMachine(ffmArgs);
TestEstimatorCore(est, data);
var model = est.Fit(data);