Skip to content

WIP [Please don't review] : Arguments, Options #2000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;

namespace Microsoft.ML.Samples.Dynamic
{
Expand Down Expand Up @@ -62,16 +63,12 @@ public static void MatrixFactorizationInMemoryData()
// Create a matrix factorization trainer which may consume "Value" as the training label, "MatrixColumnIndex" as the
// matrix's column index, and "MatrixRowIndex" as the matrix's row index. Here nameof(...) is used to extract field
// names' in MatrixElement class.
var options = new MatrixFactorizationTrainer.Options { NumIterations = 10, NumThreads = 1, K = 32 };
var pipeline = mlContext.Recommendation().Trainers.MatrixFactorization(
nameof(MatrixElement.MatrixColumnIndex),
nameof(MatrixElement.MatrixRowIndex),
nameof(MatrixElement.Value),
advancedSettings: s =>
{
s.NumIterations = 10;
s.NumThreads = 1; // To eliminate randomness, # of threads must be 1.
s.K = 32;
});
options);

// Train a matrix factorization model.
var model = pipeline.Fit(dataView);
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.FastTree/BoostingFastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env,
int numTrees,
int minDatapointsInLeaves,
double learningRate,
Action<TArgs> advancedSettings)
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves, advancedSettings)
TArgs options)
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves, options)
{

if (Args.LearningRates != learningRate)
Expand Down
7 changes: 2 additions & 5 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,17 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
int numLeaves,
int numTrees,
int minDatapointsInLeaves,
Action<TArgs> advancedSettings)
TArgs options)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
{
Args = new TArgs();
Args = options ?? new TArgs();

// set up the directly provided values
// override with the directly provided values.
Args.NumLeaves = numLeaves;
Args.NumTrees = numTrees;
Args.MinDocumentsInLeafs = minDatapointsInLeaves;

//apply the advanced args, if the user supplied any
advancedSettings?.Invoke(Args);

Args.LabelColumn = label.Name;
Args.FeatureColumn = featureColumn;

Expand Down
20 changes: 10 additions & 10 deletions src/Microsoft.ML.FastTree/FastTreeArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Trainers.FastTree;

[assembly: EntryPointModule(typeof(FastTreeBinaryClassificationTrainer.Arguments))]
[assembly: EntryPointModule(typeof(FastTreeRegressionTrainer.Arguments))]
[assembly: EntryPointModule(typeof(FastTreeTweedieTrainer.Arguments))]
[assembly: EntryPointModule(typeof(FastTreeRankingTrainer.Arguments))]
[assembly: EntryPointModule(typeof(FastTreeBinaryClassificationTrainer.Options))]
[assembly: EntryPointModule(typeof(FastTreeRegressionTrainer.Options))]
[assembly: EntryPointModule(typeof(FastTreeTweedieTrainer.Options))]
[assembly: EntryPointModule(typeof(FastTreeRankingTrainer.Options))]

namespace Microsoft.ML.Trainers.FastTree
{
Expand All @@ -24,7 +24,7 @@ internal interface IFastTreeTrainerFactory : IComponentFactory<ITrainer>
public sealed partial class FastTreeBinaryClassificationTrainer
{
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Should we use derivatives optimized for unbalanced sets", ShortName = "us")]
[TGUI(Label = "Optimize for unbalanced")]
Expand All @@ -37,9 +37,9 @@ public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
public sealed partial class FastTreeRegressionTrainer
{
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory
{
public Arguments()
public Options()
{
EarlyStoppingMetrics = 1; // Use L1 by default.
}
Expand All @@ -51,7 +51,7 @@ public Arguments()
public sealed partial class FastTreeTweedieTrainer
{
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory
{
// REVIEW: It is possible to estimate this index parameter from the distribution of data, using
// a combination of univariate optimization and grid search, following section 4.2 of the paper. However
Expand All @@ -68,7 +68,7 @@ public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
public sealed partial class FastTreeRankingTrainer
{
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Comma seperated list of gains associated to each relevance label.", ShortName = "gains")]
[TGUI(NoSweep = true)]
Expand Down Expand Up @@ -105,7 +105,7 @@ public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
[TGUI(NotGui = true)]
public bool NormalizeQueryLambdas;

public Arguments()
public Options()
{
EarlyStoppingMetrics = 1;
}
Expand Down
18 changes: 9 additions & 9 deletions src/Microsoft.ML.FastTree/FastTreeClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;

[assembly: LoadableClass(FastTreeBinaryClassificationTrainer.Summary, typeof(FastTreeBinaryClassificationTrainer), typeof(FastTreeBinaryClassificationTrainer.Arguments),
[assembly: LoadableClass(FastTreeBinaryClassificationTrainer.Summary, typeof(FastTreeBinaryClassificationTrainer), typeof(FastTreeBinaryClassificationTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) },
FastTreeBinaryClassificationTrainer.UserNameValue,
FastTreeBinaryClassificationTrainer.LoadNameValue,
Expand Down Expand Up @@ -103,7 +103,7 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad

/// <include file = 'doc.xml' path='doc/members/member[@name="FastTree"]/*' />
public sealed partial class FastTreeBinaryClassificationTrainer :
BoostingFastTreeTrainerBase<FastTreeBinaryClassificationTrainer.Arguments, BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>>, IPredictorWithFeatureWeights<float>>
BoostingFastTreeTrainerBase<FastTreeBinaryClassificationTrainer.Options, BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>>, IPredictorWithFeatureWeights<float>>
{
/// <summary>
/// The LoadName for the assembly containing the trainer.
Expand All @@ -127,7 +127,7 @@ public sealed partial class FastTreeBinaryClassificationTrainer :
/// <param name="minDatapointsInLeaves">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
/// <param name="numLeaves">The max number of leaves in each regression tree.</param>
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
/// <param name="options">Advanced arguments for the algorithm.</param>
public FastTreeBinaryClassificationTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
Expand All @@ -136,17 +136,17 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env,
int numTrees = Defaults.NumTrees,
int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves,
double learningRate = Defaults.LearningRates,
Action<Arguments> advancedSettings = null)
: base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings)
Options options = null)
: base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, options)
{
// Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss
_sigmoidParameter = 2.0 * Args.LearningRates;
}

/// <summary>
/// Initializes a new instance of <see cref="FastTreeBinaryClassificationTrainer"/> by using the legacy <see cref="Arguments"/> class.
/// Initializes a new instance of <see cref="FastTreeBinaryClassificationTrainer"/> by using the legacy <see cref="Options"/> class.
/// </summary>
internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args)
internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options args)
: base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
{
// Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss
Expand Down Expand Up @@ -402,14 +402,14 @@ public static partial class FastTree
ShortName = FastTreeBinaryClassificationTrainer.ShortName,
XmlInclude = new[] { @"<include file='../Microsoft.ML.FastTree/doc.xml' path='doc/members/member[@name=""FastTree""]/*' />",
@"<include file='../Microsoft.ML.FastTree/doc.xml' path='doc/members/example[@name=""FastTreeBinaryClassifier""]/*' />" })]
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastTreeBinaryClassificationTrainer.Arguments input)
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastTreeBinaryClassificationTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainFastTree");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<FastTreeBinaryClassificationTrainer.Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
return LearnerEntryPointsUtils.Train<FastTreeBinaryClassificationTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new FastTreeBinaryClassificationTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn),
Expand Down
22 changes: 11 additions & 11 deletions src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
using Microsoft.ML.Training;

// REVIEW: Do we really need all these names?
[assembly: LoadableClass(FastTreeRankingTrainer.Summary, typeof(FastTreeRankingTrainer), typeof(FastTreeRankingTrainer.Arguments),
[assembly: LoadableClass(FastTreeRankingTrainer.Summary, typeof(FastTreeRankingTrainer), typeof(FastTreeRankingTrainer.Options),
new[] { typeof(SignatureRankerTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) },
FastTreeRankingTrainer.UserNameValue,
FastTreeRankingTrainer.LoadNameValue,
Expand All @@ -43,7 +43,7 @@ namespace Microsoft.ML.Trainers.FastTree
{
/// <include file='doc.xml' path='doc/members/member[@name="FastTree"]/*' />
public sealed partial class FastTreeRankingTrainer
: BoostingFastTreeTrainerBase<FastTreeRankingTrainer.Arguments, RankingPredictionTransformer<FastTreeRankingModelParameters>, FastTreeRankingModelParameters>
: BoostingFastTreeTrainerBase<FastTreeRankingTrainer.Options, RankingPredictionTransformer<FastTreeRankingModelParameters>, FastTreeRankingModelParameters>
{
internal const string LoadNameValue = "FastTreeRanking";
internal const string UserNameValue = "FastTree (Boosted Trees) Ranking";
Expand Down Expand Up @@ -71,7 +71,7 @@ public sealed partial class FastTreeRankingTrainer
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
/// <param name="minDatapointsInLeaves">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
/// <param name="options">Advanced arguments to the algorithm.</param>
public FastTreeRankingTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
Expand All @@ -81,16 +81,16 @@ public FastTreeRankingTrainer(IHostEnvironment env,
int numTrees = Defaults.NumTrees,
int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves,
double learningRate = Defaults.LearningRates,
Action<Arguments> advancedSettings = null)
: base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings)
Options options = null)
: base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves, learningRate, options)
{
Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn));
}

/// <summary>
/// Initializes a new instance of <see cref="FastTreeRankingTrainer"/> by using the legacy <see cref="Arguments"/> class.
/// Initializes a new instance of <see cref="FastTreeRankingTrainer"/> by using the legacy <see cref="Options"/> class.
/// </summary>
internal FastTreeRankingTrainer(IHostEnvironment env, Arguments args)
internal FastTreeRankingTrainer(IHostEnvironment env, Options args)
: base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn))
{
}
Expand Down Expand Up @@ -548,7 +548,7 @@ private enum DupeIdInfo
// Keeps track of labels of top 3 documents per query
public short[][] TrainQueriesTopLabels;

public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Arguments args, IParallelTraining parallelTraining)
public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options args, IParallelTraining parallelTraining)
: base(trainset,
args.LearningRates,
args.Shrinkage,
Expand Down Expand Up @@ -646,7 +646,7 @@ private void SetupSecondaryGains(Arguments args)
}
#endif

private void SetupBaselineRisk(Arguments args)
private void SetupBaselineRisk(Options args)
{
double[] scores = Dataset.Skeleton.GetData<double>("BaselineScores");
if (scores == null)
Expand Down Expand Up @@ -1162,14 +1162,14 @@ public static partial class FastTree
ShortName = FastTreeRankingTrainer.ShortName,
XmlInclude = new[] { @"<include file='../Microsoft.ML.FastTree/doc.xml' path='doc/members/member[@name=""FastTree""]/*' />",
@"<include file='../Microsoft.ML.FastTree/doc.xml' path='doc/members/example[@name=""FastTreeRanker""]/*' />"})]
public static CommonOutputs.RankingOutput TrainRanking(IHostEnvironment env, FastTreeRankingTrainer.Arguments input)
public static CommonOutputs.RankingOutput TrainRanking(IHostEnvironment env, FastTreeRankingTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainFastTree");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<FastTreeRankingTrainer.Arguments, CommonOutputs.RankingOutput>(host, input,
return LearnerEntryPointsUtils.Train<FastTreeRankingTrainer.Options, CommonOutputs.RankingOutput>(host, input,
() => new FastTreeRankingTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn),
Expand Down
Loading