Skip to content

More trainer extensions, bug fixes and consistency across trainer extensions #1524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 10, 2018
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/code/MlNetCookBook.md
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ var learningPipeline = reader.MakeNewEstimator()
IEstimator<ITransformer> dynamicPipe = learningPipeline.AsDynamic;

// Create a binary classification trainer.
var binaryTrainer = mlContext.BinaryClassification.Trainers.AveragedPerceptron();
var binaryTrainer = mlContext.BinaryClassification.Trainers.AveragedPerceptron("Label", "Features");

// Append the OVA learner to the pipeline.
dynamicPipe = dynamicPipe.Append(new Ova(mlContext, binaryTrainer));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public static NormalizingEstimator Normalize(this TransformsCatalog catalog,
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[ConcatWith] (](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/MinMaxNormalizer.cs?range =6-11,16-86)]
/// [!code-csharp[ConcatWith] (](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/MinMaxNormalizer.cs?range=6-11,16-86)]
/// ]]>
/// </format>
/// </example>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public Arguments()
BasePredictors = new[]
{
ComponentFactoryUtils.CreateFromFunction(
env => new MulticlassLogisticRegression(env, FeatureColumn, LabelColumn))
env => new MulticlassLogisticRegression(env, LabelColumn, FeatureColumn))
};
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,10 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
Args.FeatureColumn = featureColumn;

if (weightColumn != null)
Args.WeightColumn = Optional<string>.Explicit(weightColumn); ;
Args.WeightColumn = Optional<string>.Explicit(weightColumn);

if (groupIdColumn != null)
Args.GroupIdColumn = Optional<string>.Explicit(groupIdColumn); ;
Args.GroupIdColumn = Optional<string>.Explicit(groupIdColumn);

// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.FastTree/RandomForestRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ public sealed class Arguments : FastForestArgumentsBase
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="minDocumentsInLeafs">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
/// <param name="weightColumn">The optional name for the column containing the initial weight.</param>
/// <param name="numLeaves">The max number of leaves in each regression tree.</param>
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
/// <param name="minDocumentsInLeafs">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
public FastForestRegression(IHostEnvironment env,
string labelColumn,
Expand Down
98 changes: 72 additions & 26 deletions src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace Microsoft.ML
{
/// <summary>
/// FastTree <see cref="TrainContextBase"/> extension methods.
/// Tree <see cref="TrainContextBase"/> extension methods.
/// </summary>
public static class TreeExtensions
{
Expand All @@ -27,8 +27,8 @@ public static class TreeExtensions
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static FastTreeRegressionTrainer FastTree(this RegressionContext.RegressionTrainers ctx,
string label = DefaultColumnNames.Label,
string features = DefaultColumnNames.Features,
string label,
Copy link
Contributor

@artidoro artidoro Nov 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen in other files that features comes before label (SymSgdClassificationTrainer for example). We should pick one order and stick to it, otherwise it is very confusing. I would suggest features, label, optional weight. But either ways could work! #Resolved

Copy link
Contributor

@artidoro artidoro Nov 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think you already made the decision, label, features, optional weight, which is consistent with all other files in this PR. Maybe we should consider opening an issue and updating the order the arguments in other files in another PR?


In reply to: 230870480 [](ancestors = 230870480)

string features,
string weights = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
Expand All @@ -54,8 +54,8 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
string label = DefaultColumnNames.Label,
string features = DefaultColumnNames.Features,
string label,
string features,
string weights = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
Expand All @@ -71,7 +71,7 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica
/// <summary>
/// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the <see cref="FastTreeRankingTrainer"/>.
/// </summary>
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
/// <param name="ctx">The <see cref="RankingContext"/>.</param>
/// <param name="label">The label column.</param>
/// <param name="features">The features column.</param>
/// <param name="groupId">The groupId column.</param>
Expand All @@ -82,9 +82,9 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx,
string label = DefaultColumnNames.Label,
string groupId = DefaultColumnNames.GroupId,
string features = DefaultColumnNames.Features,
string label,
string features,
string groupId ,
string weights = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
Expand All @@ -100,21 +100,17 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer
/// <summary>
/// Predict a target using a decision tree regression model trained with the <see cref="FastTreeRegressionTrainer"/>.
/// </summary>
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
/// <param name="label">The label column.</param>
/// <param name="features">The features column.</param>
/// <param name="weights">The optional weights column.</param>
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
/// <param name="numLeaves">The maximum number of leaves per decision tree.</param>
/// <param name="minDatapointsInLeafs">The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this RegressionContext.RegressionTrainers ctx,
public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
string label = DefaultColumnNames.Label,
string features = DefaultColumnNames.Features,
string weights = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs,
double learningRate = Defaults.LearningRates,
Action<BinaryClassificationGamTrainer.Arguments> advancedSettings = null)
Expand All @@ -125,23 +121,19 @@ public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this Reg
}

/// <summary>
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>.
/// Predict a target using a decision tree binary classification model trained with the <see cref="RegressionGamTrainer"/>.
/// </summary>
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
/// <param name="label">The label column.</param>
/// <param name="features">The features column.</param>
/// <param name="weights">The optional weights column.</param>
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
/// <param name="numLeaves">The maximum number of leaves per decision tree.</param>
/// <param name="minDatapointsInLeafs">The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data.</param>
Copy link
Contributor

@Zruty0 Zruty0 Nov 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minDatapointsInLeafs [](start = 25, length = 20)

I sense inconsistent naming #Resolved

Copy link
Member Author

@sfilipi sfilipi Nov 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't see what you are refering to..


In reply to: 231226770 [](ancestors = 231226770)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MinDocsPerLeaf, MinDocumentsPerLeaf, MinDatapointsInLeafs

Also 'leaves', not 'leafs'


In reply to: 231232309 [](ancestors = 231232309,231226770)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh lol. Apparently it's been like this for a long time.


In reply to: 231367379 [](ancestors = 231367379,231232309,231226770)

/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static RegressionGamTrainer GeneralizedAdditiveMethods(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
string label = DefaultColumnNames.Label,
string features = DefaultColumnNames.Features,
public static RegressionGamTrainer GeneralizedAdditiveMethods(this RegressionContext.RegressionTrainers ctx,
string label,
string features,
string weights = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs,
double learningRate = Defaults.LearningRates,
Action<RegressionGamTrainer.Arguments> advancedSettings = null)
Expand All @@ -164,8 +156,8 @@ public static RegressionGamTrainer GeneralizedAdditiveMethods(this BinaryClassif
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.RegressionTrainers ctx,
string label = DefaultColumnNames.Label,
string features = DefaultColumnNames.Features,
string label,
string features,
string weights = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
Expand All @@ -177,5 +169,59 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr
var env = CatalogUtils.GetEnvironment(ctx);
return new FastTreeTweedieTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings);
}

/// <summary>
/// Predict a target using a decision tree regression model trained with the <see cref="FastForestRegression"/>.
/// </summary>
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
/// <param name="label">The label column.</param>
/// <param name="features">The features column.</param>
/// <param name="weights">The optional weights column.</param>
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
/// <param name="numLeaves">The maximum number of leaves per decision tree.</param>
/// <param name="minDatapointsInLeafs">The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static FastForestRegression FastForest(this RegressionContext.RegressionTrainers ctx,
string label,
string features,
string weights = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs,
double learningRate = Defaults.LearningRates,
Action<FastForestRegression.Arguments> advancedSettings = null)
{
Contracts.CheckValue(ctx, nameof(ctx));
var env = CatalogUtils.GetEnvironment(ctx);
return new FastForestRegression(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings);
}

/// <summary>
/// Predict a target using a decision tree regression model trained with the <see cref="FastForestClassification"/>.
/// </summary>
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
/// <param name="label">The label column.</param>
/// <param name="features">The features column.</param>
/// <param name="weights">The optional weights column.</param>
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
/// <param name="numLeaves">The maximum number of leaves per decision tree.</param>
/// <param name="minDatapointsInLeafs">The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static FastForestClassification FastForest(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
string label,
string features,
string weights = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs,
double learningRate = Defaults.LearningRates,
Action<FastForestClassification.Arguments> advancedSettings = null)
{
Contracts.CheckValue(ctx, nameof(ctx));
var env = CatalogUtils.GetEnvironment(ctx);
return new FastForestClassification(env, label, features, weights,numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings);
}
}
}
15 changes: 9 additions & 6 deletions src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,16 @@ public sealed class Arguments : LearnerInputBaseWithWeight
/// Initializes a new instance of <see cref="OlsLinearRegressionTrainer"/>
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="weightColumn">The name for the example weight column.</param>
/// <param name="label">The name of the label column.</param>
/// <param name="features">The name of the feature column.</param>
/// <param name="weights">The name for the optional example weight column.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
public OlsLinearRegressionTrainer(IHostEnvironment env, string featureColumn, string labelColumn,
string weightColumn = null, Action<Arguments> advancedSettings = null)
: this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings))
public OlsLinearRegressionTrainer(IHostEnvironment env,
string label,
string features,
string weights = null,
Action<Arguments> advancedSettings = null)
: this(env, ArgsInit(features, label, weights, advancedSettings))
{
}

Expand Down
17 changes: 10 additions & 7 deletions src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,23 @@ protected override TPredictor TrainModelCore(TrainContext context)
/// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="label">The name of the label column.</param>
/// <param name="features">The name of the feature column.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
public SymSgdClassificationTrainer(IHostEnvironment env, string featureColumn, string labelColumn, Action<Arguments> advancedSettings = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn),
TrainerUtils.MakeBoolScalarLabel(labelColumn))
public SymSgdClassificationTrainer(IHostEnvironment env,
string label,
string features,
Action<Arguments> advancedSettings = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(features),
TrainerUtils.MakeBoolScalarLabel(label))
{
_args = new Arguments();

// Apply the advanced args, if the user supplied any.
_args.Check(Host);
advancedSettings?.Invoke(_args);
_args.FeatureColumn = featureColumn;
_args.LabelColumn = labelColumn;
_args.FeatureColumn = features;
_args.LabelColumn = label;

Info = new TrainerInfo();
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public static class KMeansClusteringExtensions
/// <param name="clustersCount">The number of clusters to use for KMeans.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static KMeansPlusPlusTrainer KMeans(this ClusteringContext.ClusteringTrainers ctx,
string features = DefaultColumnNames.Features,
string features,
string weights = null,
int clustersCount = KMeansPlusPlusTrainer.Defaults.K,
Action<KMeansPlusPlusTrainer.Arguments> advancedSettings = null)
Expand Down
Loading