Skip to content

LightGbm pigstensions #1020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/Microsoft.ML.Data/Training/TrainerUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -402,14 +402,16 @@ public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerIn
checkArgColName(DefaultColumnNames.Features, args.FeatureColumn);
checkArgColName(DefaultColumnNames.Weight, args.WeightColumn);

if(args.GroupIdColumn != null)
if (args.GroupIdColumn != null)
checkArgColName(DefaultColumnNames.GroupId, args.GroupIdColumn);
}

/// <summary>
/// If, after applying the advancedArgs delegate, the args are different that the default value
/// and are also different than the value supplied directly to the xtension method, warn the user.
/// </summary>
public static void CheckArgsAndAdvancedSettingMismatch<T>(IChannel channel, T methodParam, T defaultVal, T setting, string argName)
{
// if, after applying the advancedArgs delegate, the args are different that the default value
// and are also different than the value supplied directly to the xtension method, warn the user.
if (!setting.Equals(defaultVal) && !setting.Equals(methodParam))
channel.Warning($"The value supplied to advanced settings , is different than the value supplied directly. Using value {setting} for {argName}");
}
Expand Down
127 changes: 127 additions & 0 deletions src/Microsoft.ML.LightGBM/LightGBMStatics.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.LightGBM;
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.StaticPipe;
using Microsoft.ML.StaticPipe.Runtime;
using System;

namespace Microsoft.ML.Trainers
{
/// <summary>
/// LightGbm <see cref="TrainContextBase"/> extension methods.
/// </summary>
public static class LightGbmStatics
{
/// <summary>
/// LightGbm <see cref="RegressionContext"/> extension method.
/// </summary>
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
/// <param name="label">The label column.</param>
/// <param name="features">The features colum.</param>
/// <param name="weights">The weights column.</param>
/// <param name="numLeaves">The number of leaves to use.</param>
/// <param name="numBoostRound">Number of iterations.</param>
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
/// <param name="onFit">A delegate that is called every time the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive
/// the linear model that was trained. Note that this action cannot change the result in any way;
/// it is only a way for the caller to be informed about what was learnt.</param>
/// <returns>The Score output column indicating the predicted value.</returns>
public static Scalar<float> LightGbm(this RegressionContext.RegressionTrainers ctx,
Copy link
Contributor

@Zruty0 Zruty0 Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy the summary comment here #Resolved

Scalar<float> label, Vector<float> features, Scalar<float> weights = null,
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
Action<LightGbmArguments> advancedSettings = null,
Action<LightGbmRegressionPredictor> onFit = null)
{
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit);

var rec = new TrainerEstimatorReconciler.Regression(
(env, labelName, featuresName, weightsName) =>
{
var trainer = new LightGbmRegressorTrainer(env, labelName, featuresName, weightsName, numLeaves,
minDataPerLeaf, learningRate, numBoostRound, advancedSettings);
if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
return trainer;
}, label, features, weights);

return rec.Score;
}

/// <summary>
/// LightGbm <see cref="BinaryClassificationContext"/> extension method.
/// </summary>
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
/// <param name="label">The label column.</param>
/// <param name="features">The features colum.</param>
/// <param name="weights">The weights column.</param>
/// <param name="numLeaves">The number of leaves to use.</param>
/// <param name="numBoostRound">Number of iterations.</param>
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
/// <param name="onFit">A delegate that is called every time the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive
/// the linear model that was trained. Note that this action cannot change the result in any way;
/// it is only a way for the caller to be informed about what was learnt.</param>
/// <returns>The set of output columns including in order the predicted binary classification score (which will range
/// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label.</returns>
public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) LightGbm(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
Scalar<bool> label, Vector<float> features, Scalar<float> weights = null,
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
Action<LightGbmArguments> advancedSettings = null,
Action<IPredictorWithFeatureWeights<float>> onFit = null)
{
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit);

var rec = new TrainerEstimatorReconciler.BinaryClassifier(
(env, labelName, featuresName, weightsName) =>
{
var trainer = new LightGbmBinaryTrainer(env, labelName, featuresName, weightsName, numLeaves,
minDataPerLeaf, learningRate, numBoostRound, advancedSettings);

if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
else
return trainer;
}, label, features, weights);

return rec.Output;
}

private static void CheckUserValues(PipelineColumn label, Vector<float> features, Scalar<float> weights,
int? numLeaves,
int? minDataPerLeaf,
double? learningRate,
int numBoostRound,
Delegate advancedSettings,
Delegate onFit)
{
Contracts.CheckValue(label, nameof(label));
Contracts.CheckValue(features, nameof(features));
Contracts.CheckValueOrNull(weights);
Contracts.CheckParam(!(numLeaves < 2), nameof(numLeaves), "Must be at least 2.");
Contracts.CheckParam(!(minDataPerLeaf <= 0), nameof(minDataPerLeaf), "Must be positive");
Contracts.CheckParam(!(learningRate <= 0), nameof(learningRate), "Must be positive");
Contracts.CheckParam(numBoostRound > 0, nameof(numBoostRound), "Must be positive");
Contracts.CheckValueOrNull(advancedSettings);
Contracts.CheckValueOrNull(onFit);
}
}
}
7 changes: 6 additions & 1 deletion src/Microsoft.ML.LightGBM/LightGbmArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ private static string GetArgName(string name)
return strBuf.ToString();
}

internal static class Defaults
{
internal const int NumBoostRound = 100;
}

public sealed class TreeBooster : BoosterParameter<TreeBooster.Arguments>
{
public const string Name = "gbdt";
Expand Down Expand Up @@ -268,7 +273,7 @@ public enum EvalMetricType
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations.", SortOrder = 1, ShortName = "iter")]
[TGUI(Label = "Number of boosting iterations", SuggestedSweeps = "10,20,50,100,150,200")]
[TlcModule.SweepableDiscreteParam("NumBoostRound", new object[] { 10, 20, 50, 100, 150, 200 })]
public int NumBoostRound = 100;
public int NumBoostRound = Defaults.NumBoostRound;

[Argument(ArgumentType.AtMostOnce,
HelpText = "Shrinkage rate for trees, used to prevent over-fitting. Range: (0,1].",
Expand Down
23 changes: 20 additions & 3 deletions src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,32 @@ internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args)
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="groupIdColumn">The name for the column containing the group ID. </param>
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
/// <param name="numLeaves">The number of leaves to use.</param>
/// <param name="numBoostRound">Number of iterations.</param>
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string featureColumn,
string groupIdColumn = null, string weightColumn = null, Action<LightGbmArguments> advancedSettings = null)
: base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings)
string weightColumn = null,
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
Action<LightGbmArguments> advancedSettings = null)
: base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));

if (advancedSettings != null)
CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args);

// override with the directly provided values
Args.NumBoostRound = numBoostRound;
Args.NumLeaves = numLeaves ?? Args.NumLeaves;
Args.LearningRate = learningRate ?? Args.LearningRate;
Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf;
}

private protected override IPredictorWithFeatureWeights<float> CreatePredictor()
Expand Down
23 changes: 20 additions & 3 deletions src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,32 @@ public sealed class LightGbmRegressorTrainer : LightGbmTrainerBase<float, Regres
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="groupIdColumn">The name for the column containing the group ID. </param>
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
/// <param name="numLeaves">The number of leaves to use.</param>
/// <param name="numBoostRound">Number of iterations.</param>
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string featureColumn,
string groupIdColumn = null, string weightColumn = null, Action<LightGbmArguments> advancedSettings = null)
: base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings)
string weightColumn = null,
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
Action<LightGbmArguments> advancedSettings = null)
: base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));

if (advancedSettings != null)
CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args);

// override with the directly provided values
Args.NumBoostRound = numBoostRound;
Args.NumLeaves = numLeaves ?? Args.NumLeaves;
Args.LearningRate = learningRate ?? Args.LearningRate;
Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf;
}

internal LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args)
Expand Down
32 changes: 32 additions & 0 deletions src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaS

//apply the advanced args, if the user supplied any
advancedSettings?.Invoke(Args);

// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
TrainerUtils.CheckArgsHaveDefaultColNames(Host, Args);

Args.LabelColumn = label.Name;
Args.FeatureColumn = featureColumn;

Expand Down Expand Up @@ -158,6 +162,34 @@ protected virtual void CheckDataValid(IChannel ch, RoleMappedData data)
ch.CheckParam(data.Schema.Label != null, nameof(data), "Need a label column");
}

/// <summary>
/// If, after applying the advancedSettings delegate, the args are different that the default value
/// and are also different than the value supplied directly to the xtension method, warn the user
/// about which value is being used.
/// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune.
/// This list should follow the one in the constructor, and the extension methods on the <see cref="TrainContextBase"/>.
/// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation.
/// </summary>
protected void CheckArgsAndAdvancedSettingMismatch(int? numLeaves,
int? minDataPerLeaf,
double? learningRate,
int numBoostRound,
LightGbmArguments snapshot,
LightGbmArguments currentArgs)
{
using (var ch = Host.Start("Comparing advanced settings with the directly provided values."))
{

// Check that the user didn't supply different parameters in the args, from what it specified directly.
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numBoostRound, snapshot.NumBoostRound, currentArgs.NumBoostRound, nameof(numBoostRound));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDataPerLeaf, snapshot.MinDataPerLeaf, currentArgs.MinDataPerLeaf, nameof(minDataPerLeaf));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRate, currentArgs.LearningRate, nameof(learningRate));

ch.Done();
}
}

protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCategarical, int totalCats, bool hiddenMsg=false)
{
double learningRate = Args.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
<ProjectReference Include="..\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Analyzer\Microsoft.ML.Analyzer.csproj">
Expand Down
Loading