Skip to content

LightGbm pigstensions #1020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 26, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/Microsoft.ML.Data/Training/TrainerUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,16 @@ public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
return null;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}

/// <summary>
/// Check that the label, feature, weights is not supplied in the args of the constructor.
Copy link
Contributor

@TomFinley TomFinley Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args [](start = 74, length = 4)

Maybe ought to write out arguments. #Resolved

/// Those parameters should be internal if they are not used from the maml help code path.
Copy link
Contributor

@TomFinley TomFinley Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maml help code path [](start = 78, length = 19)

Given the focus we now have I'm a little wary of having language specific to the command line in any new code, that is not somehow directly related to the command line, without plenty of explanatory text. Someone unaware of this software's roots as a tool rather than a library will have absolutely no idea what this means. #Resolved

/// </summary>
public static void CheckArgsDefaultColNames(IHost env, string defaultColName, string argValue)
Copy link
Contributor

@TomFinley TomFinley Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IHost env [](start = 52, length = 9)

IExceptionContext please. #Resolved

{
if (argValue != defaultColName)
throw env.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead.");
Copy link
Contributor

@TomFinley TomFinley Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead instead [](start = 172, length = 15)

duplicate duplicate #Resolved

}
}

/// <summary>
Expand Down
86 changes: 86 additions & 0 deletions src/Microsoft.ML.LightGBM/LightGBMStatics.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Data.StaticPipe.Runtime;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.LightGBM;
using Microsoft.ML.Runtime.Training;
using System;

namespace Microsoft.ML.Trainers
{
public static class LightGbmStatics
{
public static Scalar<float> LightGbm(this RegressionContext.RegressionTrainers ctx,
Copy link
Contributor

@Zruty0 Zruty0 Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy the summary comment here #Resolved

Scalar<float> label, Vector<float> features, Scalar<float> weights = null,
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = 100,
Action<LightGbmArguments> advancedSettings = null,
Action<LightGbmRegressionPredictor> onFit = null)
{
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit);

var rec = new TrainerEstimatorReconciler.Regression(
(env, labelName, featuresName, weightsName) =>
{
var trainer = new LightGbmRegressorTrainer(env, labelName, featuresName, weightsName, numLeaves,
minDataPerLeaf, learningRate, numBoostRound, advancedSettings);
if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
return trainer;
}, label, features, weights);

return rec.Score;
}

public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
Scalar<bool> label, Vector<float> features, Scalar<float> weights = null,
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = 100,
Action<LightGbmArguments> advancedSettings = null,
Action<IPredictorWithFeatureWeights<float>> onFit = null)
{
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit);

var rec = new TrainerEstimatorReconciler.BinaryClassifier(
(env, labelName, featuresName, weightsName) =>
{
var trainer = new LightGbmBinaryTrainer(env, labelName, featuresName, weightsName, numLeaves,
minDataPerLeaf, learningRate, numBoostRound, advancedSettings);

if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
else
return trainer;
}, label, features, weights);

return rec.Output;
}

private static void CheckUserValues<TVal, TArgs, TPred>(Scalar<TVal> label, Vector<float> features, Scalar<float> weights,
Copy link
Contributor

@Zruty0 Zruty0 Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are just checking the presence of stuff, right? We could make label a 'PipelineColumn', and the delegates to be just 'Delegate', and this method will not need to be generic #Resolved

int? numLeaves,
int? minDataPerLeaf,
double? learningRate,
int numBoostRound,
Action<TArgs> advancedSettings,
Action<TPred> onFit)
{
Contracts.CheckValue(label, nameof(label));
Contracts.CheckValue(features, nameof(features));
Contracts.CheckValueOrNull(weights);
Contracts.CheckParam(!(numLeaves < 2), nameof(numLeaves), "Must be at least 2.");
Contracts.CheckParam(!(minDataPerLeaf <= 0), nameof(minDataPerLeaf), "Must be positive");
Contracts.CheckParam(!(learningRate <= 0), nameof(learningRate), "Must be positive");
Contracts.CheckParam(numBoostRound > 0, nameof(numBoostRound), "Must be positive");
Contracts.CheckValueOrNull(advancedSettings);
Contracts.CheckValueOrNull(onFit);
}
}
}
19 changes: 16 additions & 3 deletions src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,28 @@ internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args)
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="groupIdColumn">The name for the column containing the group ID. </param>
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
/// <param name="learningRate"></param>
/// <param name="minDataPerLeaf"></param>
/// <param name="numBoostRound"></param>
/// <param name="numLeaves"></param>
public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string featureColumn,
string groupIdColumn = null, string weightColumn = null, Action<LightGbmArguments> advancedSettings = null)
: base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings)
string weightColumn = null,
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = 100,
Action<LightGbmArguments> advancedSettings = null)
: base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));

// override with the directly provided values
Args.NumBoostRound = numBoostRound;
Args.NumLeaves = numLeaves ?? Args.NumLeaves;
Args.LearningRate = learningRate ?? Args.LearningRate;
}

private protected override IPredictorWithFeatureWeights<float> CreatePredictor()
Expand Down
19 changes: 16 additions & 3 deletions src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,28 @@ public sealed class LightGbmRegressorTrainer : LightGbmTrainerBase<float, Regres
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="groupIdColumn">The name for the column containing the group ID. </param>
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
/// <param name="learningRate"></param>
Copy link
Contributor

@Zruty0 Zruty0 Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docs on empty parameters #Resolved

/// <param name="minDataPerLeaf"></param>
/// <param name="numBoostRound"></param>
/// <param name="numLeaves"></param>
public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string featureColumn,
string groupIdColumn = null, string weightColumn = null, Action<LightGbmArguments> advancedSettings = null)
: base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings)
string weightColumn = null,
int? numLeaves = null,
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = 100,
Action<LightGbmArguments> advancedSettings = null)
: base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));

// override with the directly provided values
Args.NumBoostRound = numBoostRound;
Args.NumLeaves = numLeaves ?? Args.NumLeaves;
Args.LearningRate = learningRate ?? Args.LearningRate;
}

internal LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args)
Expand Down
7 changes: 7 additions & 0 deletions src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaS

//apply the advanced args, if the user supplied any
advancedSettings?.Invoke(Args);

// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
TrainerUtils.CheckArgsDefaultColNames(Host, DefaultColumnNames.Label, Args.LabelColumn);
TrainerUtils.CheckArgsDefaultColNames(Host, DefaultColumnNames.Features, Args.FeatureColumn);
TrainerUtils.CheckArgsDefaultColNames(Host, DefaultColumnNames.GroupId, Args.GroupIdColumn);
TrainerUtils.CheckArgsDefaultColNames(Host, DefaultColumnNames.Weight, Args.WeightColumn);

Args.LabelColumn = label.Name;
Args.FeatureColumn = featureColumn;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
<ProjectReference Include="..\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj" />
<NativeAssemblyReference Include="FactorizationMachineNative" />
Expand Down
84 changes: 84 additions & 0 deletions test/Microsoft.ML.StaticPipelineTesting/Training.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.FactorizationMachine;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.LightGBM;
using Microsoft.ML.Runtime.RunTests;
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Trainers;
Expand Down Expand Up @@ -260,5 +262,87 @@ public void SdcaMulticlass()
Assert.True(metrics.LogLoss > 0);
Assert.True(metrics.TopKAccuracy > 0);
}

[Fact]
public void LightGbmBinaryClassification()
{
var env = new ConsoleEnvironment(seed: 0);
var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
var dataSource = new MultiFileSource(dataPath);
var ctx = new BinaryClassificationContext(env);

var reader = TextLoader.CreateReader(env,
c => (label: c.LoadBool(0), features: c.LoadFloat(1, 9)));

IPredictorWithFeatureWeights<float> pred = null;

var est = reader.MakeNewEstimator()
.Append(r => (r.label, preds: ctx.Trainers.FastTree(r.label, r.features,
numBoostRound: 10,
numLeaves: 5,
learningRate: 0.01,
onFit: (p) => { pred = p; })));

var pipe = reader.Append(est);

Assert.Null(pred);
var model = pipe.Fit(dataSource);
Assert.NotNull(pred);

// 9 input features, so we ought to have 9 weights.
VBuffer<float> weights = new VBuffer<float>();
pred.GetFeatureWeights(ref weights);
Assert.Equal(9, weights.Length);

var data = model.Read(dataSource);

var metrics = ctx.Evaluate(data, r => r.label, r => r.preds);
// Run a sanity check against a few of the metrics.
Assert.InRange(metrics.Accuracy, 0, 1);
Assert.InRange(metrics.Auc, 0, 1);
Assert.InRange(metrics.Auprc, 0, 1);
}

[Fact]
public void LightGbmRegression()
{
var env = new ConsoleEnvironment(seed: 0);
var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
var dataSource = new MultiFileSource(dataPath);

var ctx = new RegressionContext(env);

var reader = TextLoader.CreateReader(env,
c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)),
separator: ';', hasHeader: true);

LightGbmRegressionPredictor pred = null;

var est = reader.MakeNewEstimator()
.Append(r => (r.label, score: ctx.Trainers.LightGbm(r.label, r.features,
numBoostRound: 10,
numLeaves: 5,
onFit: (p) => { pred = p; })));

var pipe = reader.Append(est);

Assert.Null(pred);
var model = pipe.Fit(dataSource);
Assert.NotNull(pred);
// 11 input features, so we ought to have 11 weights.
VBuffer<float> weights = new VBuffer<float>();
pred.GetFeatureWeights(ref weights);
Assert.Equal(11, weights.Length);

var data = model.Read(dataSource);

var metrics = ctx.Evaluate(data, r => r.label, r => r.score, new PoissonLoss());
// Run a sanity check against a few of the metrics.
Assert.InRange(metrics.L1, 0, double.PositiveInfinity);
Assert.InRange(metrics.L2, 0, double.PositiveInfinity);
Assert.InRange(metrics.Rms, 0, double.PositiveInfinity);
Assert.Equal(metrics.Rms * metrics.Rms, metrics.L2, 5);
Assert.InRange(metrics.LossFn, 0, double.PositiveInfinity);
}
}
}