Skip to content

Modify API for advanced settings (LightGBM) #2261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 185 additions & 25 deletions src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions src/Microsoft.ML.LightGBM/LightGbmArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.LightGBM;

[assembly: LoadableClass(typeof(LightGbmArguments.TreeBooster), typeof(LightGbmArguments.TreeBooster.Arguments),
typeof(SignatureLightGBMBooster), LightGbmArguments.TreeBooster.FriendlyName, LightGbmArguments.TreeBooster.Name)]
[assembly: LoadableClass(typeof(LightGbmArguments.DartBooster), typeof(LightGbmArguments.DartBooster.Arguments),
typeof(SignatureLightGBMBooster), LightGbmArguments.DartBooster.FriendlyName, LightGbmArguments.DartBooster.Name)]
[assembly: LoadableClass(typeof(LightGbmArguments.GossBooster), typeof(LightGbmArguments.GossBooster.Arguments),
typeof(SignatureLightGBMBooster), LightGbmArguments.GossBooster.FriendlyName, LightGbmArguments.GossBooster.Name)]
[assembly: LoadableClass(typeof(Options.TreeBooster), typeof(Options.TreeBooster.Arguments),
typeof(SignatureLightGBMBooster), Options.TreeBooster.FriendlyName, Options.TreeBooster.Name)]
[assembly: LoadableClass(typeof(Options.DartBooster), typeof(Options.DartBooster.Arguments),
typeof(SignatureLightGBMBooster), Options.DartBooster.FriendlyName, Options.DartBooster.Name)]
[assembly: LoadableClass(typeof(Options.GossBooster), typeof(Options.GossBooster.Arguments),
typeof(SignatureLightGBMBooster), Options.GossBooster.FriendlyName, Options.GossBooster.Name)]

[assembly: EntryPointModule(typeof(LightGbmArguments.TreeBooster.Arguments))]
[assembly: EntryPointModule(typeof(LightGbmArguments.DartBooster.Arguments))]
[assembly: EntryPointModule(typeof(LightGbmArguments.GossBooster.Arguments))]
[assembly: EntryPointModule(typeof(Options.TreeBooster.Arguments))]
[assembly: EntryPointModule(typeof(Options.DartBooster.Arguments))]
[assembly: EntryPointModule(typeof(Options.GossBooster.Arguments))]

namespace Microsoft.ML.LightGBM
{
Expand All @@ -39,7 +39,7 @@ public interface IBoosterParameter
/// Parameters names comes from LightGBM library.
/// See https://github.com/Microsoft/LightGBM/blob/master/docs/Parameters.rst.
/// </summary>
public sealed class LightGbmArguments : LearnerInputBaseWithGroupId
public sealed class Options : LearnerInputBaseWithGroupId
{
public abstract class BoosterParameter<TArgs> : IBoosterParameter
where TArgs : class, new()
Expand Down
21 changes: 8 additions & 13 deletions src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;

[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(LightGbmArguments),
[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) },
LightGbmBinaryTrainer.UserName, LightGbmBinaryTrainer.LoadNameValue, LightGbmBinaryTrainer.ShortName, DocName = "trainer/LightGBM.md")]

Expand Down Expand Up @@ -95,8 +95,8 @@ public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase<float, BinaryPre

public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args)
: base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
internal LightGbmBinaryTrainer(IHostEnvironment env, Options options)
: base(env, LoadNameValue, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn))
{
}

Expand All @@ -111,20 +111,15 @@ internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args)
/// <param name="numBoostRound">Number of iterations.</param>
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">A delegate to set more settings.
/// The settings here will override the ones provided in the direct signature,
/// if both are present and have different values.
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
public LightGbmBinaryTrainer(IHostEnvironment env,
internal LightGbmBinaryTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
Copy link
Contributor

@artidoro artidoro Jan 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not possible to combine the two constructors?
Or have you added this work to the reconciliation issue #2100? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap. exactly. added to #2100

want to focus on cleanup of public surface first


In reply to: 251541952 [](ancestors = 251541952)

string featureColumn = DefaultColumnNames.Features,
string weights = null,
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
Action<LightGbmArguments> advancedSettings = null)
: base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings)
int numBoostRound = LightGBM.Options.Defaults.NumBoostRound)
: base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound)
{
}

Expand Down Expand Up @@ -186,14 +181,14 @@ public static partial class LightGbm
ShortName = LightGbmBinaryTrainer.ShortName,
XmlInclude = new[] { @"<include file='../Microsoft.ML.LightGBM/doc.xml' path='doc/members/member[@name=""LightGBM""]/*' />",
@"<include file='../Microsoft.ML.LightGBM/doc.xml' path='doc/members/example[@name=""LightGbmBinaryClassifier""]/*' />"})]
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, LightGbmArguments input)
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainLightGBM");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<LightGbmArguments, CommonOutputs.BinaryClassificationOutput>(host, input,
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new LightGbmBinaryTrainer(host, input),
getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));
Expand Down
91 changes: 60 additions & 31 deletions src/Microsoft.ML.LightGBM/LightGbmCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,31 @@ public static class LightGbmExtensions
/// <param name="numBoostRound">Number of iterations.</param>
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">A delegate to set more settings.
/// The settings here will override the ones provided in the direct signature,
/// if both are present and have different values.
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
Action<LightGbmArguments> advancedSettings = null)
int numBoostRound = Options.Defaults.NumBoostRound)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmRegressorTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings);
return new LightGbmRegressorTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound);
}

/// <summary>
/// Predict a target using a decision tree regression model trained with the <see cref="LightGbmRegressorTrainer"/>.
/// </summary>
/// <param name="catalog">The <see cref="RegressionCatalog"/>.</param>
/// <param name="options">Advanced options to the algorithm.</param>
public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog,
Options options)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmRegressorTrainer(env, options);
}

/// <summary>
Expand All @@ -54,28 +62,35 @@ public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.Regressio
/// <param name="numBoostRound">Number of iterations.</param>
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">A delegate to set more settings.
/// The settings here will override the ones provided in the direct signature,
/// if both are present and have different values.
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
Action<LightGbmArguments> advancedSettings = null)
int numBoostRound = Options.Defaults.NumBoostRound)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmBinaryTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings);
return new LightGbmBinaryTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound);
}

/// <summary>
/// Predict a target using a decision tree binary classification model trained with the <see cref="LightGbmBinaryTrainer"/>.
/// </summary>
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
/// <param name="options">Advanced options to the algorithm.</param>
public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
Options options)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmBinaryTrainer(env, options);
}

/// <summary>
Copy link
Contributor

@rogancarr rogancarr Jan 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

binary classification => ranking #Resolved

/// Predict a target using a decision tree binary classification model trained with the <see cref="LightGbmRankingTrainer"/>.
/// Predict a target using a decision tree ranking model trained with the <see cref="LightGbmRankingTrainer"/>.
/// </summary>
/// <param name="catalog">The <see cref="RankingCatalog"/>.</param>
/// <param name="labelColumn">The labelColumn column.</param>
Expand All @@ -86,10 +101,6 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi
/// <param name="numBoostRound">Number of iterations.</param>
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">A delegate to set more settings.
/// The settings here will override the ones provided in the direct signature,
/// if both are present and have different values.
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
Expand All @@ -98,44 +109,62 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
Action<LightGbmArguments> advancedSettings = null)
int numBoostRound = Options.Defaults.NumBoostRound)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmRankingTrainer(env, labelColumn, featureColumn, groupIdColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings);

return new LightGbmRankingTrainer(env, labelColumn, featureColumn, groupIdColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound);
Copy link
Contributor

@rogancarr rogancarr Jan 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some pretty long lines in this file. Over the permitted length? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i do not see any warnings related to long lines.

i am actually reducing the length, by dropping advancedSettings from the line :)


In reply to: 251535698 [](ancestors = 251535698)

}

/// <summary>
/// Predict a target using a decision tree binary classification model trained with the <see cref="LightGbmRankingTrainer"/>.
/// Predict a target using a decision tree ranking model trained with the <see cref="LightGbmRankingTrainer"/>.
/// </summary>
/// <param name="catalog">The <see cref="RankingCatalog"/>.</param>
/// <param name="options">Advanced options to the algorithm.</param>
public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainers catalog,
Options options)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmRankingTrainer(env, options);
}

/// <summary>
/// Predict a target using a decision tree multiclass classification model trained with the <see cref="LightGbmMulticlassTrainer"/>.
/// </summary>
/// <param name="catalog">The <see cref="MulticlassClassificationCatalog"/>.</param>
/// <param name="labelColumn">The labelColumn column.</param>
/// <param name="featureColumn">The features column.</param>
/// <param name="weights">The weights column.</param>
/// <param name="numLeaves">The number of leaves to use.</param>
/// <param name="numBoostRound">Number of iterations.</param>
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">A delegate to set more settings.
/// The settings here will override the ones provided in the direct signature,
/// if both are present and have different values.
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
Action<LightGbmArguments> advancedSettings = null)
int numBoostRound = Options.Defaults.NumBoostRound)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmMulticlassTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings);
return new LightGbmMulticlassTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound);
}

/// <summary>
/// Predict a target using a decision tree multiclass classification model trained with the <see cref="LightGbmMulticlassTrainer"/>.
/// </summary>
/// <param name="catalog">The <see cref="MulticlassClassificationCatalog"/>.</param>
/// <param name="options">Advanced options to the algorithm.</param>
public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
Options options)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmMulticlassTrainer(env, options);
}
}
}
Loading