-
Notifications
You must be signed in to change notification settings - Fork 1.9k
More trainer extensions, bug fixes and consistency across trainer extensions #1524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
f698b96
d10ea0a
011de99
fb99d51
b776880
d4b9ce7
191f28c
7d98033
e2a360b
5f2fb8e
d2aaf1f
b18e2f4
27b1842
a84daf0
b2631b7
1bacc1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
namespace Microsoft.ML | ||
{ | ||
/// <summary> | ||
/// FastTree <see cref="TrainContextBase"/> extension methods. | ||
/// Tree <see cref="TrainContextBase"/> extension methods. | ||
/// </summary> | ||
public static class TreeExtensions | ||
{ | ||
|
@@ -27,8 +27,8 @@ public static class TreeExtensions | |
/// <param name="learningRate">The learning rate.</param> | ||
/// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
public static FastTreeRegressionTrainer FastTree(this RegressionContext.RegressionTrainers ctx, | ||
string label = DefaultColumnNames.Label, | ||
string features = DefaultColumnNames.Features, | ||
string label, | ||
string features, | ||
string weights = null, | ||
int numLeaves = Defaults.NumLeaves, | ||
int numTrees = Defaults.NumTrees, | ||
|
@@ -54,8 +54,8 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi | |
/// <param name="learningRate">The learning rate.</param> | ||
/// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, | ||
string label = DefaultColumnNames.Label, | ||
string features = DefaultColumnNames.Features, | ||
string label, | ||
string features, | ||
string weights = null, | ||
int numLeaves = Defaults.NumLeaves, | ||
int numTrees = Defaults.NumTrees, | ||
|
@@ -71,7 +71,7 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica | |
/// <summary> | ||
/// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the <see cref="FastTreeRankingTrainer"/>. | ||
/// </summary> | ||
/// <param name="ctx">The <see cref="RegressionContext"/>.</param> | ||
/// <param name="ctx">The <see cref="RankingContext"/>.</param> | ||
/// <param name="label">The label column.</param> | ||
/// <param name="features">The features column.</param> | ||
/// <param name="groupId">The groupId column.</param> | ||
|
@@ -82,9 +82,9 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica | |
/// <param name="learningRate">The learning rate.</param> | ||
/// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx, | ||
string label = DefaultColumnNames.Label, | ||
string groupId = DefaultColumnNames.GroupId, | ||
string features = DefaultColumnNames.Features, | ||
string label, | ||
string features, | ||
string groupId , | ||
string weights = null, | ||
int numLeaves = Defaults.NumLeaves, | ||
int numTrees = Defaults.NumTrees, | ||
|
@@ -100,21 +100,17 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer | |
/// <summary> | ||
/// Predict a target using a decision tree regression model trained with the <see cref="FastTreeRegressionTrainer"/>. | ||
/// </summary> | ||
/// <param name="ctx">The <see cref="RegressionContext"/>.</param> | ||
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param> | ||
/// <param name="label">The label column.</param> | ||
/// <param name="features">The features column.</param> | ||
/// <param name="weights">The optional weights column.</param> | ||
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param> | ||
/// <param name="numLeaves">The maximum number of leaves per decision tree.</param> | ||
/// <param name="minDatapointsInLeafs">The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data.</param> | ||
/// <param name="learningRate">The learning rate.</param> | ||
/// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this RegressionContext.RegressionTrainers ctx, | ||
public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this BinaryClassificationContext.BinaryClassificationTrainers ctx, | ||
string label = DefaultColumnNames.Label, | ||
string features = DefaultColumnNames.Features, | ||
string weights = null, | ||
int numLeaves = Defaults.NumLeaves, | ||
int numTrees = Defaults.NumTrees, | ||
int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, | ||
double learningRate = Defaults.LearningRates, | ||
Action<BinaryClassificationGamTrainer.Arguments> advancedSettings = null) | ||
|
@@ -125,23 +121,19 @@ public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this Reg | |
} | ||
|
||
/// <summary> | ||
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>. | ||
/// Predict a target using a decision tree binary classification model trained with the <see cref="RegressionGamTrainer"/>. | ||
/// </summary> | ||
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param> | ||
/// <param name="ctx">The <see cref="RegressionContext"/>.</param> | ||
/// <param name="label">The label column.</param> | ||
/// <param name="features">The features column.</param> | ||
/// <param name="weights">The optional weights column.</param> | ||
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param> | ||
/// <param name="numLeaves">The maximum number of leaves per decision tree.</param> | ||
/// <param name="minDatapointsInLeafs">The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data.</param> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I sense inconsistent naming #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MinDocsPerLeaf, MinDocumentsPerLeaf, MinDatapointsInLeafs Also 'leaves', not 'leafs' In reply to: 231232309 [](ancestors = 231232309,231226770) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh lol. Apparently it's been like this for a long time. In reply to: 231367379 [](ancestors = 231367379,231232309,231226770) |
||
/// <param name="learningRate">The learning rate.</param> | ||
/// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
public static RegressionGamTrainer GeneralizedAdditiveMethods(this BinaryClassificationContext.BinaryClassificationTrainers ctx, | ||
string label = DefaultColumnNames.Label, | ||
string features = DefaultColumnNames.Features, | ||
public static RegressionGamTrainer GeneralizedAdditiveMethods(this RegressionContext.RegressionTrainers ctx, | ||
string label, | ||
string features, | ||
string weights = null, | ||
int numLeaves = Defaults.NumLeaves, | ||
int numTrees = Defaults.NumTrees, | ||
int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, | ||
double learningRate = Defaults.LearningRates, | ||
Action<RegressionGamTrainer.Arguments> advancedSettings = null) | ||
|
@@ -164,8 +156,8 @@ public static RegressionGamTrainer GeneralizedAdditiveMethods(this BinaryClassif | |
/// <param name="learningRate">The learning rate.</param> | ||
/// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.RegressionTrainers ctx, | ||
string label = DefaultColumnNames.Label, | ||
string features = DefaultColumnNames.Features, | ||
string label, | ||
string features, | ||
string weights = null, | ||
int numLeaves = Defaults.NumLeaves, | ||
int numTrees = Defaults.NumTrees, | ||
|
@@ -177,5 +169,59 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr | |
var env = CatalogUtils.GetEnvironment(ctx); | ||
return new FastTreeTweedieTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); | ||
} | ||
|
||
/// <summary> | ||
/// Predict a target using a decision tree regression model trained with the <see cref="FastForestRegression"/>. | ||
/// </summary> | ||
/// <param name="ctx">The <see cref="RegressionContext"/>.</param> | ||
/// <param name="label">The label column.</param> | ||
/// <param name="features">The features column.</param> | ||
/// <param name="weights">The optional weights column.</param> | ||
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param> | ||
/// <param name="numLeaves">The maximum number of leaves per decision tree.</param> | ||
/// <param name="minDatapointsInLeafs">The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data.</param> | ||
/// <param name="learningRate">The learning rate.</param> | ||
/// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
public static FastForestRegression FastForest(this RegressionContext.RegressionTrainers ctx, | ||
string label, | ||
string features, | ||
string weights = null, | ||
int numLeaves = Defaults.NumLeaves, | ||
int numTrees = Defaults.NumTrees, | ||
int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, | ||
double learningRate = Defaults.LearningRates, | ||
Action<FastForestRegression.Arguments> advancedSettings = null) | ||
{ | ||
Contracts.CheckValue(ctx, nameof(ctx)); | ||
var env = CatalogUtils.GetEnvironment(ctx); | ||
return new FastForestRegression(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); | ||
} | ||
|
||
/// <summary> | ||
/// Predict a target using a decision tree regression model trained with the <see cref="FastForestClassification"/>. | ||
/// </summary> | ||
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param> | ||
/// <param name="label">The label column.</param> | ||
/// <param name="features">The features column.</param> | ||
/// <param name="weights">The optional weights column.</param> | ||
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param> | ||
/// <param name="numLeaves">The maximum number of leaves per decision tree.</param> | ||
/// <param name="minDatapointsInLeafs">The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data.</param> | ||
/// <param name="learningRate">The learning rate.</param> | ||
/// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
public static FastForestClassification FastForest(this BinaryClassificationContext.BinaryClassificationTrainers ctx, | ||
string label, | ||
string features, | ||
string weights = null, | ||
int numLeaves = Defaults.NumLeaves, | ||
int numTrees = Defaults.NumTrees, | ||
int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, | ||
double learningRate = Defaults.LearningRates, | ||
Action<FastForestClassification.Arguments> advancedSettings = null) | ||
{ | ||
Contracts.CheckValue(ctx, nameof(ctx)); | ||
var env = CatalogUtils.GetEnvironment(ctx); | ||
return new FastForestClassification(env, label, features, weights,numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); | ||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have seen in other files that features comes before label (SymSgdClassificationTrainer for example). We should pick one order and stick to it, otherwise it is very confusing. I would suggest features, label, optional weight. But either ways could work! #Resolved
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I think you already made the decision, label, features, optional weight, which is consistent with all other files in this PR. Maybe we should consider opening an issue and updating the order the arguments in other files in another PR?
In reply to: 230870480 [](ancestors = 230870480)