-
Notifications
You must be signed in to change notification settings - Fork 1.9k
LightGbm pigstensions #1020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LightGbm pigstensions #1020
Changes from 1 commit
29896c9
6d6f65c
2f3e8df
1a9c83a
c11b0ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -382,6 +382,16 @@ public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn) | |
return null; | ||
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); | ||
} | ||
|
||
/// <summary> | ||
/// Check that the label, feature, weights is not supplied in the args of the constructor. | ||
/// Those parameters should be internal if they are not used from the maml help code path. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Given the focus we now have I'm a little wary of having language specific to the command line in any new code, that is not somehow directly related to the command line, without plenty of explanatory text. Someone unaware of this software's roots as a tool rather than a library will have absolutely no idea what this means. #Resolved |
||
/// </summary> | ||
public static void CheckArgsDefaultColNames(IHost env, string defaultColName, string argValue) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
{ | ||
if (argValue != defaultColName) | ||
throw env.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
duplicate duplicate #Resolved |
||
} | ||
} | ||
|
||
/// <summary> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Data.StaticPipe.Runtime; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Internal.Internallearn; | ||
using Microsoft.ML.Runtime.LightGBM; | ||
using Microsoft.ML.Runtime.Training; | ||
using System; | ||
|
||
namespace Microsoft.ML.Trainers | ||
{ | ||
public static class LightGbmStatics | ||
{ | ||
public static Scalar<float> LightGbm(this RegressionContext.RegressionTrainers ctx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copy the summary comment here #Resolved |
||
Scalar<float> label, Vector<float> features, Scalar<float> weights = null, | ||
int? numLeaves = null, | ||
int? minDataPerLeaf = null, | ||
double? learningRate = null, | ||
int numBoostRound = 100, | ||
Action<LightGbmArguments> advancedSettings = null, | ||
Action<LightGbmRegressionPredictor> onFit = null) | ||
{ | ||
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); | ||
|
||
var rec = new TrainerEstimatorReconciler.Regression( | ||
(env, labelName, featuresName, weightsName) => | ||
{ | ||
var trainer = new LightGbmRegressorTrainer(env, labelName, featuresName, weightsName, numLeaves, | ||
minDataPerLeaf, learningRate, numBoostRound, advancedSettings); | ||
if (onFit != null) | ||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); | ||
return trainer; | ||
}, label, features, weights); | ||
|
||
return rec.Score; | ||
} | ||
|
||
public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, | ||
Scalar<bool> label, Vector<float> features, Scalar<float> weights = null, | ||
int? numLeaves = null, | ||
int? minDataPerLeaf = null, | ||
double? learningRate = null, | ||
int numBoostRound = 100, | ||
Action<LightGbmArguments> advancedSettings = null, | ||
Action<IPredictorWithFeatureWeights<float>> onFit = null) | ||
{ | ||
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); | ||
|
||
var rec = new TrainerEstimatorReconciler.BinaryClassifier( | ||
(env, labelName, featuresName, weightsName) => | ||
{ | ||
var trainer = new LightGbmBinaryTrainer(env, labelName, featuresName, weightsName, numLeaves, | ||
minDataPerLeaf, learningRate, numBoostRound, advancedSettings); | ||
|
||
if (onFit != null) | ||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); | ||
else | ||
return trainer; | ||
}, label, features, weights); | ||
|
||
return rec.Output; | ||
} | ||
|
||
private static void CheckUserValues<TVal, TArgs, TPred>(Scalar<TVal> label, Vector<float> features, Scalar<float> weights, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we are just checking the presence of stuff, right? We could make label a 'PipelineColumn', and the delegates to be just 'Delegate', and this method will not need to be generic #Resolved |
||
int? numLeaves, | ||
int? minDataPerLeaf, | ||
double? learningRate, | ||
int numBoostRound, | ||
Action<TArgs> advancedSettings, | ||
Action<TPred> onFit) | ||
{ | ||
Contracts.CheckValue(label, nameof(label)); | ||
Contracts.CheckValue(features, nameof(features)); | ||
Contracts.CheckValueOrNull(weights); | ||
Contracts.CheckParam(!(numLeaves < 2), nameof(numLeaves), "Must be at least 2."); | ||
Contracts.CheckParam(!(minDataPerLeaf <= 0), nameof(minDataPerLeaf), "Must be positive"); | ||
Contracts.CheckParam(!(learningRate <= 0), nameof(learningRate), "Must be positive"); | ||
Contracts.CheckParam(numBoostRound > 0, nameof(numBoostRound), "Must be positive"); | ||
Contracts.CheckValueOrNull(advancedSettings); | ||
Contracts.CheckValueOrNull(onFit); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,15 +89,28 @@ public sealed class LightGbmRegressorTrainer : LightGbmTrainerBase<float, Regres | |
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param> | ||
/// <param name="labelColumn">The name of the label column.</param> | ||
/// <param name="featureColumn">The name of the feature column.</param> | ||
/// <param name="groupIdColumn">The name for the column containing the group ID. </param> | ||
/// <param name="weightColumn">The name for the column containing the initial weight.</param> | ||
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param> | ||
/// <param name="learningRate"></param> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add docs on empty parameters #Resolved |
||
/// <param name="minDataPerLeaf"></param> | ||
/// <param name="numBoostRound"></param> | ||
/// <param name="numLeaves"></param> | ||
public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string featureColumn, | ||
string groupIdColumn = null, string weightColumn = null, Action<LightGbmArguments> advancedSettings = null) | ||
: base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) | ||
string weightColumn = null, | ||
int? numLeaves = null, | ||
int? minDataPerLeaf = null, | ||
double? learningRate = null, | ||
int numBoostRound = 100, | ||
Action<LightGbmArguments> advancedSettings = null) | ||
: base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) | ||
{ | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); | ||
|
||
// override with the directly provided values | ||
Args.NumBoostRound = numBoostRound; | ||
Args.NumLeaves = numLeaves ?? Args.NumLeaves; | ||
Args.LearningRate = learningRate ?? Args.LearningRate; | ||
} | ||
|
||
internal LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe ought to write out
arguments
. #Resolved