From 92055255af8bee76d812f440c1310ec8f2ee734b Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 28 Jan 2019 03:42:30 +0000 Subject: [PATCH 1/4] lightgbm tests work fine --- .../LightGbmStaticExtensions.cs | 202 ++++++++++++++++-- .../LightGbmBinaryTrainer.cs | 11 +- src/Microsoft.ML.LightGBM/LightGbmCatalog.cs | 85 +++++--- .../LightGbmMulticlassTrainer.cs | 11 +- .../LightGbmRankingTrainer.cs | 11 +- .../LightGbmRegressionTrainer.cs | 11 +- .../LightGbmTrainerBase.cs | 6 +- .../TrainerEstimators/TreeEstimators.cs | 60 +++--- 8 files changed, 287 insertions(+), 110 deletions(-) diff --git a/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs b/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs index f81387d828..c7e6f5ec44 100644 --- a/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs +++ b/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; @@ -26,7 +28,6 @@ public static class LightGbmStaticExtensions /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -45,16 +46,52 @@ public static Scalar LightGbm(this RegressionCatalog.RegressionTrainers c int? minDataPerLeaf = null, double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, - Action advancedSettings = null, Action onFit = null) { - CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => { var trainer = new LightGbmRegressorTrainer(env, labelName, featuresName, weightsName, numLeaves, - minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + minDataPerLeaf, learningRate, numBoostRound); + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, weights); + + return rec.Score; + } + + /// + /// Predict a target using a tree regression model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The weights column. + /// Algorithm advanced settings. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; + /// it is only a way for the caller to be informed about what was learnt. + /// The Score output column indicating the predicted value. + public static Scalar LightGbm(this RegressionCatalog.RegressionTrainers catalog, + Scalar label, Vector features, Scalar weights, + LightGbmArguments advancedSettings, + Action onFit = null) + { + CheckUserValues(label, features, weights, advancedSettings, onFit); + + var rec = new TrainerEstimatorReconciler.Regression( + (env, labelName, featuresName, weightsName) => + { + advancedSettings.LabelColumn = labelName; + advancedSettings.FeatureColumn = featuresName; + advancedSettings.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); + + var trainer = new LightGbmRegressorTrainer(env, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -74,7 +111,6 @@ public static Scalar LightGbm(this RegressionCatalog.RegressionTrainers c /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -94,16 +130,55 @@ public static (Scalar score, Scalar probability, Scalar pred int? minDataPerLeaf = null, double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, - Action advancedSettings = null, Action> onFit = null) { - CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => { var trainer = new LightGbmBinaryTrainer(env, labelName, featuresName, weightsName, numLeaves, - minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + minDataPerLeaf, learningRate, numBoostRound); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + else + return trainer; + }, label, features, weights); + + return rec.Output; + } + + /// + /// Predict a target using a tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The weights column. + /// Algorithm advanced settings. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; + /// it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted binary classification score (which will range + /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label. + public static (Scalar score, Scalar probability, Scalar predictedLabel) LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, + Scalar label, Vector features, Scalar weights, + LightGbmArguments advancedSettings, + Action> onFit = null) + { + CheckUserValues(label, features, weights, advancedSettings, onFit); + + var rec = new TrainerEstimatorReconciler.BinaryClassifier( + (env, labelName, featuresName, weightsName) => + { + advancedSettings.LabelColumn = labelName; + advancedSettings.FeatureColumn = featuresName; + advancedSettings.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); + + var trainer = new LightGbmBinaryTrainer(env, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); @@ -126,7 +201,6 @@ public static (Scalar score, Scalar probability, Scalar pred /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -140,17 +214,58 @@ public static Scalar LightGbm(this RankingCatalog.RankingTrainers c int? minDataPerLeaf = null, double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, - Action advancedSettings = null, Action onFit = null) { - CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit); Contracts.CheckValue(groupId, nameof(groupId)); var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => { var trainer = new LightGbmRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, numLeaves, - minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + minDataPerLeaf, learningRate, numBoostRound); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, groupId, weights); + + return rec.Score; + } + + /// + /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . + /// + /// The . + /// The label column. + /// The features column. + /// The groupId column. + /// The weights column. + /// Algorithm advanced settings. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; + /// it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted binary classification score (which will range + /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label. + public static Scalar LightGbm(this RankingCatalog.RankingTrainers catalog, + Scalar label, Vector features, Key groupId, Scalar weights, + LightGbmArguments advancedSettings, + Action onFit = null) + { + CheckUserValues(label, features, weights, advancedSettings, onFit); + Contracts.CheckValue(groupId, nameof(groupId)); + + var rec = new TrainerEstimatorReconciler.Ranker( + (env, labelName, featuresName, groupIdName, weightsName) => + { + advancedSettings.LabelColumn = labelName; + advancedSettings.FeatureColumn = featuresName; + advancedSettings.GroupIdColumn = groupIdName; + advancedSettings.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); + + var trainer = new LightGbmRankingTrainer(env, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); @@ -171,10 +286,6 @@ public static Scalar LightGbm(this RankingCatalog.RankingTrainers c /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -197,16 +308,56 @@ public static (Vector score, Key predictedLabel) int? minDataPerLeaf = null, double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, - Action advancedSettings = null, Action onFit = null) { - CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit); var rec = new TrainerEstimatorReconciler.MulticlassClassifier( (env, labelName, featuresName, weightsName) => { var trainer = new LightGbmMulticlassTrainer(env, labelName, featuresName, weightsName, numLeaves, - minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + minDataPerLeaf, learningRate, numBoostRound); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, weights); + + return rec.Output; + } + + /// + /// Predict a target using a tree multiclass classification model trained with the . + /// + /// The multiclass classification catalog trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The weights column. + /// Advanced options to the algorithm. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. + public static (Vector score, Key predictedLabel) + LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + Key label, + Vector features, + Scalar weights, + LightGbmArguments advancedSettings, + Action onFit = null) + { + CheckUserValues(label, features, weights, advancedSettings, onFit); + + var rec = new TrainerEstimatorReconciler.MulticlassClassifier( + (env, labelName, featuresName, weightsName) => + { + advancedSettings.LabelColumn = labelName; + advancedSettings.FeatureColumn = featuresName; + advancedSettings.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); + + var trainer = new LightGbmMulticlassTrainer(env, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); @@ -221,7 +372,6 @@ private static void CheckUserValues(PipelineColumn label, Vector features int? minDataPerLeaf, double? learningRate, int numBoostRound, - Delegate advancedSettings, Delegate onFit) { Contracts.CheckValue(label, nameof(label)); @@ -231,7 +381,17 @@ private static void CheckUserValues(PipelineColumn label, Vector features Contracts.CheckParam(!(minDataPerLeaf <= 0), nameof(minDataPerLeaf), "Must be positive"); Contracts.CheckParam(!(learningRate <= 0), nameof(learningRate), "Must be positive"); Contracts.CheckParam(numBoostRound > 0, nameof(numBoostRound), "Must be positive"); - Contracts.CheckValueOrNull(advancedSettings); + Contracts.CheckValueOrNull(onFit); + } + + private static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, + LightGbmArguments advancedSettings, + Delegate onFit) + { + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckValue(features, nameof(features)); + Contracts.CheckValueOrNull(weights); + Contracts.CheckValue(advancedSettings, nameof(advancedSettings)); Contracts.CheckValueOrNull(onFit); } } diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index f49ddc1e29..149ecccd15 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -111,20 +111,15 @@ internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . - public LightGbmBinaryTrainer(IHostEnvironment env, + internal LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weights = null, int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, - Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) + : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound) { } diff --git a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs index 855e679ec7..5f0fa25a1a 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs @@ -24,10 +24,6 @@ public static class LightGbmExtensions /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -35,12 +31,24 @@ public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.Regressio int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, - Action advancedSettings = null) + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new LightGbmRegressorTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + return new LightGbmRegressorTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound); + } + + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// Advanced options to the algorithm. + public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog, + LightGbmArguments advancedSettings) + { + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + return new LightGbmRegressorTrainer(env, advancedSettings); } /// @@ -54,10 +62,6 @@ public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.Regressio /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -65,13 +69,24 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, - Action advancedSettings = null) + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new LightGbmBinaryTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + return new LightGbmBinaryTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound); + } + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// Advanced options to the algorithm. + public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, + LightGbmArguments advancedSettings) + { + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + return new LightGbmBinaryTrainer(env, advancedSettings); } /// @@ -86,10 +101,6 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainers catalog, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -98,13 +109,24 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, - Action advancedSettings = null) + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new LightGbmRankingTrainer(env, labelColumn, featureColumn, groupIdColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + return new LightGbmRankingTrainer(env, labelColumn, featureColumn, groupIdColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound); + } + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// Advanced options to the algorithm. + public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainers catalog, + LightGbmArguments advancedSettings) + { + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + return new LightGbmRankingTrainer(env, advancedSettings); } /// @@ -118,10 +140,6 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -129,13 +147,24 @@ public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCa int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, - Action advancedSettings = null) + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new LightGbmMulticlassTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + return new LightGbmMulticlassTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound); + } + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// Advanced options to the algorithm. + public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + LightGbmArguments advancedSettings) + { + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + return new LightGbmMulticlassTrainer(env, advancedSettings); } } } diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 681562e42e..5347ece215 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -51,20 +51,15 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . - public LightGbmMulticlassTrainer(IHostEnvironment env, + internal LightGbmMulticlassTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weights = null, int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, - Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeU4ScalarColumn(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) + : base(env, LoadNameValue, TrainerUtils.MakeU4ScalarColumn(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound) { _numClass = -1; } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 1aeca6a725..56fd6b02cf 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -98,11 +98,7 @@ internal LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . - public LightGbmRankingTrainer(IHostEnvironment env, + internal LightGbmRankingTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string groupId = DefaultColumnNames.GroupId, @@ -110,9 +106,8 @@ public LightGbmRankingTrainer(IHostEnvironment env, int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, - Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weights, groupId, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weights, groupId, numLeaves, minDataPerLeaf, learningRate, numBoostRound) { Host.CheckNonEmpty(groupId, nameof(groupId)); } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 742ce58790..f937c593f2 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -96,20 +96,15 @@ public sealed class LightGbmRegressorTrainer : LightGbmTrainerBaseNumber of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . - public LightGbmRegressorTrainer(IHostEnvironment env, + internal LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weights = null, int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, - Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound) { } diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index bc02b59b1b..5ff68989e2 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -67,8 +67,7 @@ private protected LightGbmTrainerBase(IHostEnvironment env, int? numLeaves, int? minDataPerLeaf, double? learningRate, - int numBoostRound, - Action advancedSettings) + int numBoostRound) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { Args = new LightGbmArguments(); @@ -78,9 +77,6 @@ private protected LightGbmTrainerBase(IHostEnvironment env, Args.LearningRate = learningRate; Args.NumBoostRound = numBoostRound; - //apply the advanced args, if the user supplied any - advancedSettings?.Invoke(Args); - Args.LabelColumn = label.Name; Args.FeatureColumn = featureColumn; diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index 341a3c1ddf..c3b11252ab 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -26,7 +26,8 @@ public void FastTreeBinaryEstimator() var (pipe, dataView) = GetBinaryClassificationPipeline(); var trainer = ML.BinaryClassification.Trainers.FastTree( - new FastTreeBinaryClassificationTrainer.Options { + new FastTreeBinaryClassificationTrainer.Options + { NumThreads = 1, NumTrees = 10, NumLeaves = 5, @@ -45,12 +46,13 @@ public void LightGBMBinaryEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); - var trainer = new LightGbmBinaryTrainer(Env, "Label", "Features", advancedSettings: s => + var trainer = new LightGbmBinaryTrainer(Env, new LightGbmArguments { - s.NumLeaves = 10; - s.NThread = 1; - s.MinDataPerLeaf = 2; + NumLeaves = 10, + NThread = 1, + MinDataPerLeaf = 2, }); + var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView); @@ -84,8 +86,9 @@ public void FastForestClassificationEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); - var trainer = ML.BinaryClassification.Trainers.FastForest( - new FastForestClassification.Options { + var trainer = ML.BinaryClassification.Trainers.FastForest( + new FastForestClassification.Options + { NumLeaves = 10, NumTrees = 20, }); @@ -107,7 +110,8 @@ public void FastTreeRankerEstimator() var (pipe, dataView) = GetRankingPipeline(); var trainer = ML.Ranking.Trainers.FastTree( - new FastTreeRankingTrainer.Options { + new FastTreeRankingTrainer.Options + { FeatureColumn = "NumericFeatures", NumTrees = 10 }); @@ -128,8 +132,8 @@ public void LightGBMRankerEstimator() { var (pipe, dataView) = GetRankingPipeline(); - var trainer = new LightGbmRankingTrainer(Env, "Label0", "NumericFeatures", "Group", - advancedSettings: s => { s.LearningRate = 0.4; }); + var trainer = new LightGbmRankingTrainer(Env, labelColumn: "Label0", featureColumn: "NumericFeatures", groupId: "Group", learningRate: 0.4); + var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView); @@ -160,11 +164,11 @@ public void FastTreeRegressorEstimator() public void LightGBMRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new LightGbmRegressorTrainer(Env, "Label", "Features", advancedSettings: s => + var trainer = new LightGbmRegressorTrainer(Env, new LightGbmArguments { - s.NThread = 1; - s.NormalizeFeatures = NormalizeOption.Warn; - s.CatL2 = 5; + NThread = 1, + NormalizeFeatures = NormalizeOption.Warn, + CatL2 = 5, }); TestEstimatorCore(trainer, dataView); @@ -198,8 +202,9 @@ public void GAMRegressorEstimator() public void TweedieRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = ML.Regression.Trainers.FastTreeTweedie( - new FastTreeTweedieTrainer.Options { + var trainer = ML.Regression.Trainers.FastTreeTweedie( + new FastTreeTweedieTrainer.Options + { EntropyCoefficient = 0.3, OptimizationAlgorithm = BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent, }); @@ -216,8 +221,9 @@ public void TweedieRegressorEstimator() public void FastForestRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = ML.Regression.Trainers.FastForest( - new FastForestRegression.Options { + var trainer = ML.Regression.Trainers.FastForest( + new FastForestRegression.Options + { BaggingSize = 2, NumTrees = 10, }); @@ -234,7 +240,7 @@ public void FastForestRegressorEstimator() public void LightGbmMultiClassEstimator() { var (pipeline, dataView) = GetMultiClassPipeline(); - var trainer = new LightGbmMulticlassTrainer(Env, "Label", "Features", advancedSettings: s => { s.LearningRate = 0.4; }); + var trainer = new LightGbmMulticlassTrainer(Env, learningRate: 0.4); var pipe = pipeline.Append(trainer) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); TestEstimatorCore(pipe, dataView); @@ -251,7 +257,7 @@ private class GbmExample { [VectorType(_columnNumber)] public float[] Features; - [KeyType(Count =_classNumber)] + [KeyType(Count = _classNumber)] public uint Label; [VectorType(_classNumber)] public float[] Score; @@ -284,8 +290,14 @@ private void LightGbmHelper(bool useSoftmax, out string modelString, out List { s.MinDataPerGroup = 1; s.MinDataPerLeaf = 1; s.UseSoftmax = useSoftmax; }); + var gbmTrainer = new LightGbmMulticlassTrainer(mlContext, new LightGbmArguments + { + NumBoostRound = numberOfTrainingIterations, + MinDataPerGroup = 1, + MinDataPerLeaf = 1, + UseSoftmax = useSoftmax + }); + var gbm = gbmTrainer.Fit(dataView); var predicted = gbm.Transform(dataView); mlnetPredictions = mlContext.CreateEnumerable(predicted, false).ToList(); @@ -337,7 +349,7 @@ private void LightGbmHelper(bool useSoftmax, out string modelString, out List Date: Mon, 28 Jan 2019 03:57:34 +0000 Subject: [PATCH 2/4] Options renaming --- .../LightGbmStaticExtensions.cs | 70 +++++++++---------- .../LightGbmArguments.cs | 20 +++--- .../LightGbmBinaryTrainer.cs | 12 ++-- src/Microsoft.ML.LightGBM/LightGbmCatalog.cs | 32 ++++----- .../LightGbmMulticlassTrainer.cs | 12 ++-- .../LightGbmRankingTrainer.cs | 12 ++-- .../LightGbmRegressionTrainer.cs | 12 ++-- .../LightGbmTrainerBase.cs | 12 ++-- .../Common/EntryPoints/core_ep-list.tsv | 8 +-- .../TestPredictors.cs | 2 +- .../TrainerEstimators/TreeEstimators.cs | 6 +- 11 files changed, 99 insertions(+), 99 deletions(-) diff --git a/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs b/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs index c7e6f5ec44..e2223c4788 100644 --- a/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs +++ b/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs @@ -45,7 +45,7 @@ public static Scalar LightGbm(this RegressionCatalog.RegressionTrainers c int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + int numBoostRound = Options.Defaults.NumBoostRound, Action onFit = null) { CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit); @@ -70,7 +70,7 @@ public static Scalar LightGbm(this RegressionCatalog.RegressionTrainers c /// The label column. /// The features column. /// The weights column. - /// Algorithm advanced settings. + /// Algorithm advanced settings. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -79,19 +79,19 @@ public static Scalar LightGbm(this RegressionCatalog.RegressionTrainers c /// The Score output column indicating the predicted value. public static Scalar LightGbm(this RegressionCatalog.RegressionTrainers catalog, Scalar label, Vector features, Scalar weights, - LightGbmArguments advancedSettings, + Options options, Action onFit = null) { - CheckUserValues(label, features, weights, advancedSettings, onFit); + CheckUserValues(label, features, weights, options, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => { - advancedSettings.LabelColumn = labelName; - advancedSettings.FeatureColumn = featuresName; - advancedSettings.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); + options.LabelColumn = labelName; + options.FeatureColumn = featuresName; + options.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); - var trainer = new LightGbmRegressorTrainer(env, advancedSettings); + var trainer = new LightGbmRegressorTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -129,7 +129,7 @@ public static (Scalar score, Scalar probability, Scalar pred int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + int numBoostRound = Options.Defaults.NumBoostRound, Action> onFit = null) { CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit); @@ -156,7 +156,7 @@ public static (Scalar score, Scalar probability, Scalar pred /// The label column. /// The features column. /// The weights column. - /// Algorithm advanced settings. + /// Algorithm advanced settings. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -166,19 +166,19 @@ public static (Scalar score, Scalar probability, Scalar pred /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label. public static (Scalar score, Scalar probability, Scalar predictedLabel) LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, Scalar label, Vector features, Scalar weights, - LightGbmArguments advancedSettings, + Options options, Action> onFit = null) { - CheckUserValues(label, features, weights, advancedSettings, onFit); + CheckUserValues(label, features, weights, options, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => { - advancedSettings.LabelColumn = labelName; - advancedSettings.FeatureColumn = featuresName; - advancedSettings.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); + options.LabelColumn = labelName; + options.FeatureColumn = featuresName; + options.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); - var trainer = new LightGbmBinaryTrainer(env, advancedSettings); + var trainer = new LightGbmBinaryTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); @@ -213,7 +213,7 @@ public static Scalar LightGbm(this RankingCatalog.RankingTrainers c int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + int numBoostRound = Options.Defaults.NumBoostRound, Action onFit = null) { CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit); @@ -241,7 +241,7 @@ public static Scalar LightGbm(this RankingCatalog.RankingTrainers c /// The features column. /// The groupId column. /// The weights column. - /// Algorithm advanced settings. + /// Algorithm advanced settings. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -251,21 +251,21 @@ public static Scalar LightGbm(this RankingCatalog.RankingTrainers c /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label. public static Scalar LightGbm(this RankingCatalog.RankingTrainers catalog, Scalar label, Vector features, Key groupId, Scalar weights, - LightGbmArguments advancedSettings, + Options options, Action onFit = null) { - CheckUserValues(label, features, weights, advancedSettings, onFit); + CheckUserValues(label, features, weights, options, onFit); Contracts.CheckValue(groupId, nameof(groupId)); var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => { - advancedSettings.LabelColumn = labelName; - advancedSettings.FeatureColumn = featuresName; - advancedSettings.GroupIdColumn = groupIdName; - advancedSettings.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); + options.LabelColumn = labelName; + options.FeatureColumn = featuresName; + options.GroupIdColumn = groupIdName; + options.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); - var trainer = new LightGbmRankingTrainer(env, advancedSettings); + var trainer = new LightGbmRankingTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); @@ -307,7 +307,7 @@ public static (Vector score, Key predictedLabel) int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + int numBoostRound = Options.Defaults.NumBoostRound, Action onFit = null) { CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit); @@ -333,7 +333,7 @@ public static (Vector score, Key predictedLabel) /// The label, or dependent variable. /// The features, or independent variables. /// The weights column. - /// Advanced options to the algorithm. + /// Advanced options to the algorithm. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -345,19 +345,19 @@ public static (Vector score, Key predictedLabel) Key label, Vector features, Scalar weights, - LightGbmArguments advancedSettings, + Options options, Action onFit = null) { - CheckUserValues(label, features, weights, advancedSettings, onFit); + CheckUserValues(label, features, weights, options, onFit); var rec = new TrainerEstimatorReconciler.MulticlassClassifier( (env, labelName, featuresName, weightsName) => { - advancedSettings.LabelColumn = labelName; - advancedSettings.FeatureColumn = featuresName; - advancedSettings.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); + options.LabelColumn = labelName; + options.FeatureColumn = featuresName; + options.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); - var trainer = new LightGbmMulticlassTrainer(env, advancedSettings); + var trainer = new LightGbmMulticlassTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); @@ -385,13 +385,13 @@ private static void CheckUserValues(PipelineColumn label, Vector features } private static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, - LightGbmArguments advancedSettings, + Options options, Delegate onFit) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); Contracts.CheckValueOrNull(weights); - Contracts.CheckValue(advancedSettings, nameof(advancedSettings)); + Contracts.CheckValue(options, nameof(options)); Contracts.CheckValueOrNull(onFit); } } diff --git a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs index 35dfbbed8f..dbfa868847 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs @@ -11,16 +11,16 @@ using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.LightGBM; -[assembly: LoadableClass(typeof(LightGbmArguments.TreeBooster), typeof(LightGbmArguments.TreeBooster.Arguments), - typeof(SignatureLightGBMBooster), LightGbmArguments.TreeBooster.FriendlyName, LightGbmArguments.TreeBooster.Name)] -[assembly: LoadableClass(typeof(LightGbmArguments.DartBooster), typeof(LightGbmArguments.DartBooster.Arguments), - typeof(SignatureLightGBMBooster), LightGbmArguments.DartBooster.FriendlyName, LightGbmArguments.DartBooster.Name)] -[assembly: LoadableClass(typeof(LightGbmArguments.GossBooster), typeof(LightGbmArguments.GossBooster.Arguments), - typeof(SignatureLightGBMBooster), LightGbmArguments.GossBooster.FriendlyName, LightGbmArguments.GossBooster.Name)] +[assembly: LoadableClass(typeof(Options.TreeBooster), typeof(Options.TreeBooster.Arguments), + typeof(SignatureLightGBMBooster), Options.TreeBooster.FriendlyName, Options.TreeBooster.Name)] +[assembly: LoadableClass(typeof(Options.DartBooster), typeof(Options.DartBooster.Arguments), + typeof(SignatureLightGBMBooster), Options.DartBooster.FriendlyName, Options.DartBooster.Name)] +[assembly: LoadableClass(typeof(Options.GossBooster), typeof(Options.GossBooster.Arguments), + typeof(SignatureLightGBMBooster), Options.GossBooster.FriendlyName, Options.GossBooster.Name)] -[assembly: EntryPointModule(typeof(LightGbmArguments.TreeBooster.Arguments))] -[assembly: EntryPointModule(typeof(LightGbmArguments.DartBooster.Arguments))] -[assembly: EntryPointModule(typeof(LightGbmArguments.GossBooster.Arguments))] +[assembly: EntryPointModule(typeof(Options.TreeBooster.Arguments))] +[assembly: EntryPointModule(typeof(Options.DartBooster.Arguments))] +[assembly: EntryPointModule(typeof(Options.GossBooster.Arguments))] namespace Microsoft.ML.LightGBM { @@ -39,7 +39,7 @@ public interface IBoosterParameter /// Parameters names comes from LightGBM library. /// See https://github.com/Microsoft/LightGBM/blob/master/docs/Parameters.rst. /// - public sealed class LightGbmArguments : LearnerInputBaseWithGroupId + public sealed class Options : LearnerInputBaseWithGroupId { public abstract class BoosterParameter : IBoosterParameter where TArgs : class, new() diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index 149ecccd15..0003eb52bf 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -16,7 +16,7 @@ using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; -[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(LightGbmArguments), +[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(Options), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) }, LightGbmBinaryTrainer.UserName, LightGbmBinaryTrainer.LoadNameValue, LightGbmBinaryTrainer.ShortName, DocName = "trainer/LightGBM.md")] @@ -95,8 +95,8 @@ public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase PredictionKind.BinaryClassification; - internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) + internal LightGbmBinaryTrainer(IHostEnvironment env, Options options) + : base(env, LoadNameValue, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn)) { } @@ -118,7 +118,7 @@ internal LightGbmBinaryTrainer(IHostEnvironment env, int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) + int numBoostRound = LightGBM.Options.Defaults.NumBoostRound) : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound) { } @@ -181,14 +181,14 @@ public static partial class LightGbm ShortName = LightGbmBinaryTrainer.ShortName, XmlInclude = new[] { @"", @""})] - public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, LightGbmArguments input) + public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainLightGBM"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new LightGbmBinaryTrainer(host, input), getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)); diff --git a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs index 5f0fa25a1a..7c610da0b7 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs @@ -31,7 +31,7 @@ public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.Regressio int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) + int numBoostRound = Options.Defaults.NumBoostRound) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); @@ -42,13 +42,13 @@ public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.Regressio /// Predict a target using a decision tree regression model trained with the . /// /// The . - /// Advanced options to the algorithm. + /// Advanced options to the algorithm. public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog, - LightGbmArguments advancedSettings) + Options options) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new LightGbmRegressorTrainer(env, advancedSettings); + return new LightGbmRegressorTrainer(env, options); } /// @@ -69,7 +69,7 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) + int numBoostRound = Options.Defaults.NumBoostRound) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); @@ -80,13 +80,13 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi /// Predict a target using a decision tree binary classification model trained with the . /// /// The . - /// Advanced options to the algorithm. + /// Advanced options to the algorithm. public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, - LightGbmArguments advancedSettings) + Options options) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new LightGbmBinaryTrainer(env, advancedSettings); + return new LightGbmBinaryTrainer(env, options); } /// @@ -109,7 +109,7 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) + int numBoostRound = Options.Defaults.NumBoostRound) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); @@ -120,13 +120,13 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer /// Predict a target using a decision tree binary classification model trained with the . /// /// The . - /// Advanced options to the algorithm. + /// Advanced options to the algorithm. public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainers catalog, - LightGbmArguments advancedSettings) + Options options) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new LightGbmRankingTrainer(env, advancedSettings); + return new LightGbmRankingTrainer(env, options); } /// @@ -147,7 +147,7 @@ public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCa int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) + int numBoostRound = Options.Defaults.NumBoostRound) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); @@ -158,13 +158,13 @@ public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCa /// Predict a target using a decision tree binary classification model trained with the . /// /// The . - /// Advanced options to the algorithm. + /// Advanced options to the algorithm. public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, - LightGbmArguments advancedSettings) + Options options) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new LightGbmMulticlassTrainer(env, advancedSettings); + return new LightGbmMulticlassTrainer(env, options); } } } diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 5347ece215..bec37c9a54 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -14,7 +14,7 @@ using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; -[assembly: LoadableClass(LightGbmMulticlassTrainer.Summary, typeof(LightGbmMulticlassTrainer), typeof(LightGbmArguments), +[assembly: LoadableClass(LightGbmMulticlassTrainer.Summary, typeof(LightGbmMulticlassTrainer), typeof(Options), new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer) }, "LightGBM Multi-class Classifier", LightGbmMulticlassTrainer.LoadNameValue, LightGbmMulticlassTrainer.ShortName, DocName = "trainer/LightGBM.md")] @@ -34,8 +34,8 @@ public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase PredictionKind.MultiClassClassification; - internal LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) + internal LightGbmMulticlassTrainer(IHostEnvironment env, Options options) + : base(env, LoadNameValue, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn)) { _numClass = -1; } @@ -58,7 +58,7 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env, int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) + int numBoostRound = LightGBM.Options.Defaults.NumBoostRound) : base(env, LoadNameValue, TrainerUtils.MakeU4ScalarColumn(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound) { _numClass = -1; @@ -238,14 +238,14 @@ public static partial class LightGbm ShortName = LightGbmMulticlassTrainer.ShortName, XmlInclude = new[] { @"", @""})] - public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, LightGbmArguments input) + public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainLightGBM"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new LightGbmMulticlassTrainer(host, input), getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)); diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 56fd6b02cf..01cee62ad8 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -13,7 +13,7 @@ using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; -[assembly: LoadableClass(LightGbmRankingTrainer.UserName, typeof(LightGbmRankingTrainer), typeof(LightGbmArguments), +[assembly: LoadableClass(LightGbmRankingTrainer.UserName, typeof(LightGbmRankingTrainer), typeof(Options), new[] { typeof(SignatureRankerTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) }, "LightGBM Ranking", LightGbmRankingTrainer.LoadNameValue, LightGbmRankingTrainer.ShortName, DocName = "trainer/LightGBM.md")] @@ -81,8 +81,8 @@ public sealed class LightGbmRankingTrainer : LightGbmTrainerBase PredictionKind.Ranking; - internal LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, LoadNameValue, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) + internal LightGbmRankingTrainer(IHostEnvironment env, Options options) + : base(env, LoadNameValue, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn)) { } @@ -106,7 +106,7 @@ internal LightGbmRankingTrainer(IHostEnvironment env, int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) + int numBoostRound = LightGBM.Options.Defaults.NumBoostRound) : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weights, groupId, numLeaves, minDataPerLeaf, learningRate, numBoostRound) { Host.CheckNonEmpty(groupId, nameof(groupId)); @@ -193,14 +193,14 @@ public static partial class LightGbm ShortName = LightGbmRankingTrainer.ShortName, XmlInclude = new[] { @"", @""})] - public static CommonOutputs.RankingOutput TrainRanking(IHostEnvironment env, LightGbmArguments input) + public static CommonOutputs.RankingOutput TrainRanking(IHostEnvironment env, Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainLightGBM"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new LightGbmRankingTrainer(host, input), getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn), diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index f937c593f2..4c05e27f4a 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -13,7 +13,7 @@ using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; -[assembly: LoadableClass(LightGbmRegressorTrainer.Summary, typeof(LightGbmRegressorTrainer), typeof(LightGbmArguments), +[assembly: LoadableClass(LightGbmRegressorTrainer.Summary, typeof(LightGbmRegressorTrainer), typeof(Options), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) }, LightGbmRegressorTrainer.UserNameValue, LightGbmRegressorTrainer.LoadNameValue, LightGbmRegressorTrainer.ShortName, DocName = "trainer/LightGBM.md")] @@ -103,13 +103,13 @@ internal LightGbmRegressorTrainer(IHostEnvironment env, int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, - int numBoostRound = LightGbmArguments.Defaults.NumBoostRound) + int numBoostRound = LightGBM.Options.Defaults.NumBoostRound) : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound) { } - internal LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, LoadNameValue, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) + internal LightGbmRegressorTrainer(IHostEnvironment env, Options options) + : base(env, LoadNameValue, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn)) { } @@ -167,14 +167,14 @@ public static partial class LightGbm ShortName = LightGbmRegressorTrainer.ShortName, XmlInclude = new[] { @"", @""})] - public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, LightGbmArguments input) + public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainLightGBM"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new LightGbmRegressorTrainer(host, input), getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)); diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index 5ff68989e2..d580db4514 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -41,7 +41,7 @@ private sealed class CategoricalMetaData public bool[] IsCategoricalFeature; } - private protected readonly LightGbmArguments Args; + private protected readonly Options Args; /// /// Stores argumments as objects to convert them to invariant string type in the end so that @@ -70,7 +70,7 @@ private protected LightGbmTrainerBase(IHostEnvironment env, int numBoostRound) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { - Args = new LightGbmArguments(); + Args = new Options(); Args.NumLeaves = numLeaves; Args.MinDataPerLeaf = minDataPerLeaf; @@ -89,12 +89,12 @@ private protected LightGbmTrainerBase(IHostEnvironment env, InitParallelTraining(); } - private protected LightGbmTrainerBase(IHostEnvironment env, string name, LightGbmArguments args, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + private protected LightGbmTrainerBase(IHostEnvironment env, string name, Options options, SchemaShape.Column label) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn)) { - Host.CheckValue(args, nameof(args)); + Host.CheckValue(options, nameof(options)); - Args = args; + Args = options; InitParallelTraining(); } diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index cbb12ae3ef..91e1731838 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -53,10 +53,10 @@ Trainers.FieldAwareFactorizationMachineBinaryClassifier Train a field-aware fact Trainers.GeneralizedAdditiveModelBinaryClassifier Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainBinary Microsoft.ML.Trainers.FastTree.BinaryClassificationGamTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.GeneralizedAdditiveModelRegressor Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainRegression Microsoft.ML.Trainers.FastTree.RegressionGamTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.KMeansPlusPlusClusterer K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers. Microsoft.ML.Trainers.KMeans.KMeansPlusPlusTrainer TrainKMeans Microsoft.ML.Trainers.KMeans.KMeansPlusPlusTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+ClusteringOutput -Trainers.LightGbmBinaryClassifier Train a LightGBM binary classification model. Microsoft.ML.LightGBM.LightGbm TrainBinary Microsoft.ML.LightGBM.LightGbmArguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Trainers.LightGbmClassifier Train a LightGBM multi class model. Microsoft.ML.LightGBM.LightGbm TrainMultiClass Microsoft.ML.LightGBM.LightGbmArguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput -Trainers.LightGbmRanker Train a LightGBM ranking model. Microsoft.ML.LightGBM.LightGbm TrainRanking Microsoft.ML.LightGBM.LightGbmArguments Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput -Trainers.LightGbmRegressor LightGBM Regression Microsoft.ML.LightGBM.LightGbm TrainRegression Microsoft.ML.LightGBM.LightGbmArguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput +Trainers.LightGbmBinaryClassifier Train a LightGBM binary classification model. Microsoft.ML.LightGBM.LightGbm TrainBinary Microsoft.ML.LightGBM.Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput +Trainers.LightGbmClassifier Train a LightGBM multi class model. Microsoft.ML.LightGBM.LightGbm TrainMultiClass Microsoft.ML.LightGBM.Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput +Trainers.LightGbmRanker Train a LightGBM ranking model. Microsoft.ML.LightGBM.LightGbm TrainRanking Microsoft.ML.LightGBM.Options Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput +Trainers.LightGbmRegressor LightGBM Regression Microsoft.ML.LightGBM.LightGbm TrainRegression Microsoft.ML.LightGBM.Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.LinearSvmBinaryClassifier Train a linear SVM. Microsoft.ML.Trainers.Online.LinearSvmTrainer TrainLinearSvm Microsoft.ML.Trainers.Online.LinearSvmTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.LogisticRegressionBinaryClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Learners.LogisticRegression TrainBinary Microsoft.ML.Learners.LogisticRegression+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.LogisticRegressionClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Learners.LogisticRegression TrainMultiClass Microsoft.ML.Learners.MulticlassLogisticRegression+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index 37ed1745b7..4d4c16e6eb 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -781,7 +781,7 @@ public void TestMultiClassEnsembleCombiner() var predictors = new PredictorModel[] { - LightGbm.TrainMultiClass(Env, new LightGbmArguments + LightGbm.TrainMultiClass(Env, new Options { FeatureColumn = "Features", NumBoostRound = 5, diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index c3b11252ab..7d3c4e05e2 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -46,7 +46,7 @@ public void LightGBMBinaryEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); - var trainer = new LightGbmBinaryTrainer(Env, new LightGbmArguments + var trainer = new LightGbmBinaryTrainer(Env, new Options { NumLeaves = 10, NThread = 1, @@ -164,7 +164,7 @@ public void FastTreeRegressorEstimator() public void LightGBMRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new LightGbmRegressorTrainer(Env, new LightGbmArguments + var trainer = new LightGbmRegressorTrainer(Env, new Options { NThread = 1, NormalizeFeatures = NormalizeOption.Warn, @@ -290,7 +290,7 @@ private void LightGbmHelper(bool useSoftmax, out string modelString, out List Date: Mon, 28 Jan 2019 19:02:30 +0000 Subject: [PATCH 3/4] review comments --- src/Microsoft.ML.LightGBM/LightGbmCatalog.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs index 7c610da0b7..392f4e084f 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs @@ -90,7 +90,7 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi } /// - /// Predict a target using a decision tree binary classification model trained with the . + /// Predict a target using a decision tree ranking model trained with the . /// /// The . /// The labelColumn column. @@ -117,7 +117,7 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer } /// - /// Predict a target using a decision tree binary classification model trained with the . + /// Predict a target using a decision tree ranking model trained with the . /// /// The . /// Advanced options to the algorithm. @@ -130,9 +130,9 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer } /// - /// Predict a target using a decision tree binary classification model trained with the . + /// Predict a target using a decision tree multiclass classification model trained with the . /// - /// The . + /// The . /// The labelColumn column. /// The features column. /// The weights column. @@ -155,9 +155,9 @@ public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCa } /// - /// Predict a target using a decision tree binary classification model trained with the . + /// Predict a target using a decision tree multiclass classification model trained with the . /// - /// The . + /// The . /// Advanced options to the algorithm. public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, Options options) From 347ab27971477b64c35d61c257e2f5e3619f53e8 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 28 Jan 2019 19:19:27 +0000 Subject: [PATCH 4/4] update tests to exercise the catalog entries --- .../TrainerEstimators/TreeEstimators.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index 7d3c4e05e2..e14d1003b9 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -46,7 +46,7 @@ public void LightGBMBinaryEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); - var trainer = new LightGbmBinaryTrainer(Env, new Options + var trainer = ML.BinaryClassification.Trainers.LightGbm(new Options { NumLeaves = 10, NThread = 1, @@ -132,7 +132,7 @@ public void LightGBMRankerEstimator() { var (pipe, dataView) = GetRankingPipeline(); - var trainer = new LightGbmRankingTrainer(Env, labelColumn: "Label0", featureColumn: "NumericFeatures", groupId: "Group", learningRate: 0.4); + var trainer = ML.Ranking.Trainers.LightGbm(labelColumn: "Label0", featureColumn: "NumericFeatures", groupIdColumn: "Group", learningRate: 0.4); var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView); @@ -164,7 +164,7 @@ public void FastTreeRegressorEstimator() public void LightGBMRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new LightGbmRegressorTrainer(Env, new Options + var trainer = ML.Regression.Trainers.LightGbm(new Options { NThread = 1, NormalizeFeatures = NormalizeOption.Warn, @@ -240,7 +240,7 @@ public void FastForestRegressorEstimator() public void LightGbmMultiClassEstimator() { var (pipeline, dataView) = GetMultiClassPipeline(); - var trainer = new LightGbmMulticlassTrainer(Env, learningRate: 0.4); + var trainer = ML.MulticlassClassification.Trainers.LightGbm(learningRate: 0.4); var pipe = pipeline.Append(trainer) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); TestEstimatorCore(pipe, dataView);