Skip to content

Modify API for advanced settings (FieldAwareFactorizationMachineTrainer) #2219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,33 @@ public static class FactorizationMachineExtensions
/// <param name="featureColumns">The features, or independent variables.</param>
/// <param name="labelColumn">The label, or dependent variable.</param>
/// <param name="weights">The optional example weights.</param>
/// <param name="advancedSettings">A delegate to set more settings.
/// The settings here will override the ones provided in the direct method signature,
/// if both are present and have different values.
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[FieldAwareFactorizationMachine](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs)]
/// ]]></format>
/// </example>
public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
string[] featureColumns,
string labelColumn = DefaultColumnNames.Label,
string weights = null,
Action<FieldAwareFactorizationMachineTrainer.Arguments> advancedSettings = null)
string weights = null)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new FieldAwareFactorizationMachineTrainer(env, featureColumns, labelColumn, weights, advancedSettings: advancedSettings);
return new FieldAwareFactorizationMachineTrainer(env, featureColumns, labelColumn, weights);
}

/// <summary>
/// Predict a target using a field-aware factorization machine algorithm.
/// </summary>
/// <param name="catalog">The binary classification catalog trainer object.</param>
/// <param name="options">Advanced arguments to the algorithm.</param>
public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
FieldAwareFactorizationMachineTrainer.Options options)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new FieldAwareFactorizationMachineTrainer(env, options);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
using Microsoft.ML.Training;

[assembly: LoadableClass(FieldAwareFactorizationMachineTrainer.Summary, typeof(FieldAwareFactorizationMachineTrainer),
typeof(FieldAwareFactorizationMachineTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }
typeof(FieldAwareFactorizationMachineTrainer.Options), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }
, FieldAwareFactorizationMachineTrainer.UserName, FieldAwareFactorizationMachineTrainer.LoadName,
FieldAwareFactorizationMachineTrainer.ShortName, DocName = "trainer/FactorizationMachine.md")]

Expand All @@ -40,7 +40,7 @@ public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase<FieldAwa
internal const string LoadName = "FieldAwareFactorizationMachine";
internal const string ShortName = "ffm";

public sealed class Arguments : LearnerInputBaseWithWeight
public sealed class Options : LearnerInputBaseWithWeight
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate", ShortName = "lr", SortOrder = 1)]
[TlcModule.SweepableFloatParam(0.001f, 1.0f, isLogScale: true)]
Expand Down Expand Up @@ -90,19 +90,19 @@ public sealed class Arguments : LearnerInputBaseWithWeight
/// <summary>
/// The feature column that the trainer expects.
/// </summary>
public readonly SchemaShape.Column[] FeatureColumns;
internal readonly SchemaShape.Column[] FeatureColumns;

/// <summary>
/// The label column that the trainer expects. Can be <c>null</c>, which indicates that label
/// is not used for training.
/// </summary>
public readonly SchemaShape.Column LabelColumn;
internal readonly SchemaShape.Column LabelColumn;

/// <summary>
/// The weight column that the trainer expects. Can be <c>null</c>, which indicates that weight is
/// not used for training.
/// </summary>
public readonly SchemaShape.Column WeightColumn;
internal readonly SchemaShape.Column WeightColumn;

/// <summary>
/// The <see cref="TrainerInfo"/> containing at least the training data for this trainer.
Expand All @@ -121,48 +121,46 @@ public sealed class Arguments : LearnerInputBaseWithWeight
private float _radius;

/// <summary>
/// Legacy constructor initializing a new instance of <see cref="FieldAwareFactorizationMachineTrainer"/> through the legacy
/// <see cref="Arguments"/> class.
/// Initializes a new instance of <see cref="FieldAwareFactorizationMachineTrainer"/> through the <see cref="Options"/> class.
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="args">An instance of the legacy <see cref="Arguments"/> to apply advanced parameters to the algorithm.</param>
public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments args)
/// <param name="options">An instance of the legacy <see cref="Options"/> to apply advanced parameters to the algorithm.</param>
[BestFriend]
internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Options options)
: base(env, LoadName)
{
Initialize(env, args);
Initialize(env, options);
Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);
var extraColumnLength = (args.ExtraFeatureColumns != null ? args.ExtraFeatureColumns.Length : 0);
var extraColumnLength = (options.ExtraFeatureColumns != null ? options.ExtraFeatureColumns.Length : 0);
// There can be multiple feature columns in FFM, jointly specified by args.FeatureColumn and args.ExtraFeatureColumns.
FeatureColumns = new SchemaShape.Column[1 + extraColumnLength];

// Treat the default feature column as the 1st field.
FeatureColumns[0] = new SchemaShape.Column(args.FeatureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
FeatureColumns[0] = new SchemaShape.Column(options.FeatureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);

// Add 2nd, 3rd, and other fields from a FFM-specific argument, args.ExtraFeatureColumns.
for (int i = 0; i < extraColumnLength; i++)
FeatureColumns[i + 1] = new SchemaShape.Column(args.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
FeatureColumns[i + 1] = new SchemaShape.Column(options.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);

LabelColumn = new SchemaShape.Column(args.LabelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
WeightColumn = args.WeightColumn.IsExplicit ? new SchemaShape.Column(args.WeightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default;
LabelColumn = new SchemaShape.Column(options.LabelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
WeightColumn = options.WeightColumn.IsExplicit ? new SchemaShape.Column(options.WeightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default;
}

/// <summary>
/// Initializing a new instance of <see cref="FieldAwareFactorizationMachineTrainer"/>.
/// Initializes a new instance of <see cref="FieldAwareFactorizationMachineTrainer"/>.
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="featureColumns">The name of column hosting the features. The i-th element stores feature column of the i-th field.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
/// <param name="weights">The name of the optional weights' column.</param>
public FieldAwareFactorizationMachineTrainer(IHostEnvironment env,
[BestFriend]
internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env,
string[] featureColumns,
string labelColumn = DefaultColumnNames.Label,
string weights = null,
Action<Arguments> advancedSettings = null)
string weights = null)
: base(env, LoadName)
{
var args = new Arguments();
advancedSettings?.Invoke(args);
var args = new Options();

Initialize(env, args);
Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);
Expand All @@ -181,24 +179,24 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env,
/// REVIEW: Once the legacy constructor goes away, this can move to the only constructor and most of the fields can be back to readonly.
/// </summary>
/// <param name="env"></param>
/// <param name="args"></param>
private void Initialize(IHostEnvironment env, Arguments args)
/// <param name="options"></param>
private void Initialize(IHostEnvironment env, Options options)
{
Host.CheckUserArg(args.LatentDim > 0, nameof(args.LatentDim), "Must be positive");
Host.CheckUserArg(args.LambdaLinear >= 0, nameof(args.LambdaLinear), "Must be non-negative");
Host.CheckUserArg(args.LambdaLatent >= 0, nameof(args.LambdaLatent), "Must be non-negative");
Host.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), "Must be positive");
Host.CheckUserArg(args.Iters >= 0, nameof(args.Iters), "Must be non-negative");
_latentDim = args.LatentDim;
Host.CheckUserArg(options.LatentDim > 0, nameof(options.LatentDim), "Must be positive");
Host.CheckUserArg(options.LambdaLinear >= 0, nameof(options.LambdaLinear), "Must be non-negative");
Host.CheckUserArg(options.LambdaLatent >= 0, nameof(options.LambdaLatent), "Must be non-negative");
Host.CheckUserArg(options.LearningRate > 0, nameof(options.LearningRate), "Must be positive");
Host.CheckUserArg(options.Iters >= 0, nameof(options.Iters), "Must be non-negative");
_latentDim = options.LatentDim;
_latentDimAligned = FieldAwareFactorizationMachineUtils.GetAlignedVectorLength(_latentDim);
_lambdaLinear = args.LambdaLinear;
_lambdaLatent = args.LambdaLatent;
_learningRate = args.LearningRate;
_numIterations = args.Iters;
_norm = args.Norm;
_shuffle = args.Shuffle;
_verbose = args.Verbose;
_radius = args.Radius;
_lambdaLinear = options.LambdaLinear;
_lambdaLatent = options.LambdaLatent;
_learningRate = options.LearningRate;
_numIterations = options.Iters;
_norm = options.Norm;
_shuffle = options.Shuffle;
_verbose = options.Verbose;
_radius = options.Radius;
}

private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachineModelParameters predictor, out float[] linearWeights,
Expand Down Expand Up @@ -476,13 +474,13 @@ private protected override FieldAwareFactorizationMachineModelParameters Train(T
ShortName = ShortName,
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/FactorizationMachine/doc.xml' path='doc/members/member[@name=""FieldAwareFactorizationMachineBinaryClassifier""]/*' />",
@"<include file='../Microsoft.ML.StandardLearners/FactorizationMachine/doc.xml' path='doc/members/example[@name=""FieldAwareFactorizationMachineBinaryClassifier""]/*' />" })]
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input)
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("Train a field-aware factorization machine");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input, () => new FieldAwareFactorizationMachineTrainer(host, input),
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input, () => new FieldAwareFactorizationMachineTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}

Expand Down
67 changes: 41 additions & 26 deletions src/Microsoft.ML.StaticPipe/FactorizationMachineStatic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,46 +24,61 @@ public static class FactorizationMachineExtensions
/// <param name="catalog">The binary classifier catalog trainer object.</param>
/// <param name="label">The label, or dependent variable.</param>
/// <param name="features">The features, or independent variables.</param>
/// <param name="learningRate">Initial learning rate.</param>
/// <param name="numIterations">Number of training iterations.</param>
/// <param name="numLatentDimensions">Latent space dimensions.</param>
/// <param name="advancedSettings">A delegate to set more settings.
/// The settings here will override the ones provided in the direct method signature,
/// if both are present and have different values.
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>/// <param name="onFit">A delegate that is called every time the
/// <param name="onFit">A delegate that is called every time the
Copy link
Member

@wschin wschin Jan 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reference the type of the trained model? It's impossible for user to write onFit without knowing its input type. Many thanks. #Resolved

/// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this. This delegate will receive
/// the model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this.
/// This delegate will receive the model that was trained. The type of the model is <see cref="FieldAwareFactorizationMachineModelParameters"/>.
/// Note that this action cannot change the result in any way; it is only a way for the caller to be informed about what was learnt.</param>
/// <returns>The predicted output.</returns>
public static (Scalar<float> score, Scalar<bool> predictedLabel) FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
Scalar<bool> label, Vector<float>[] features,
Action<FieldAwareFactorizationMachineModelParameters> onFit = null)
{
Contracts.CheckValue(label, nameof(label));
Contracts.CheckNonEmpty(features, nameof(features));

Contracts.CheckValueOrNull(onFit);

var rec = new CustomReconciler((env, labelCol, featureCols) =>
{
var trainer = new FieldAwareFactorizationMachineTrainer(env, featureCols, labelCol);

if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
else
return trainer;
}, label, features);
return rec.Output;
}

/// <summary>
/// Predict a target using a field-aware factorization machine.
/// </summary>
/// <param name="catalog">The binary classifier catalog trainer object.</param>
/// <param name="label">The label, or dependent variable.</param>
/// <param name="features">The features, or independent variables.</param>
/// <param name="options">Advanced arguments to the algorithm.</param>
/// <param name="onFit">A delegate that is called every time the
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this.
/// This delegate will receive the model that was trained. The type of the model is <see cref="FieldAwareFactorizationMachineModelParameters"/>.
/// Note that this action cannot change the result in any way; it is only a way for the caller to
/// be informed about what was learnt.</param>
/// <returns>The predicted output.</returns>
public static (Scalar<float> score, Scalar<bool> predictedLabel) FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
Scalar<bool> label, Vector<float>[] features,
float learningRate = 0.1f,
int numIterations = 5,
int numLatentDimensions = 20,
Action<FieldAwareFactorizationMachineTrainer.Arguments> advancedSettings = null,
FieldAwareFactorizationMachineTrainer.Options options,
Action<FieldAwareFactorizationMachineModelParameters> onFit = null)
{
Contracts.CheckValue(label, nameof(label));
Contracts.CheckNonEmpty(features, nameof(features));

Contracts.CheckParam(learningRate > 0, nameof(learningRate), "Must be positive");
Contracts.CheckParam(numIterations > 0, nameof(numIterations), "Must be positive");
Contracts.CheckParam(numLatentDimensions > 0, nameof(numLatentDimensions), "Must be positive");
Contracts.CheckValueOrNull(advancedSettings);
Contracts.CheckValueOrNull(options);
Contracts.CheckValueOrNull(onFit);

var rec = new CustomReconciler((env, labelCol, featureCols) =>
{
var trainer = new FieldAwareFactorizationMachineTrainer(env, featureCols, labelCol, advancedSettings:
args =>
{
args.LearningRate = learningRate;
args.Iters = numIterations;
args.LatentDim = numLatentDimensions;

advancedSettings?.Invoke(args);
});
var trainer = new FieldAwareFactorizationMachineTrainer(env, options);
if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
else
Expand Down
Loading