diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs index c9e0954a2e..7d085fc965 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs @@ -20,19 +20,33 @@ public static class FactorizationMachineExtensions /// The features, or independent variables. /// The label, or dependent variable. /// The optional example weights. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . + /// + /// + /// + /// public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, string[] featureColumns, string labelColumn = DefaultColumnNames.Label, - string weights = null, - Action advancedSettings = null) + string weights = null) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new FieldAwareFactorizationMachineTrainer(env, featureColumns, labelColumn, weights, advancedSettings: advancedSettings); + return new FieldAwareFactorizationMachineTrainer(env, featureColumns, labelColumn, weights); + } + + /// + /// Predict a target using a field-aware factorization machine algorithm. + /// + /// The binary classification catalog trainer object. + /// Advanced arguments to the algorithm. + public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, + FieldAwareFactorizationMachineTrainer.Options options) + { + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + return new FieldAwareFactorizationMachineTrainer(env, options); } } } diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index a1655114df..270efd291c 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -16,7 +16,7 @@ using Microsoft.ML.Training; [assembly: LoadableClass(FieldAwareFactorizationMachineTrainer.Summary, typeof(FieldAwareFactorizationMachineTrainer), - typeof(FieldAwareFactorizationMachineTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) } + typeof(FieldAwareFactorizationMachineTrainer.Options), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) } , FieldAwareFactorizationMachineTrainer.UserName, FieldAwareFactorizationMachineTrainer.LoadName, FieldAwareFactorizationMachineTrainer.ShortName, DocName = "trainer/FactorizationMachine.md")] @@ -40,7 +40,7 @@ public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase /// The feature column that the trainer expects. /// - public readonly SchemaShape.Column[] FeatureColumns; + internal readonly SchemaShape.Column[] FeatureColumns; /// /// The label column that the trainer expects. Can be null, which indicates that label /// is not used for training. /// - public readonly SchemaShape.Column LabelColumn; + internal readonly SchemaShape.Column LabelColumn; /// /// The weight column that the trainer expects. Can be null, which indicates that weight is /// not used for training. /// - public readonly SchemaShape.Column WeightColumn; + internal readonly SchemaShape.Column WeightColumn; /// /// The containing at least the training data for this trainer. @@ -121,48 +121,46 @@ public sealed class Arguments : LearnerInputBaseWithWeight private float _radius; /// - /// Legacy constructor initializing a new instance of through the legacy - /// class. + /// Initializes a new instance of through the class. /// /// The private instance of . - /// An instance of the legacy to apply advanced parameters to the algorithm. - public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments args) + /// An instance of the legacy to apply advanced parameters to the algorithm. + [BestFriend] + internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Options options) : base(env, LoadName) { - Initialize(env, args); + Initialize(env, options); Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); - var extraColumnLength = (args.ExtraFeatureColumns != null ? args.ExtraFeatureColumns.Length : 0); + var extraColumnLength = (options.ExtraFeatureColumns != null ? options.ExtraFeatureColumns.Length : 0); // There can be multiple feature columns in FFM, jointly specified by args.FeatureColumn and args.ExtraFeatureColumns. FeatureColumns = new SchemaShape.Column[1 + extraColumnLength]; // Treat the default feature column as the 1st field. - FeatureColumns[0] = new SchemaShape.Column(args.FeatureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + FeatureColumns[0] = new SchemaShape.Column(options.FeatureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); // Add 2nd, 3rd, and other fields from a FFM-specific argument, args.ExtraFeatureColumns. for (int i = 0; i < extraColumnLength; i++) - FeatureColumns[i + 1] = new SchemaShape.Column(args.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + FeatureColumns[i + 1] = new SchemaShape.Column(options.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); - LabelColumn = new SchemaShape.Column(args.LabelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); - WeightColumn = args.WeightColumn.IsExplicit ? new SchemaShape.Column(args.WeightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default; + LabelColumn = new SchemaShape.Column(options.LabelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); + WeightColumn = options.WeightColumn.IsExplicit ? new SchemaShape.Column(options.WeightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default; } /// - /// Initializing a new instance of . + /// Initializes a new instance of . /// /// The private instance of . /// The name of column hosting the features. The i-th element stores feature column of the i-th field. /// The name of the label column. - /// A delegate to apply all the advanced arguments to the algorithm. /// The name of the optional weights' column. - public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, + [BestFriend] + internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env, string[] featureColumns, string labelColumn = DefaultColumnNames.Label, - string weights = null, - Action advancedSettings = null) + string weights = null) : base(env, LoadName) { - var args = new Arguments(); - advancedSettings?.Invoke(args); + var args = new Options(); Initialize(env, args); Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); @@ -181,24 +179,24 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, /// REVIEW: Once the legacy constructor goes away, this can move to the only constructor and most of the fields can be back to readonly. /// /// - /// - private void Initialize(IHostEnvironment env, Arguments args) + /// + private void Initialize(IHostEnvironment env, Options options) { - Host.CheckUserArg(args.LatentDim > 0, nameof(args.LatentDim), "Must be positive"); - Host.CheckUserArg(args.LambdaLinear >= 0, nameof(args.LambdaLinear), "Must be non-negative"); - Host.CheckUserArg(args.LambdaLatent >= 0, nameof(args.LambdaLatent), "Must be non-negative"); - Host.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), "Must be positive"); - Host.CheckUserArg(args.Iters >= 0, nameof(args.Iters), "Must be non-negative"); - _latentDim = args.LatentDim; + Host.CheckUserArg(options.LatentDim > 0, nameof(options.LatentDim), "Must be positive"); + Host.CheckUserArg(options.LambdaLinear >= 0, nameof(options.LambdaLinear), "Must be non-negative"); + Host.CheckUserArg(options.LambdaLatent >= 0, nameof(options.LambdaLatent), "Must be non-negative"); + Host.CheckUserArg(options.LearningRate > 0, nameof(options.LearningRate), "Must be positive"); + Host.CheckUserArg(options.Iters >= 0, nameof(options.Iters), "Must be non-negative"); + _latentDim = options.LatentDim; _latentDimAligned = FieldAwareFactorizationMachineUtils.GetAlignedVectorLength(_latentDim); - _lambdaLinear = args.LambdaLinear; - _lambdaLatent = args.LambdaLatent; - _learningRate = args.LearningRate; - _numIterations = args.Iters; - _norm = args.Norm; - _shuffle = args.Shuffle; - _verbose = args.Verbose; - _radius = args.Radius; + _lambdaLinear = options.LambdaLinear; + _lambdaLatent = options.LambdaLatent; + _learningRate = options.LearningRate; + _numIterations = options.Iters; + _norm = options.Norm; + _shuffle = options.Shuffle; + _verbose = options.Verbose; + _radius = options.Radius; } private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachineModelParameters predictor, out float[] linearWeights, @@ -476,13 +474,13 @@ private protected override FieldAwareFactorizationMachineModelParameters Train(T ShortName = ShortName, XmlInclude = new[] { @"", @"" })] - public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input) + public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("Train a field-aware factorization machine"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, () => new FieldAwareFactorizationMachineTrainer(host, input), + return LearnerEntryPointsUtils.Train(host, input, () => new FieldAwareFactorizationMachineTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); } diff --git a/src/Microsoft.ML.StaticPipe/FactorizationMachineStatic.cs b/src/Microsoft.ML.StaticPipe/FactorizationMachineStatic.cs index 3bcd8e5c98..b85c3703ce 100644 --- a/src/Microsoft.ML.StaticPipe/FactorizationMachineStatic.cs +++ b/src/Microsoft.ML.StaticPipe/FactorizationMachineStatic.cs @@ -24,46 +24,61 @@ public static class FactorizationMachineExtensions /// The binary classifier catalog trainer object. /// The label, or dependent variable. /// The features, or independent variables. - /// Initial learning rate. - /// Number of training iterations. - /// Latent space dimensions. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the ./// A delegate that is called every time the + /// A delegate that is called every time the /// method is called on the - /// instance created out of this. This delegate will receive - /// the model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to + /// instance created out of this. + /// This delegate will receive the model that was trained. The type of the model is . + /// Note that this action cannot change the result in any way; it is only a way for the caller to be informed about what was learnt. + /// The predicted output. + public static (Scalar score, Scalar predictedLabel) FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, + Scalar label, Vector[] features, + Action onFit = null) + { + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckNonEmpty(features, nameof(features)); + + Contracts.CheckValueOrNull(onFit); + + var rec = new CustomReconciler((env, labelCol, featureCols) => + { + var trainer = new FieldAwareFactorizationMachineTrainer(env, featureCols, labelCol); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + else + return trainer; + }, label, features); + return rec.Output; + } + + /// + /// Predict a target using a field-aware factorization machine. + /// + /// The binary classifier catalog trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// Advanced arguments to the algorithm. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. + /// This delegate will receive the model that was trained. The type of the model is . + /// Note that this action cannot change the result in any way; it is only a way for the caller to /// be informed about what was learnt. /// The predicted output. public static (Scalar score, Scalar predictedLabel) FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, Scalar label, Vector[] features, - float learningRate = 0.1f, - int numIterations = 5, - int numLatentDimensions = 20, - Action advancedSettings = null, + FieldAwareFactorizationMachineTrainer.Options options, Action onFit = null) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckNonEmpty(features, nameof(features)); - Contracts.CheckParam(learningRate > 0, nameof(learningRate), "Must be positive"); - Contracts.CheckParam(numIterations > 0, nameof(numIterations), "Must be positive"); - Contracts.CheckParam(numLatentDimensions > 0, nameof(numLatentDimensions), "Must be positive"); - Contracts.CheckValueOrNull(advancedSettings); + Contracts.CheckValueOrNull(options); Contracts.CheckValueOrNull(onFit); var rec = new CustomReconciler((env, labelCol, featureCols) => { - var trainer = new FieldAwareFactorizationMachineTrainer(env, featureCols, labelCol, advancedSettings: - args => - { - args.LearningRate = learningRate; - args.Iters = numIterations; - args.LatentDim = numLatentDimensions; - - advancedSettings?.Invoke(args); - }); + var trainer = new FieldAwareFactorizationMachineTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); else diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index aecd8cdd3a..cbb12ae3ef 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -49,7 +49,7 @@ Trainers.FastTreeBinaryClassifier Uses a logit-boost boosted tree learner to per Trainers.FastTreeRanker Trains gradient boosted decision trees to the LambdaRank quasi-gradient. Microsoft.ML.Trainers.FastTree.FastTree TrainRanking Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput Trainers.FastTreeRegressor Trains gradient boosted decision trees to fit target values using least-squares. Microsoft.ML.Trainers.FastTree.FastTree TrainRegression Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.FastTreeTweedieRegressor Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression. Microsoft.ML.Trainers.FastTree.FastTree TrainTweedieRegression Microsoft.ML.Trainers.FastTree.FastTreeTweedieTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput -Trainers.FieldAwareFactorizationMachineBinaryClassifier Train a field-aware factorization machine for binary classification Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineTrainer TrainBinary Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput +Trainers.FieldAwareFactorizationMachineBinaryClassifier Train a field-aware factorization machine for binary classification Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineTrainer TrainBinary Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.GeneralizedAdditiveModelBinaryClassifier Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainBinary Microsoft.ML.Trainers.FastTree.BinaryClassificationGamTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.GeneralizedAdditiveModelRegressor Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainRegression Microsoft.ML.Trainers.FastTree.RegressionGamTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.KMeansPlusPlusClusterer K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers. Microsoft.ML.Trainers.KMeans.KMeansPlusPlusTrainer TrainKMeans Microsoft.ML.Trainers.KMeans.KMeansPlusPlusTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+ClusteringOutput diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs index 629c4c15c2..5292d4b339 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs @@ -20,13 +20,13 @@ public void FfmBinaryClassificationWithAdvancedArguments() var data = DatasetUtils.GenerateFfmSamples(500); var dataView = mlContext.Data.ReadFromEnumerable(data); - var ffmArgs = new FieldAwareFactorizationMachineTrainer.Arguments(); + var ffmArgs = new FieldAwareFactorizationMachineTrainer.Options(); // Customized the field names. ffmArgs.FeatureColumn = nameof(DatasetUtils.FfmExample.Field0); // First field. ffmArgs.ExtraFeatureColumns = new[]{ nameof(DatasetUtils.FfmExample.Field1), nameof(DatasetUtils.FfmExample.Field2) }; - var pipeline = new FieldAwareFactorizationMachineTrainer(mlContext, ffmArgs); + var pipeline = mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine(ffmArgs); var model = pipeline.Fit(dataView); var prediction = model.Transform(dataView); @@ -45,13 +45,15 @@ public void FieldAwareFactorizationMachine_Estimator() var data = new TextLoader(Env, GetFafmBCLoaderArgs()) .Read(GetDataPath(TestDatasets.breastCancer.trainFilename)); - var est = new FieldAwareFactorizationMachineTrainer(Env, new[] { "Feature1", "Feature2", "Feature3", "Feature4" }, "Label", - advancedSettings: s => - { - s.Shuffle = false; - s.Iters = 3; - s.LatentDim = 7; - }); + var ffmArgs = new FieldAwareFactorizationMachineTrainer.Options { + FeatureColumn = "Feature1", // Features from the 1st field. + ExtraFeatureColumns = new[] { "Feature2", "Feature3", "Feature4" }, // 2nd field's feature column, 3rd field's feature column, 4th field's feature column. + Shuffle = false, + Iters = 3, + LatentDim = 7, + }; + + var est = ML.BinaryClassification.Trainers.FieldAwareFactorizationMachine(ffmArgs); TestEstimatorCore(est, data); var model = est.Fit(data);