Skip to content

Trainer estimator cleanup for FastTrees and LightGBM #1352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 27, 2018
15 changes: 13 additions & 2 deletions src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstim
/// </summary>
public readonly SchemaShape.Column WeightColumn;

/// <summary>
/// The optional groupID column that the ranking trainers expects.
/// </summary>
public readonly SchemaShape.Column GroupIdColumn;
Copy link
Contributor

@Zruty0 Zruty0 Oct 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GroupIdColumn [](start = 43, length = 13)

I am not sure I like the idea that every estimator will have the group ID column, but it only makes sense for ranking trainers. #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative would be to modify the role mapped data after the estimator constructs it, and inject the groupId, and that feels error-prone. It's just one reference :)


In reply to: 227866406 [](ancestors = 227866406)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we have a TrainerEstimatorBase, is there anything that prevents us from having another subclass, if we wanted to have additional roles? We may have to change something below to be a virtual method or have some other extension point, but I feel like this would be helpful anyway.

I am not really a fan of having a single utility base class handle everything, since if it is capable of doing everything I suspect it will become too unwieldly.


In reply to: 227895594 [](ancestors = 227895594,227866406)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk what i like less; the GroupId field here, or making MakeRoles protected virtual ... changing it anyways.


In reply to: 228312151 [](ancestors = 228312151,227895594,227866406)


protected readonly IHost Host;

/// <summary>
Expand All @@ -50,17 +55,23 @@ public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstim

public abstract PredictionKind PredictionKind { get; }

public TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null)
public TrainerEstimatorBase(IHost host,
SchemaShape.Column feature,
SchemaShape.Column label,
SchemaShape.Column weight = null,
SchemaShape.Column groupId = null)
{
Contracts.CheckValue(host, nameof(host));
Host = host;
Host.CheckValue(feature, nameof(feature));
Host.CheckValueOrNull(label);
Host.CheckValueOrNull(weight);
Host.CheckValueOrNull(groupId);

FeatureColumn = feature;
LabelColumn = label;
WeightColumn = weight;
GroupIdColumn = groupId;
}

public TTransformer Fit(IDataView input) => TrainTransformer(input);
Expand Down Expand Up @@ -150,7 +161,7 @@ protected TTransformer TrainTransformer(IDataView trainSet,
protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema);

private RoleMappedData MakeRoles(IDataView data) =>
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name);
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name);

IPredictor ITrainer.Train(TrainContext context) => Train(context);
}
Expand Down
16 changes: 11 additions & 5 deletions src/Microsoft.ML.Data/Training/TrainerUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,14 @@ public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn)
/// <summary>
/// The <see cref="SchemaShape.Column"/> for the label column for regression tasks.
/// </summary>
/// <param name="labelColumn">name of the weight column</param>
public static SchemaShape.Column MakeU4ScalarLabel(string labelColumn)
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
/// <param name="columnName">name of the weight column</param>
public static SchemaShape.Column MakeU4ScalarColumn(string columnName)
{
if (columnName == null)
Copy link
Member

@singlis singlis Oct 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check needed? It looks like SchemaShape.Column constructor also checks the name for null or empty string. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, because the check inside the constructor will throw if we pass null.


In reply to: 227989462 [](ancestors = 227989462)

return null;

return new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
}

/// <summary>
/// The <see cref="SchemaShape.Column"/> for the feature column.
Expand All @@ -377,9 +382,10 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn)
/// The <see cref="SchemaShape.Column"/> for the weight column.
/// </summary>
/// <param name="weightColumn">name of the weight column</param>
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
/// <param name="isExplicit">whether the column is implicitly, or explicitly defined</param>
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true)
{
if (weightColumn == null)
if (weightColumn == null || !isExplicit)
return null;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
{
Args = new TArgs();

Expand Down Expand Up @@ -128,7 +128,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
/// Legacy constructor that is used when invoking the classes deriving from this, through maml.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
{
Host.CheckValue(args, nameof(args));
Args = args;
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.FastTree/FastTreeTweedie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ namespace Microsoft.ML.Trainers.FastTree
public sealed partial class FastTreeTweedieTrainer
: BoostingFastTreeTrainerBase<FastTreeTweedieTrainer.Arguments, RegressionPredictionTransformer<FastTreeTweediePredictor>, FastTreeTweediePredictor>
{
public const string LoadNameValue = "FastTreeTweedieRegression";
public const string UserNameValue = "FastTree (Boosted Trees) Tweedie Regression";
public const string Summary = "Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression.";
public const string ShortName = "fttweedie";
internal const string LoadNameValue = "FastTreeTweedieRegression";
internal const string UserNameValue = "FastTree (Boosted Trees) Tweedie Regression";
internal const string Summary = "Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression.";
internal const string ShortName = "fttweedie";

private TestHistory _firstTestSetHistory;
private Test _trainRegressionTest;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/GamTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ private protected GamTrainerBase(IHostEnvironment env, string name, SchemaShape.

private protected GamTrainerBase(IHostEnvironment env, TArgs args, string name, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
{
Contracts.CheckValue(env, nameof(env));
Host.CheckValue(args, nameof(args));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Microsoft.ML
/// <summary>
/// FastTree <see cref="TrainContextBase"/> extension methods.
/// </summary>
public static class FastTreeRegressionExtensions
public static class TreeExtensions
{
/// <summary>
/// Predict a target using a decision tree regression model trained with the <see cref="FastTreeRegressionTrainer"/>.
Expand Down Expand Up @@ -40,10 +40,6 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi
var env = CatalogUtils.GetEnvironment(ctx);
return new FastTreeRegressionTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings);
}
}

public static class FastTreeBinaryClassificationExtensions
{

/// <summary>
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>.
Expand Down Expand Up @@ -71,10 +67,6 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica
var env = CatalogUtils.GetEnvironment(ctx);
return new FastTreeBinaryClassificationTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings);
}
}

public static class FastTreeRankingExtensions
{

/// <summary>
/// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the <see cref="FastTreeRankingTrainer"/>.
Expand All @@ -96,5 +88,62 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer
var env = CatalogUtils.GetEnvironment(ctx);
return new FastTreeRankingTrainer(env, label, features, groupId, weights, advancedSettings);
}

/// <summary>
/// Predict a target using a decision tree regression model trained with the <see cref="FastTreeRegressionTrainer"/>.
/// </summary>
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
/// <param name="label">The label column.</param>
/// <param name="features">The features colum.</param>
/// <param name="weights">The optional weights column.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this RegressionContext.RegressionTrainers ctx,
string label = DefaultColumnNames.Label,
string features = DefaultColumnNames.Features,
string weights = null,
Action<BinaryClassificationGamTrainer.Arguments> advancedSettings = null)
{
Contracts.CheckValue(ctx, nameof(ctx));
var env = CatalogUtils.GetEnvironment(ctx);
return new BinaryClassificationGamTrainer(env, label, features, weights, advancedSettings);
}

/// <summary>
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>.
/// </summary>
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
/// <param name="label">The label column.</param>
/// <param name="features">The features colum.</param>
Copy link
Member

@singlis singlis Oct 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

column, same on the other function comments. #Resolved

/// <param name="weights">The optional weights column.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static RegressionGamTrainer GeneralizedAdditiveMethods(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
string label = DefaultColumnNames.Label,
string features = DefaultColumnNames.Features,
string weights = null,
Action<RegressionGamTrainer.Arguments> advancedSettings = null)
{
Contracts.CheckValue(ctx, nameof(ctx));
var env = CatalogUtils.GetEnvironment(ctx);
return new RegressionGamTrainer(env, label, features, weights, advancedSettings);
}

/// <summary>
/// Predict a target using a decision tree regression model trained with the <see cref="FastTreeTweedieTrainer"/>.
/// </summary>
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
/// <param name="label">The label column.</param>
/// <param name="features">The features colum.</param>
/// <param name="weights">The optional weights column.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.RegressionTrainers ctx,
string label = DefaultColumnNames.Label,
string features = DefaultColumnNames.Features,
string weights = null,
Action<FastTreeTweedieTrainer.Arguments> advancedSettings = null)
{
Contracts.CheckValue(ctx, nameof(ctx));
var env = CatalogUtils.GetEnvironment(ctx);
return new FastTreeTweedieTrainer(env, label, features, weights, advancedSettings: advancedSettings);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Microsoft.ML.StaticPipe
/// <summary>
/// FastTree <see cref="TrainContextBase"/> extension methods.
/// </summary>
public static class FastTreeRegressionExtensions
public static class TreeRegressionExtensions
{
/// <summary>
/// FastTree <see cref="RegressionContext"/> extension method.
Expand Down Expand Up @@ -50,7 +50,7 @@ public static Scalar<float> FastTree(this RegressionContext.RegressionTrainers c
Action<FastTreeRegressionTrainer.Arguments> advancedSettings = null,
Action<FastTreeRegressionPredictor> onFit = null)
{
FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit);
CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit);

var rec = new TrainerEstimatorReconciler.Regression(
(env, labelName, featuresName, weightsName) =>
Expand All @@ -64,10 +64,6 @@ public static Scalar<float> FastTree(this RegressionContext.RegressionTrainers c

return rec.Score;
}
}

public static class FastTreeBinaryClassificationExtensions
{

/// <summary>
/// FastTree <see cref="BinaryClassificationContext"/> extension method.
Expand Down Expand Up @@ -98,7 +94,7 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred
Action<FastTreeBinaryClassificationTrainer.Arguments> advancedSettings = null,
Action<IPredictorWithFeatureWeights<float>> onFit = null)
{
FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit);
CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit);

var rec = new TrainerEstimatorReconciler.BinaryClassifier(
(env, labelName, featuresName, weightsName) =>
Expand All @@ -114,10 +110,6 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred

return rec.Output;
}
}

public static class FastTreeRankingExtensions
{

/// <summary>
/// FastTree <see cref="RankingContext"/>.
Expand Down Expand Up @@ -148,7 +140,7 @@ public static Scalar<float> FastTree<TVal>(this RankingContext.RankingTrainers c
Action<FastTreeRankingTrainer.Arguments> advancedSettings = null,
Action<FastTreeRankingPredictor> onFit = null)
{
FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit);
CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit);

var rec = new TrainerEstimatorReconciler.Ranker<TVal>(
(env, labelName, featuresName, groupIdName, weightsName) =>
Expand All @@ -161,17 +153,14 @@ public static Scalar<float> FastTree<TVal>(this RankingContext.RankingTrainers c

return rec.Score;
}
}

internal class FastTreeStaticsUtils
{
internal static void CheckUserValues(PipelineColumn label, Vector<float> features, Scalar<float> weights,
Copy link
Member

@singlis singlis Oct 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent is off here... #Resolved

int numLeaves,
int numTrees,
int minDatapointsInLeafs,
double learningRate,
Delegate advancedSettings,
Delegate onFit)
int numLeaves,
int numTrees,
int minDatapointsInLeafs,
double learningRate,
Delegate advancedSettings,
Delegate onFit)
{
Contracts.CheckValue(label, nameof(label));
Contracts.CheckValue(features, nameof(features));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, string featureColumn, st
/// </summary>
internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
{
Host.CheckValue(args, nameof(args));
Host.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "L2 regularization term cannot be negative");
Expand Down
Loading