diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/LightGBMBinaryClassification.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/LightGBMBinaryClassification.cs
new file mode 100644
index 0000000000..edd4e31504
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/LightGBMBinaryClassification.cs
@@ -0,0 +1,41 @@
+using Microsoft.ML.Transforms.Categorical;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ public class LightGbmBinaryClassification
+ {
+ // This example requires installation of additional nuget package Microsoft.ML.LightGBM.
+ public static void Example()
+ {
+ // Creating the ML.Net IHostEnvironment object, needed for the pipeline.
+ var mlContext = new MLContext();
+
+ // Download and featurize the dataset.
+ var dataview = SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext);
+
+ // Leave out 10% of data for testing.
+ var split = mlContext.BinaryClassification.TrainTestSplit(dataview, testFraction: 0.1);
+
+ // Create the Estimator.
+ var pipeline = mlContext.BinaryClassification.Trainers.LightGbm("IsOver50K", "Features");
+
+ // Fit this Pipeline to the Training Data.
+ var model = pipeline.Fit(split.TrainSet);
+
+ // Evaluate how the model is doing on the test data.
+ var dataWithPredictions = model.Transform(split.TestSet);
+
+ var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions, "IsOver50K");
+ SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
+
+ // Output:
+ // Accuracy: 0.88
+ // AUC: 0.93
+ // F1 Score: 0.71
+ // Negative Precision: 0.90
+ // Negative Recall: 0.94
+ // Positive Precision: 0.76
+ // Positive Recall: 0.66
+ }
+ }
+}
\ No newline at end of file
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/LightGBMBinaryClassificationWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/LightGBMBinaryClassificationWithOptions.cs
new file mode 100644
index 0000000000..20924bc29f
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/LightGBMBinaryClassificationWithOptions.cs
@@ -0,0 +1,53 @@
+using Microsoft.ML.LightGBM;
+using Microsoft.ML.Transforms.Categorical;
+using static Microsoft.ML.LightGBM.Options;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ class LightGbmBinaryClassificationWithOptions
+ {
+ // This example requires installation of additional nuget package Microsoft.ML.LightGBM.
+ public static void Example()
+ {
+ // Creating the ML.Net IHostEnvironment object, needed for the pipeline
+ var mlContext = new MLContext();
+
+ // Download and featurize the dataset.
+ var dataview = SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext);
+
+ // Leave out 10% of data for testing.
+ var split = mlContext.BinaryClassification.TrainTestSplit(dataview, testFraction: 0.1);
+
+ // Create the pipeline with LightGbm Estimator using advanced options.
+ var pipeline = mlContext.BinaryClassification.Trainers.LightGbm(
+ new Options
+ {
+ LabelColumn = "IsOver50K",
+ FeatureColumn = "Features",
+ Booster = new GossBooster.Options
+ {
+ TopRate = 0.3,
+ OtherRate = 0.2
+ }
+ });
+
+ // Fit this Pipeline to the Training Data.
+ var model = pipeline.Fit(split.TrainSet);
+
+ // Evaluate how the model is doing on the test data.
+ var dataWithPredictions = model.Transform(split.TestSet);
+
+ var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions, "IsOver50K");
+ SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
+
+ // Output:
+ // Accuracy: 0.88
+ // AUC: 0.93
+ // F1 Score: 0.71
+ // Negative Precision: 0.90
+ // Negative Recall: 0.94
+ // Positive Precision: 0.76
+ // Positive Recall: 0.67
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGBMMulticlassClassification.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGBMMulticlassClassification.cs
new file mode 100644
index 0000000000..8731c6bc50
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGBMMulticlassClassification.cs
@@ -0,0 +1,85 @@
+using System;
+using System.Linq;
+using Microsoft.ML.Data;
+using Microsoft.ML.SamplesUtils;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ class LightGbmMulticlassClassification
+ {
+ // This example requires installation of additional nuget package Microsoft.ML.LightGBM.
+ public static void Example()
+ {
+ // Create a general context for ML.NET operations. It can be used for exception tracking and logging,
+ // as a catalog of available operations and as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create in-memory examples as C# native class.
+ var examples = DatasetUtils.GenerateRandomMulticlassClassificationExamples(1000);
+
+ // Convert native C# class to IDataView, a consumble format to ML.NET functions.
+ var dataView = mlContext.Data.ReadFromEnumerable(examples);
+
+ //////////////////// Data Preview ////////////////////
+ // Label Features
+ // AA 0.7262433,0.8173254,0.7680227,0.5581612,0.2060332,0.5588848,0.9060271,0.4421779,0.9775497,0.2737045
+ // BB 0.4919063,0.6673147,0.8326591,0.6695119,1.182151,0.230367,1.06237,1.195347,0.8771811,0.5145918
+ // CC 1.216908,1.248052,1.391902,0.4326252,1.099942,0.9262842,1.334019,1.08762,0.9468155,0.4811099
+ // DD 0.7871246,1.053327,0.8971719,1.588544,1.242697,1.362964,0.6303943,0.9810045,0.9431419,1.557455
+
+ // Create a pipeline.
+ // - Convert the string labels into key types.
+ // - Apply LightGbm multiclass trainer.
+ var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelIndex", "Label")
+ .Append(mlContext.MulticlassClassification.Trainers.LightGbm(labelColumn: "LabelIndex"))
+ .Append(mlContext.Transforms.Conversion.MapValueToKey("PredictedLabelIndex", "PredictedLabel"))
+ .Append(mlContext.Transforms.CopyColumns("Scores", "Score"));
+
+ // Split the static-typed data into training and test sets. Only training set is used in fitting
+ // the created pipeline. Metrics are computed on the test.
+ var split = mlContext.MulticlassClassification.TrainTestSplit(dataView, testFraction: 0.5);
+
+ // Train the model.
+ var model = pipeline.Fit(split.TrainSet);
+
+ // Do prediction on the test set.
+ var dataWithPredictions = model.Transform(split.TestSet);
+
+ // Evaluate the trained model using the test set.
+ var metrics = mlContext.MulticlassClassification.Evaluate(dataWithPredictions, label: "LabelIndex");
+
+ // Check if metrics are reasonable.
+ Console.WriteLine($"Macro accuracy: {metrics.AccuracyMacro:F4}, Micro accuracy: {metrics.AccuracyMicro:F4}.");
+ // Console output:
+ // Macro accuracy: 0.8655, Micro accuracy: 0.8651.
+
+ // IDataView with predictions, to an IEnumerable.
+ var nativePredictions = mlContext.CreateEnumerable(dataWithPredictions, false).ToList();
+
+ // Get schema object out of the prediction. It contains metadata such as the mapping from predicted label index
+ // (e.g., 1) to its actual label (e.g., "AA").
+ // The metadata can be used to get all the unique labels used during training.
+ var labelBuffer = new VBuffer>();
+ dataWithPredictions.Schema["PredictedLabelIndex"].GetKeyValues(ref labelBuffer);
+ // nativeLabels is { "AA" , "BB", "CC", "DD" }
+ var nativeLabels = labelBuffer.DenseValues().ToArray(); // nativeLabels[nativePrediction.PredictedLabelIndex - 1] is the original label indexed by nativePrediction.PredictedLabelIndex.
+
+
+ // Show prediction result for the 3rd example.
+ var nativePrediction = nativePredictions[2];
+ // Console output:
+ // Our predicted label to this example is "AA" with probability 0.9257.
+ Console.WriteLine($"Our predicted label to this example is {nativeLabels[(int)nativePrediction.PredictedLabelIndex - 1]} " +
+ $"with probability {nativePrediction.Scores[(int)nativePrediction.PredictedLabelIndex - 1]:F4}.");
+
+ // Scores and nativeLabels are two parallel attributes; that is, Scores[i] is the probability of being nativeLabels[i].
+ // Console output:
+ // The probability of being class "AA" is 0.9257.
+ // The probability of being class "BB" is 0.0739.
+ // The probability of being class "CC" is 0.0002.
+ // The probability of being class "DD" is 0.0001.
+ for (int i = 0; i < nativeLabels.Length; ++i)
+ Console.WriteLine($"The probability of being class {nativeLabels[i]} is {nativePrediction.Scores[i]:F4}.");
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGBMMulticlassClassificationWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGBMMulticlassClassificationWithOptions.cs
new file mode 100644
index 0000000000..7d98c9318e
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGBMMulticlassClassificationWithOptions.cs
@@ -0,0 +1,96 @@
+using System;
+using System.Linq;
+using Microsoft.ML.Data;
+using Microsoft.ML.LightGBM;
+using Microsoft.ML.SamplesUtils;
+using static Microsoft.ML.LightGBM.Options;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ class LightGbmMulticlassClassificationWithOptions
+ {
+ // This example requires installation of additional nuget package Microsoft.ML.LightGBM.
+ public static void Example()
+ {
+ // Create a general context for ML.NET operations. It can be used for exception tracking and logging,
+ // as a catalog of available operations and as the source of randomness.
+ var mlContext = new MLContext(seed: 0);
+
+ // Create in-memory examples as C# native class.
+ var examples = DatasetUtils.GenerateRandomMulticlassClassificationExamples(1000);
+
+ // Convert native C# class to IDataView, a consumble format to ML.NET functions.
+ var dataView = mlContext.Data.ReadFromEnumerable(examples);
+
+ //////////////////// Data Preview ////////////////////
+ // Label Features
+ // AA 0.7262433,0.8173254,0.7680227,0.5581612,0.2060332,0.5588848,0.9060271,0.4421779,0.9775497,0.2737045
+ // BB 0.4919063,0.6673147,0.8326591,0.6695119,1.182151,0.230367,1.06237,1.195347,0.8771811,0.5145918
+ // CC 1.216908,1.248052,1.391902,0.4326252,1.099942,0.9262842,1.334019,1.08762,0.9468155,0.4811099
+ // DD 0.7871246,1.053327,0.8971719,1.588544,1.242697,1.362964,0.6303943,0.9810045,0.9431419,1.557455
+
+ // Create a pipeline.
+ // - Convert the string labels into key types.
+ // - Apply LightGbm multiclass trainer with advanced options.
+ var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelIndex", "Label")
+ .Append(mlContext.MulticlassClassification.Trainers.LightGbm(new Options
+ {
+ LabelColumn = "LabelIndex",
+ FeatureColumn = "Features",
+ Booster = new DartBooster.Options
+ {
+ DropRate = 0.15,
+ XgboostDartMode = false
+ }
+ }))
+ .Append(mlContext.Transforms.Conversion.MapValueToKey("PredictedLabelIndex", "PredictedLabel"))
+ .Append(mlContext.Transforms.CopyColumns("Scores", "Score"));
+
+ // Split the static-typed data into training and test sets. Only training set is used in fitting
+ // the created pipeline. Metrics are computed on the test.
+ var split = mlContext.MulticlassClassification.TrainTestSplit(dataView, testFraction: 0.5);
+
+ // Train the model.
+ var model = pipeline.Fit(split.TrainSet);
+
+ // Do prediction on the test set.
+ var dataWithPredictions = model.Transform(split.TestSet);
+
+ // Evaluate the trained model using the test set.
+ var metrics = mlContext.MulticlassClassification.Evaluate(dataWithPredictions, label: "LabelIndex");
+
+ // Check if metrics are reasonable.
+ Console.WriteLine($"Macro accuracy: {metrics.AccuracyMacro:F4}, Micro accuracy: {metrics.AccuracyMicro:F4}.");
+ // Console output:
+ // Macro accuracy: 0.8619, Micro accuracy: 0.8611.
+
+ // IDataView with predictions, to an IEnumerable.
+ var nativePredictions = mlContext.CreateEnumerable(dataWithPredictions, false).ToList();
+
+ // Get schema object out of the prediction. It contains metadata such as the mapping from predicted label index
+ // (e.g., 1) to its actual label (e.g., "AA").
+ // The metadata can be used to get all the unique labels used during training.
+ var labelBuffer = new VBuffer>();
+ dataWithPredictions.Schema["PredictedLabelIndex"].GetKeyValues(ref labelBuffer);
+ // nativeLabels is { "AA" , "BB", "CC", "DD" }
+ var nativeLabels = labelBuffer.DenseValues().ToArray(); // nativeLabels[nativePrediction.PredictedLabelIndex - 1] is the original label indexed by nativePrediction.PredictedLabelIndex.
+
+
+ // Show prediction result for the 3rd example.
+ var nativePrediction = nativePredictions[2];
+ // Console output:
+ // Our predicted label to this example is AA with probability 0.8986.
+ Console.WriteLine($"Our predicted label to this example is {nativeLabels[(int)nativePrediction.PredictedLabelIndex - 1]} " +
+ $"with probability {nativePrediction.Scores[(int)nativePrediction.PredictedLabelIndex - 1]:F4}.");
+
+ // Scores and nativeLabels are two parallel attributes; that is, Scores[i] is the probability of being nativeLabels[i].
+ // Console output:
+ // The probability of being class AA is 0.8986.
+ // The probability of being class BB is 0.0961.
+ // The probability of being class CC is 0.0050.
+ // The probability of being class DD is 0.0003.
+ for (int i = 0; i < nativeLabels.Length; ++i)
+ Console.WriteLine($"The probability of being class {nativeLabels[i]} is {nativePrediction.Scores[i]:F4}.");
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGBMRegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGBMRegression.cs
new file mode 100644
index 0000000000..c4b6f9f68c
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGBMRegression.cs
@@ -0,0 +1,65 @@
+using System;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ class LightGbmRegression
+ {
+ // This example requires installation of additional nuget package Microsoft.ML.LightGBM.
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
+ // as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Download and load the housing dataset into an IDataView.
+ var dataView = SamplesUtils.DatasetUtils.LoadHousingRegressionDataset(mlContext);
+
+ //////////////////// Data Preview ////////////////////
+ /// Only 6 columns are displayed here.
+ // MedianHomeValue CrimesPerCapita PercentResidental PercentNonRetail CharlesRiver NitricOxides RoomsPerDwelling PercentPre40s ...
+ // 24.00 0.00632 18.00 2.310 0 0.5380 6.5750 65.20 ...
+ // 21.60 0.02731 00.00 7.070 0 0.4690 6.4210 78.90 ...
+ // 34.70 0.02729 00.00 7.070 0 0.4690 7.1850 61.10 ...
+
+ var split = mlContext.Regression.TrainTestSplit(dataView, testFraction: 0.1);
+
+ // Create the estimator, here we only need LightGbm trainer
+ // as data is already processed in a form consumable by the trainer.
+ var labelName = "MedianHomeValue";
+ var featureNames = dataView.Schema
+ .Select(column => column.Name) // Get the column names
+ .Where(name => name != labelName) // Drop the Label
+ .ToArray();
+ var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
+ .Append(mlContext.Regression.Trainers.LightGbm(
+ labelColumn: labelName,
+ numLeaves: 4,
+ minDataPerLeaf: 6,
+ learningRate: 0.001));
+
+ // Fit this pipeline to the training data.
+ var model = pipeline.Fit(split.TrainSet);
+
+ // Get the feature importance based on the information gain used during training.
+ VBuffer weights = default;
+ model.LastTransformer.Model.GetFeatureWeights(ref weights);
+ var weightsValues = weights.DenseValues().ToArray();
+ Console.WriteLine($"weight 0 - {weightsValues[0]}"); // CrimesPerCapita (weight 0) = 0.1898361
+ Console.WriteLine($"weight 5 - {weightsValues[5]}"); // RoomsPerDwelling (weight 5) = 1
+
+ // Evaluate how the model is doing on the test data.
+ var dataWithPredictions = model.Transform(split.TestSet);
+ var metrics = mlContext.Regression.Evaluate(dataWithPredictions, label: labelName);
+ SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
+
+ // Output
+ // L1: 4.97
+ // L2: 51.37
+ // LossFunction: 51.37
+ // RMS: 7.17
+ // RSquared: 0.08
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGBMRegressionWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGBMRegressionWithOptions.cs
new file mode 100644
index 0000000000..3f73df053e
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGBMRegressionWithOptions.cs
@@ -0,0 +1,75 @@
+using System;
+using System.Linq;
+using Microsoft.ML.Data;
+using Microsoft.ML.LightGBM;
+using static Microsoft.ML.LightGBM.Options;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ class LightGbmRegressionWithOptions
+ {
+ // This example requires installation of additional nuget package Microsoft.ML.LightGBM.
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
+ // as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Download and load the housing dataset into an IDataView.
+ var dataView = SamplesUtils.DatasetUtils.LoadHousingRegressionDataset(mlContext);
+
+ //////////////////// Data Preview ////////////////////
+ /// Only 6 columns are displayed here.
+ // MedianHomeValue CrimesPerCapita PercentResidental PercentNonRetail CharlesRiver NitricOxides RoomsPerDwelling PercentPre40s ...
+ // 24.00 0.00632 18.00 2.310 0 0.5380 6.5750 65.20 ...
+ // 21.60 0.02731 00.00 7.070 0 0.4690 6.4210 78.90 ...
+ // 34.70 0.02729 00.00 7.070 0 0.4690 7.1850 61.10 ...
+
+ var split = mlContext.Regression.TrainTestSplit(dataView, testFraction: 0.1);
+
+ // Create a pipeline with LightGbm estimator with advanced options.
+ // Here we only need LightGbm trainer as data is already processed
+ // in a form consumable by the trainer.
+ var labelName = "MedianHomeValue";
+ var featureNames = dataView.Schema
+ .Select(column => column.Name) // Get the column names
+ .Where(name => name != labelName) // Drop the Label
+ .ToArray();
+ var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
+ .Append(mlContext.Regression.Trainers.LightGbm(new Options
+ {
+ LabelColumn = labelName,
+ NumLeaves = 4,
+ MinDataPerLeaf = 6,
+ LearningRate = 0.001,
+ Booster = new GossBooster.Options
+ {
+ TopRate = 0.3,
+ OtherRate = 0.2
+ }
+ }));
+
+ // Fit this pipeline to the training data.
+ var model = pipeline.Fit(split.TrainSet);
+
+ // Get the feature importance based on the information gain used during training.
+ VBuffer weights = default;
+ model.LastTransformer.Model.GetFeatureWeights(ref weights);
+ var weightsValues = weights.DenseValues().ToArray();
+ Console.WriteLine($"weight 0 - {weightsValues[0]}"); // CrimesPerCapita (weight 0) = 0.1898361
+ Console.WriteLine($"weight 5 - {weightsValues[5]}"); // RoomsPerDwelling (weight 5) = 1
+
+ // Evaluate how the model is doing on the test data.
+ var dataWithPredictions = model.Transform(split.TestSet);
+ var metrics = mlContext.Regression.Evaluate(dataWithPredictions, label: labelName);
+ SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
+
+ // Output
+ // L1: 4.97
+ // L2: 51.37
+ // LossFunction: 51.37
+ // RMS: 7.17
+ // RSquared: 0.08
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs b/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs
index fffad9181b..edcfd39efd 100644
--- a/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs
+++ b/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs
@@ -11,8 +11,8 @@ public class LightGbmRegressionExample
public static void LightGbmRegression()
{
// Downloading a regression dataset from github.com/dotnet/machinelearning
- // this will create a housing.txt file in the filsystem this code will run
- // you can open the file to see the data.
+ // this will create a housing.txt file in the filsystem.
+ // You can open the file to see the data.
string dataFile = SamplesUtils.DatasetUtils.DownloadHousingRegressionDataset();
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
diff --git a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs
index e142d2d30c..0085ad31aa 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs
@@ -11,25 +11,26 @@
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.LightGBM;
-[assembly: LoadableClass(typeof(Options.TreeBooster), typeof(Options.TreeBooster.Arguments),
+[assembly: LoadableClass(typeof(Options.TreeBooster), typeof(Options.TreeBooster.Options),
typeof(SignatureLightGBMBooster), Options.TreeBooster.FriendlyName, Options.TreeBooster.Name)]
-[assembly: LoadableClass(typeof(Options.DartBooster), typeof(Options.DartBooster.Arguments),
+[assembly: LoadableClass(typeof(Options.DartBooster), typeof(Options.DartBooster.Options),
typeof(SignatureLightGBMBooster), Options.DartBooster.FriendlyName, Options.DartBooster.Name)]
-[assembly: LoadableClass(typeof(Options.GossBooster), typeof(Options.GossBooster.Arguments),
+[assembly: LoadableClass(typeof(Options.GossBooster), typeof(Options.GossBooster.Options),
typeof(SignatureLightGBMBooster), Options.GossBooster.FriendlyName, Options.GossBooster.Name)]
-[assembly: EntryPointModule(typeof(Options.TreeBooster.Arguments))]
-[assembly: EntryPointModule(typeof(Options.DartBooster.Arguments))]
-[assembly: EntryPointModule(typeof(Options.GossBooster.Arguments))]
+[assembly: EntryPointModule(typeof(Options.TreeBooster.Options))]
+[assembly: EntryPointModule(typeof(Options.DartBooster.Options))]
+[assembly: EntryPointModule(typeof(Options.GossBooster.Options))]
namespace Microsoft.ML.LightGBM
{
- public delegate void SignatureLightGBMBooster();
+ internal delegate void SignatureLightGBMBooster();
[TlcModule.ComponentKind("BoosterParameterFunction")]
public interface ISupportBoosterParameterFactory : IComponentFactory
{
}
+
public interface IBoosterParameter
{
void UpdateParameters(Dictionary res);
@@ -54,7 +55,7 @@ protected BoosterParameter(TArgs args)
///
/// Update the parameters by specific Booster, will update parameters into "res" directly.
///
- public virtual void UpdateParameters(Dictionary res)
+ internal virtual void UpdateParameters(Dictionary res)
{
FieldInfo[] fields = Args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
foreach (var field in fields)
@@ -67,6 +68,8 @@ public virtual void UpdateParameters(Dictionary res)
res[GetArgName(field.Name)] = field.GetValue(Args);
}
}
+
+ void IBoosterParameter.UpdateParameters(Dictionary res) => UpdateParameters(res);
}
private static string GetArgName(string name)
@@ -92,17 +95,16 @@ private static string GetArgName(string name)
[BestFriend]
internal static class Defaults
{
- [BestFriend]
- internal const int NumBoostRound = 100;
+ public const int NumBoostRound = 100;
}
- public sealed class TreeBooster : BoosterParameter
+ public sealed class TreeBooster : BoosterParameter
{
- public const string Name = "gbdt";
- public const string FriendlyName = "Tree Booster";
+ internal const string Name = "gbdt";
+ internal const string FriendlyName = "Tree Booster";
[TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Traditional Gradient Boosting Decision Tree.")]
- public class Arguments : ISupportBoosterParameterFactory
+ public class Options : ISupportBoosterParameterFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Use for binary classification when classes are not balanced.", ShortName = "us")]
public bool UnbalancedSets = false;
@@ -164,10 +166,12 @@ public class Arguments : ISupportBoosterParameterFactory
" A typical value to consider: sum(negative cases) / sum(positive cases).")]
public double ScalePosWeight = 1;
- public virtual IBoosterParameter CreateComponent(IHostEnvironment env) => new TreeBooster(this);
+ internal virtual IBoosterParameter CreateComponent(IHostEnvironment env) => new TreeBooster(this);
+
+ IBoosterParameter IComponentFactory.CreateComponent(IHostEnvironment env) => CreateComponent(env);
}
- public TreeBooster(Arguments args)
+ internal TreeBooster(Options args)
: base(args)
{
Contracts.CheckUserArg(Args.MinSplitGain >= 0, nameof(Args.MinSplitGain), "must be >= 0.");
@@ -177,20 +181,20 @@ public TreeBooster(Arguments args)
Contracts.CheckUserArg(Args.ScalePosWeight > 0 && Args.ScalePosWeight <= 1, nameof(Args.ScalePosWeight), "must be in (0,1].");
}
- public override void UpdateParameters(Dictionary res)
+ internal override void UpdateParameters(Dictionary res)
{
base.UpdateParameters(res);
res["boosting_type"] = Name;
}
}
- public class DartBooster : BoosterParameter
+ public sealed class DartBooster : BoosterParameter
{
- public const string Name = "dart";
- public const string FriendlyName = "Tree Dropout Tree Booster";
+ internal const string Name = "dart";
+ internal const string FriendlyName = "Tree Dropout Tree Booster";
[TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Dropouts meet Multiple Additive Regresion Trees. See https://arxiv.org/abs/1505.01866")]
- public class Arguments : TreeBooster.Arguments
+ public class Options : TreeBooster.Options
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Drop ratio for trees. Range:(0,1).")]
[TlcModule.Range(Inf = 0.0, Max = 1.0)]
@@ -210,10 +214,10 @@ public class Arguments : TreeBooster.Arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "True will enable uniform drop.")]
public bool UniformDrop = false;
- public override IBoosterParameter CreateComponent(IHostEnvironment env) => new DartBooster(this);
+ internal override IBoosterParameter CreateComponent(IHostEnvironment env) => new DartBooster(this);
}
- public DartBooster(Arguments args)
+ internal DartBooster(Options args)
: base(args)
{
Contracts.CheckUserArg(Args.DropRate > 0 && Args.DropRate < 1, nameof(Args.DropRate), "must be in (0,1).");
@@ -221,20 +225,20 @@ public DartBooster(Arguments args)
Contracts.CheckUserArg(Args.SkipDrop >= 0 && Args.SkipDrop < 1, nameof(Args.SkipDrop), "must be in [0,1).");
}
- public override void UpdateParameters(Dictionary res)
+ internal override void UpdateParameters(Dictionary res)
{
base.UpdateParameters(res);
res["boosting_type"] = Name;
}
}
- public class GossBooster : BoosterParameter
+ public sealed class GossBooster : BoosterParameter
{
- public const string Name = "goss";
- public const string FriendlyName = "Gradient-based One-Size Sampling";
+ internal const string Name = "goss";
+ internal const string FriendlyName = "Gradient-based One-Size Sampling";
[TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Gradient-based One-Side Sampling.")]
- public class Arguments : TreeBooster.Arguments
+ public class Options : TreeBooster.Options
{
[Argument(ArgumentType.AtMostOnce,
HelpText = "Retain ratio for large gradient instances.")]
@@ -247,10 +251,10 @@ public class Arguments : TreeBooster.Arguments
[TlcModule.Range(Inf = 0.0, Max = 1.0)]
public double OtherRate = 0.1;
- public override IBoosterParameter CreateComponent(IHostEnvironment env) => new GossBooster(this);
+ internal override IBoosterParameter CreateComponent(IHostEnvironment env) => new GossBooster(this);
}
- public GossBooster(Arguments args)
+ internal GossBooster(Options args)
: base(args)
{
Contracts.CheckUserArg(Args.TopRate > 0 && Args.TopRate < 1, nameof(Args.TopRate), "must be in (0,1).");
@@ -258,7 +262,7 @@ public GossBooster(Arguments args)
Contracts.Check(Args.TopRate + Args.OtherRate <= 1, "Sum of topRate and otherRate cannot be larger than 1.");
}
- public override void UpdateParameters(Dictionary res)
+ internal override void UpdateParameters(Dictionary res)
{
base.UpdateParameters(res);
res["boosting_type"] = Name;
@@ -307,7 +311,7 @@ public enum EvalMetricType
public int MaxBin = 255;
[Argument(ArgumentType.Multiple, HelpText = "Which booster to use, can be gbtree, gblinear or dart. gbtree and dart use tree based model while gblinear uses linear function.", SortOrder = 3)]
- public ISupportBoosterParameterFactory Booster = new TreeBooster.Arguments();
+ public ISupportBoosterParameterFactory Booster = new TreeBooster.Options();
[Argument(ArgumentType.AtMostOnce, HelpText = "Verbose", ShortName = "v")]
public bool VerboseEval = false;
@@ -375,7 +379,7 @@ public enum EvalMetricType
public int? Seed;
[Argument(ArgumentType.Multiple, HelpText = "Parallel LightGBM Learning Algorithm", ShortName = "parag")]
- public ISupportParallel ParallelTrainer = new SingleTrainerFactory();
+ internal ISupportParallel ParallelTrainer = new SingleTrainerFactory();
internal Dictionary ToDictionary(IHost host)
{
diff --git a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs
index 392f4e084f..08a968ec5d 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs
@@ -24,6 +24,13 @@ public static class LightGbmExtensions
/// Number of iterations.
/// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.
/// The learning rate.
+ ///
+ ///
+ ///
+ ///
+ ///
public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
@@ -43,6 +50,13 @@ public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.Regressio
///
/// The .
/// Advanced options to the algorithm.
+ ///
+ ///
+ ///
+ ///
+ ///
public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog,
Options options)
{
@@ -62,6 +76,13 @@ public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.Regressio
/// Number of iterations.
/// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.
/// The learning rate.
+ ///
+ ///
+ ///
+ ///
+ ///
public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
@@ -81,6 +102,13 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi
///
/// The .
/// Advanced options to the algorithm.
+ ///
+ ///
+ ///
+ ///
+ ///
public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
Options options)
{
@@ -140,6 +168,13 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer
/// Number of iterations.
/// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.
/// The learning rate.
+ ///
+ ///
+ ///
+ ///
+ ///
public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
@@ -159,6 +194,13 @@ public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCa
///
/// The .
/// Advanced options to the algorithm.
+ ///
+ ///
+ ///
+ ///
+ ///
public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
Options options)
{
diff --git a/src/Microsoft.ML.LightGBM/Parallel/IParallel.cs b/src/Microsoft.ML.LightGBM/Parallel/IParallel.cs
index 8f3f005741..9948d1bb06 100644
--- a/src/Microsoft.ML.LightGBM/Parallel/IParallel.cs
+++ b/src/Microsoft.ML.LightGBM/Parallel/IParallel.cs
@@ -12,17 +12,17 @@ namespace Microsoft.ML.LightGBM
///
/// Signature of LightGBM IAllreduce
///
- public delegate void SignatureParallelTrainer();
+ internal delegate void SignatureParallelTrainer();
///
/// Reduce function define in LightGBM Cpp side
///
- public unsafe delegate void ReduceFunction(byte* src, byte* output, int typeSize, int arraySize);
+ internal unsafe delegate void ReduceFunction(byte* src, byte* output, int typeSize, int arraySize);
///
/// Definition of ReduceScatter funtion
///
- public delegate void ReduceScatterFunction([MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)]byte[] input, int inputSize, int typeSize,
+ internal delegate void ReduceScatterFunction([MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)]byte[] input, int inputSize, int typeSize,
[MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 5)]int[] blockStart, [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 5)]int[] blockLen, int numBlock,
[MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 7)]byte[] output, int outputSize,
IntPtr reducer);
@@ -30,11 +30,11 @@ public delegate void ReduceScatterFunction([MarshalAs(UnmanagedType.LPArray, Siz
///
/// Definition of Allgather funtion
///
- public delegate void AllgatherFunction([MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)]byte[] input, int inputSize,
+ internal delegate void AllgatherFunction([MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)]byte[] input, int inputSize,
[MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 4)]int[] blockStart, [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 4)]int[] blockLen, int numBlock,
[MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 6)]byte[] output, int outputSize);
- public interface IParallel
+ internal interface IParallel
{
///
/// Type of parallel
@@ -68,7 +68,7 @@ public interface IParallel
}
[TlcModule.ComponentKind("ParallelLightGBM")]
- public interface ISupportParallel : IComponentFactory
+ internal interface ISupportParallel : IComponentFactory
{
}
}
diff --git a/src/Microsoft.ML.LightGBM/Parallel/SingleTrainer.cs b/src/Microsoft.ML.LightGBM/Parallel/SingleTrainer.cs
index 8041fb3d6f..096e2277ef 100644
--- a/src/Microsoft.ML.LightGBM/Parallel/SingleTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/Parallel/SingleTrainer.cs
@@ -13,41 +13,41 @@
namespace Microsoft.ML.LightGBM
{
- public sealed class SingleTrainer : IParallel
+ internal sealed class SingleTrainer : IParallel
{
- public AllgatherFunction GetAllgatherFunction()
+ AllgatherFunction IParallel.GetAllgatherFunction()
{
return null;
}
- public ReduceScatterFunction GetReduceScatterFunction()
+ ReduceScatterFunction IParallel.GetReduceScatterFunction()
{
return null;
}
- public int NumMachines()
+ int IParallel.NumMachines()
{
return 1;
}
- public string ParallelType()
+ string IParallel.ParallelType()
{
return "serial";
}
- public int Rank()
+ int IParallel.Rank()
{
return 0;
}
- public Dictionary AdditionalParams()
+ Dictionary IParallel.AdditionalParams()
{
return null;
}
}
[TlcModule.Component(Name = "Single", Desc = "Single node machine learning process.")]
- public sealed class SingleTrainerFactory : ISupportParallel
+ internal sealed class SingleTrainerFactory : ISupportParallel
{
public IParallel CreateComponent(IHostEnvironment env) => new SingleTrainer();
}
diff --git a/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs b/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs
index 83fafd8658..0e8e1cf714 100644
--- a/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs
+++ b/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs
@@ -24,5 +24,18 @@ public static void PrintMetrics(BinaryClassificationMetrics metrics)
Console.WriteLine($"Positive Precision: {metrics.PositivePrecision:F2}");
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}");
}
+
+ ///
+ /// Pretty-print RegressionMetrics objects.
+ ///
+ /// Regression metrics.
+ public static void PrintMetrics(RegressionMetrics metrics)
+ {
+ Console.WriteLine($"L1: {metrics.L1:F2}");
+ Console.WriteLine($"L2: {metrics.L2:F2}");
+ Console.WriteLine($"LossFunction: {metrics.LossFn:F2}");
+ Console.WriteLine($"RMS: {metrics.Rms:F2}");
+ Console.WriteLine($"RSquared: {metrics.RSquared:F2}");
+ }
}
}