-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Trainer estimator cleanup for FastTrees and LightGBM #1352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
054ed9c
c41eb6c
67c33ae
a5840b2
571b2bb
57f20a7
ad51d71
036f733
e03ea66
cea3b67
577c9c4
2c6d446
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -362,9 +362,14 @@ public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn) | |
/// <summary> | ||
/// The <see cref="SchemaShape.Column"/> for the label column for regression tasks. | ||
/// </summary> | ||
/// <param name="labelColumn">name of the weight column</param> | ||
public static SchemaShape.Column MakeU4ScalarLabel(string labelColumn) | ||
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); | ||
/// <param name="columnName">name of the weight column</param> | ||
public static SchemaShape.Column MakeU4ScalarColumn(string columnName) | ||
{ | ||
if (columnName == null) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this check needed? It looks like SchemaShape.Column constructor also checks the name for null or empty string. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, because the check inside the constructor will throw if we pass null. In reply to: 227989462 [](ancestors = 227989462) |
||
return null; | ||
|
||
return new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); | ||
} | ||
|
||
/// <summary> | ||
/// The <see cref="SchemaShape.Column"/> for the feature column. | ||
|
@@ -377,9 +382,10 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn) | |
/// The <see cref="SchemaShape.Column"/> for the weight column. | ||
/// </summary> | ||
/// <param name="weightColumn">name of the weight column</param> | ||
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn) | ||
/// <param name="isExplicit">whether the column is implicitly, or explicitly defined</param> | ||
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true) | ||
{ | ||
if (weightColumn == null) | ||
if (weightColumn == null || !isExplicit) | ||
return null; | ||
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ namespace Microsoft.ML | |
/// <summary> | ||
/// FastTree <see cref="TrainContextBase"/> extension methods. | ||
/// </summary> | ||
public static class FastTreeRegressionExtensions | ||
public static class TreeExtensions | ||
{ | ||
/// <summary> | ||
/// Predict a target using a decision tree regression model trained with the <see cref="FastTreeRegressionTrainer"/>. | ||
|
@@ -40,10 +40,6 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi | |
var env = CatalogUtils.GetEnvironment(ctx); | ||
return new FastTreeRegressionTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); | ||
} | ||
} | ||
|
||
public static class FastTreeBinaryClassificationExtensions | ||
{ | ||
|
||
/// <summary> | ||
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>. | ||
|
@@ -71,10 +67,6 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica | |
var env = CatalogUtils.GetEnvironment(ctx); | ||
return new FastTreeBinaryClassificationTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); | ||
} | ||
} | ||
|
||
public static class FastTreeRankingExtensions | ||
{ | ||
|
||
/// <summary> | ||
/// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the <see cref="FastTreeRankingTrainer"/>. | ||
|
@@ -96,5 +88,62 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer | |
var env = CatalogUtils.GetEnvironment(ctx); | ||
return new FastTreeRankingTrainer(env, label, features, groupId, weights, advancedSettings); | ||
} | ||
|
||
/// <summary> | ||
/// Predict a target using a decision tree regression model trained with the <see cref="FastTreeRegressionTrainer"/>. | ||
/// </summary> | ||
/// <param name="ctx">The <see cref="RegressionContext"/>.</param> | ||
/// <param name="label">The label column.</param> | ||
/// <param name="features">The features colum.</param> | ||
/// <param name="weights">The optional weights column.</param> | ||
/// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this RegressionContext.RegressionTrainers ctx, | ||
string label = DefaultColumnNames.Label, | ||
string features = DefaultColumnNames.Features, | ||
string weights = null, | ||
Action<BinaryClassificationGamTrainer.Arguments> advancedSettings = null) | ||
{ | ||
Contracts.CheckValue(ctx, nameof(ctx)); | ||
var env = CatalogUtils.GetEnvironment(ctx); | ||
return new BinaryClassificationGamTrainer(env, label, features, weights, advancedSettings); | ||
} | ||
|
||
/// <summary> | ||
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>. | ||
/// </summary> | ||
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param> | ||
/// <param name="label">The label column.</param> | ||
/// <param name="features">The features colum.</param> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. column, same on the other function comments. #Resolved |
||
/// <param name="weights">The optional weights column.</param> | ||
/// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
public static RegressionGamTrainer GeneralizedAdditiveMethods(this BinaryClassificationContext.BinaryClassificationTrainers ctx, | ||
string label = DefaultColumnNames.Label, | ||
string features = DefaultColumnNames.Features, | ||
string weights = null, | ||
Action<RegressionGamTrainer.Arguments> advancedSettings = null) | ||
{ | ||
Contracts.CheckValue(ctx, nameof(ctx)); | ||
var env = CatalogUtils.GetEnvironment(ctx); | ||
return new RegressionGamTrainer(env, label, features, weights, advancedSettings); | ||
} | ||
|
||
/// <summary> | ||
/// Predict a target using a decision tree regression model trained with the <see cref="FastTreeTweedieTrainer"/>. | ||
/// </summary> | ||
/// <param name="ctx">The <see cref="RegressionContext"/>.</param> | ||
/// <param name="label">The label column.</param> | ||
/// <param name="features">The features colum.</param> | ||
/// <param name="weights">The optional weights column.</param> | ||
/// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.RegressionTrainers ctx, | ||
string label = DefaultColumnNames.Label, | ||
string features = DefaultColumnNames.Features, | ||
string weights = null, | ||
Action<FastTreeTweedieTrainer.Arguments> advancedSettings = null) | ||
{ | ||
Contracts.CheckValue(ctx, nameof(ctx)); | ||
var env = CatalogUtils.GetEnvironment(ctx); | ||
return new FastTreeTweedieTrainer(env, label, features, weights, advancedSettings: advancedSettings); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ namespace Microsoft.ML.StaticPipe | |
/// <summary> | ||
/// FastTree <see cref="TrainContextBase"/> extension methods. | ||
/// </summary> | ||
public static class FastTreeRegressionExtensions | ||
public static class TreeRegressionExtensions | ||
{ | ||
/// <summary> | ||
/// FastTree <see cref="RegressionContext"/> extension method. | ||
|
@@ -50,7 +50,7 @@ public static Scalar<float> FastTree(this RegressionContext.RegressionTrainers c | |
Action<FastTreeRegressionTrainer.Arguments> advancedSettings = null, | ||
Action<FastTreeRegressionPredictor> onFit = null) | ||
{ | ||
FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); | ||
CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); | ||
|
||
var rec = new TrainerEstimatorReconciler.Regression( | ||
(env, labelName, featuresName, weightsName) => | ||
|
@@ -64,10 +64,6 @@ public static Scalar<float> FastTree(this RegressionContext.RegressionTrainers c | |
|
||
return rec.Score; | ||
} | ||
} | ||
|
||
public static class FastTreeBinaryClassificationExtensions | ||
{ | ||
|
||
/// <summary> | ||
/// FastTree <see cref="BinaryClassificationContext"/> extension method. | ||
|
@@ -98,7 +94,7 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred | |
Action<FastTreeBinaryClassificationTrainer.Arguments> advancedSettings = null, | ||
Action<IPredictorWithFeatureWeights<float>> onFit = null) | ||
{ | ||
FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); | ||
CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); | ||
|
||
var rec = new TrainerEstimatorReconciler.BinaryClassifier( | ||
(env, labelName, featuresName, weightsName) => | ||
|
@@ -114,10 +110,6 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred | |
|
||
return rec.Output; | ||
} | ||
} | ||
|
||
public static class FastTreeRankingExtensions | ||
{ | ||
|
||
/// <summary> | ||
/// FastTree <see cref="RankingContext"/>. | ||
|
@@ -148,7 +140,7 @@ public static Scalar<float> FastTree<TVal>(this RankingContext.RankingTrainers c | |
Action<FastTreeRankingTrainer.Arguments> advancedSettings = null, | ||
Action<FastTreeRankingPredictor> onFit = null) | ||
{ | ||
FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); | ||
CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); | ||
|
||
var rec = new TrainerEstimatorReconciler.Ranker<TVal>( | ||
(env, labelName, featuresName, groupIdName, weightsName) => | ||
|
@@ -161,17 +153,14 @@ public static Scalar<float> FastTree<TVal>(this RankingContext.RankingTrainers c | |
|
||
return rec.Score; | ||
} | ||
} | ||
|
||
internal class FastTreeStaticsUtils | ||
{ | ||
internal static void CheckUserValues(PipelineColumn label, Vector<float> features, Scalar<float> weights, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indent is off here... #Resolved |
||
int numLeaves, | ||
int numTrees, | ||
int minDatapointsInLeafs, | ||
double learningRate, | ||
Delegate advancedSettings, | ||
Delegate onFit) | ||
int numLeaves, | ||
int numTrees, | ||
int minDatapointsInLeafs, | ||
double learningRate, | ||
Delegate advancedSettings, | ||
Delegate onFit) | ||
{ | ||
Contracts.CheckValue(label, nameof(label)); | ||
Contracts.CheckValue(features, nameof(features)); | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I like the idea that every estimator will have the group ID column, but it only makes sense for ranking trainers. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative would be to modify the role mapped data after the estimator constructs it, and inject the groupId, and that feels error-prone. It's just one reference :)
In reply to: 227866406 [](ancestors = 227866406)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we have a
TrainerEstimatorBase
, is there anything that prevents us from having another subclass, if we wanted to have additional roles? We may have to change something below to be a virtual method or have some other extension point, but I feel like this would be helpful anyway.I am not really a fan of having a single utility base class handle everything, since if it is capable of doing everything I suspect it will become too unwieldly.
In reply to: 227895594 [](ancestors = 227895594,227866406)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idk what i like less; the GroupId field here, or making MakeRoles protected virtual ... changing it anyways.
In reply to: 228312151 [](ancestors = 228312151,227895594,227866406)