From 47410eaa7a69942cb7c30c2565273c9941af5d8d Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Wed, 19 Sep 2018 16:17:24 -0700 Subject: [PATCH 1/8] GAM and LightGBM conversion. Utility methods for the column shape creations. --- .../Training/TrainerUtils.cs | 43 ++++++++- .../FastTreeClassification.cs | 33 +++---- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 40 ++++---- .../FastTreeRegression.cs | 22 ++--- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 17 ++-- .../GamClassification.cs | 33 +++++-- src/Microsoft.ML.FastTree/GamRegression.cs | 23 ++++- src/Microsoft.ML.FastTree/GamTrainer.cs | 93 ++++++++++++------- .../RandomForestClassification.cs | 23 ++--- .../RandomForestRegression.cs | 29 ++---- .../LightGbmBinaryTrainer.cs | 30 +++++- .../LightGbmMulticlassTrainer.cs | 31 ++++++- .../LightGbmRankingTrainer.cs | 26 +++++- .../LightGbmRegressionTrainer.cs | 35 +++++-- .../LightGbmTrainerBase.cs | 42 +++++++-- .../Standard/SdcaMultiClass.cs | 30 ++---- .../Standard/SdcaRegression.cs | 34 ++----- 17 files changed, 361 insertions(+), 223 deletions(-) diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index 33d3d1490d..2c4bda721f 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -2,12 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - -using System; -using System.Collections.Generic; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; +using System; +using System.Collections.Generic; +using Float = System.Single; namespace Microsoft.ML.Runtime.Training { @@ -348,6 +348,41 @@ public static ValueGetter GetOptGroupGetter(this IRow row, RoleMappedData Contracts.CheckValue(data, nameof(data)); return GetOptGroupGetter(row, data.Schema); } + + public static SchemaShape.Column MakeBoolScalarLabel(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); + } + + public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); + } + + public static SchemaShape.Column MakeU4ScalarLabel(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); + } + + /// + /// The for the feature column. + /// + /// name of the feature column + public static SchemaShape.Column MakeR4VecFeature(string featureColumn) + { + return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + } + + /// + /// The for the feature column. + /// + /// name of the feature column + public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn) + { + if (weightColumn == null) + return null; + return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); + } } /// diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 97de30bf75..dccf57d45a 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -112,7 +112,6 @@ public sealed partial class FastTreeBinaryClassificationTrainer : internal const string ShortName = "ftc"; private bool[] _trainSetLabels; - private readonly SchemaShape.Column[] _outputColumns; /// /// Initializes a new instance of @@ -125,17 +124,10 @@ public sealed partial class FastTreeBinaryClassificationTrainer : /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss _sigmoidParameter = 2.0 * Args.LearningRates; - - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), - new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) - }; } private double _sigmoidParameter; @@ -143,14 +135,8 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelCol /// Initializes a new instance of by using the legacy class. /// public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args) - : base(env, args, MakeLabelColumn(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), - new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) - }; // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss _sigmoidParameter = 2.0 * Args.LearningRates; } @@ -230,11 +216,6 @@ protected override void PrepareLabels(IChannel ch) //Here we set regression labels to what is in bin file if the values were not overriden with floats } - private static SchemaShape.Column MakeLabelColumn(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); - } - protected override Test ConstructTestForTrainingData() { return new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, _sigmoidParameter); @@ -282,7 +263,15 @@ protected override void InitializeTests() protected override BinaryPredictionTransformer> MakeTransformer(IPredictorWithFeatureWeights model, ISchema trainSchema) => new BinaryPredictionTransformer>(Host, model, trainSchema, FeatureColumn.Name); - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } internal sealed class ObjectiveImpl : ObjectiveFunctionBase, IStepSearch { diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 4c36e6e4cc..2b14de03ea 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -2,20 +2,21 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.InteropServices; -using System.Text; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.FastTree.Internal; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Training; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; // REVIEW: Do we really need all these names? [assembly: LoadableClass(FastTreeRankingTrainer.Summary, typeof(FastTreeRankingTrainer), typeof(FastTreeRankingTrainer.Arguments), @@ -58,8 +59,6 @@ public sealed partial class FastTreeRankingTrainer /// public override PredictionKind PredictionKind => PredictionKind.Ranking; - private readonly SchemaShape.Column[] _outputColumns; - /// /// Initializes a new instance of /// @@ -71,24 +70,16 @@ public sealed partial class FastTreeRankingTrainer /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + : base(env, TrainerUtils.MakeU4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) { - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) - }; } /// /// Initializes a new instance of by using the legacy class. /// public FastTreeRankingTrainer(IHostEnvironment env, Arguments args) - : base(env, args, MakeLabelColumn(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeU4ScalarLabel(args.LabelColumn)) { - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) - }; } protected override float GetMaxLabel() @@ -158,11 +149,6 @@ protected override void CheckArgs(IChannel ch) base.CheckArgs(ch); } - private static SchemaShape.Column MakeLabelColumn(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); - } - protected override void Initialize(IChannel ch) { base.Initialize(ch); @@ -446,7 +432,13 @@ protected override string GetTestGraphHeader() protected override RankingPredictionTransformer MakeTransformer(FastTreeRankingPredictor model, ISchema trainSchema) => new RankingPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } public sealed class LambdaRankObjectiveFunction : ObjectiveFunctionBase, IStepSearch { diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 3984582c61..2b224ac8ee 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -51,8 +51,6 @@ public sealed partial class FastTreeRegressionTrainer /// public override PredictionKind PredictionKind => PredictionKind.Regression; - private readonly SchemaShape.Column[] _outputColumns; - /// /// Initializes a new instance of /// @@ -64,24 +62,16 @@ public sealed partial class FastTreeRegressionTrainer /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) - : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) - }; } /// /// Initializes a new instance of by using the legacy class. /// public FastTreeRegressionTrainer(IHostEnvironment env, Arguments args) - : base(env, args, MakeLabelColumn(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) - }; } protected override FastTreeRegressionPredictor TrainModelCore(TrainContext context) @@ -166,7 +156,13 @@ protected override Test ConstructTestForTrainingData() protected override RegressionPredictionTransformer MakeTransformer(FastTreeRegressionPredictor model, ISchema trainSchema) => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } private void AddFullRegressionTests() { diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 4511ab7f1d..9dfc9556a4 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -59,7 +59,7 @@ public sealed partial class FastTreeTweedieTrainer /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { Initialize(); } @@ -68,7 +68,7 @@ public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string f /// Initializes a new instance of by using the legacy class. /// public FastTreeTweedieTrainer(IHostEnvironment env, Arguments args) - : base(env, args, MakeLabelColumn(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { Initialize(); } @@ -302,15 +302,16 @@ protected override void Train(IChannel ch) PrintTestGraph(ch); } - private static SchemaShape.Column MakeLabelColumn(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); - } - protected override RegressionPredictionTransformer MakeTransformer(FastTreeTweediePredictor model, ISchema trainSchema) => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } private sealed class ObjectiveImpl : ObjectiveFunctionBase, IStepSearch { diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index e8aad81d13..9e18b2936b 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -4,6 +4,7 @@ using System; using System.Threading.Tasks; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -13,7 +14,6 @@ using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; -using Float = System.Single; [assembly: LoadableClass(BinaryClassificationGamTrainer.Summary, typeof(BinaryClassificationGamTrainer), typeof(BinaryClassificationGamTrainer.Arguments), @@ -22,14 +22,14 @@ BinaryClassificationGamTrainer.LoadNameValue, BinaryClassificationGamTrainer.ShortName, DocName = "trainer/GAM.md")] -[assembly: LoadableClass(typeof(IPredictorProducing), typeof(BinaryClassGamPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(IPredictorProducing), typeof(BinaryClassGamPredictor), null, typeof(SignatureLoadModel), "GAM Binary Class Predictor", BinaryClassGamPredictor.LoaderSignature)] namespace Microsoft.ML.Runtime.FastTree { public sealed class BinaryClassificationGamTrainer : - GamTrainerBase> + GamTrainerBase>, IPredictorProducing> { public sealed class Arguments : ArgumentsBase { @@ -47,7 +47,13 @@ public sealed class Arguments : ArgumentsBase private protected override bool NeedCalibration => true; public BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args) - : base(env, args) + : base(env, args, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) + { + _sigmoidParameter = 1; + } + + public BinaryClassificationGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) { _sigmoidParameter = 1; } @@ -77,7 +83,7 @@ private static bool[] ConvertTargetsToBool(double[] targets) return boolArray; } - public override IPredictorProducing Train(TrainContext context) + protected override IPredictorProducing TrainModelCore(TrainContext context) { TrainBase(context); var predictor = new BinaryClassGamPredictor(Host, InputLength, TrainSet, @@ -111,9 +117,22 @@ protected override void DefinePruningTest() PruningLossIndex = Args.UnbalancedSets ? 3 /*Unbalanced sets loss*/ : 1 /*normal loss*/; PruningTest = new TestHistory(validTest, PruningLossIndex); } + + protected override BinaryPredictionTransformer> MakeTransformer(IPredictorProducing model, ISchema trainSchema) + => new BinaryPredictionTransformer>(Host, model, trainSchema, FeatureColumn.Name); + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } } - public class BinaryClassGamPredictor : GamPredictorBase, IPredictorProducing + public class BinaryClassGamPredictor : GamPredictorBase, IPredictorProducing { public const string LoaderSignature = "BinaryClassGamPredictor"; public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; @@ -135,7 +154,7 @@ public static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) + public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 9fe4610c8c..107e3af55e 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -10,6 +11,7 @@ using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; +using System; [assembly: LoadableClass(RegressionGamTrainer.Summary, typeof(RegressionGamTrainer), typeof(RegressionGamTrainer.Arguments), @@ -24,8 +26,7 @@ namespace Microsoft.ML.Runtime.FastTree { - public sealed class RegressionGamTrainer : - GamTrainerBase + public sealed class RegressionGamTrainer : GamTrainerBase, RegressionGamPredictor> { public partial class Arguments : ArgumentsBase { @@ -41,14 +42,17 @@ public partial class Arguments : ArgumentsBase public override PredictionKind PredictionKind => PredictionKind.Regression; public RegressionGamTrainer(IHostEnvironment env, Arguments args) - : base(env, args) { } + : base(env, args, LoadNameValue, TrainerUtils.MakeR4VecLabel(args.LabelColumn)) { } + + public RegressionGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeR4VecLabel(labelColumn), featureColumn, weightColumn, advancedSettings) { } internal override void CheckLabel(RoleMappedData data) { data.CheckRegressionLabel(); } - public override RegressionGamPredictor Train(TrainContext context) + protected override RegressionGamPredictor TrainModelCore(TrainContext context) { TrainBase(context); return new RegressionGamPredictor(Host, InputLength, TrainSet, MeanEffect, BinEffects, FeatureMap); @@ -66,6 +70,17 @@ protected override void DefinePruningTest() PruningLossIndex = 0; PruningTest = new TestHistory(validTest, PruningLossIndex); } + + protected override RegressionPredictionTransformer MakeTransformer(RegressionGamPredictor model, ISchema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } } public class RegressionGamPredictor : GamPredictorBase diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 6f1b387ce9..0198e9f3a7 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -7,6 +7,7 @@ using System.IO; using System.Linq; using System.Threading; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; @@ -29,16 +30,16 @@ namespace Microsoft.ML.Runtime.FastTree { - using Float = System.Single; using SplitInfo = LeastSquaresRegressionTreeLearner.SplitInfo; using AutoResetEvent = System.Threading.AutoResetEvent; /// /// Generalized Additive Model Learner. /// - public abstract partial class GamTrainerBase : TrainerBase - where TArgs : GamTrainerBase.ArgumentsBase, new() - where TPredictor : IPredictorProducing + public abstract partial class GamTrainerBase : TrainerEstimatorBase + where TTransformer: IPredictionTransformer + where TArgs : GamTrainerBase.ArgumentsBase, new() + where TPredictor : IPredictorProducing { public abstract class ArgumentsBase : LearnerInputBaseWithWeight { @@ -129,10 +130,31 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight public override TrainerInfo Info { get; } private protected virtual bool NeedCalibration => false; - protected readonly IParallelTraining ParallelTraining; + protected IParallelTraining ParallelTraining; - private protected GamTrainerBase(IHostEnvironment env, TArgs args) - : base(env, RegisterName) + private protected GamTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, + string weightColumn = null, Action advancedSettings = null) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + { + Args = new TArgs(); + + //apply the advanced args, if the user supplied any + advancedSettings?.Invoke(Args); + Args.LabelColumn = label.Name; + + if (weightColumn != null) + Args.WeightColumn = weightColumn; + + Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false, supportValid: true); + _gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - Args.GainConfidenceLevel) * 0.5), 2); + _entropyCoefficient = Args.EntropyCoefficient * 1e-6; + + InitializeThreads(); + } + + private protected GamTrainerBase(IHostEnvironment env, TArgs args, string name, SchemaShape.Column label) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), + label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) { Contracts.CheckValue(env, nameof(env)); Host.CheckValue(args, nameof(args)); @@ -146,23 +168,12 @@ private protected GamTrainerBase(IHostEnvironment env, TArgs args) Host.CheckParam(0 < args.MinDocuments, nameof(args.MinDocuments), "Must be positive."); Args = args; + Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false, supportValid: true); _gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - Args.GainConfidenceLevel) * 0.5), 2); _entropyCoefficient = Args.EntropyCoefficient * 1e-6; - ParallelTraining = new SingleTrainer(); - - int numThreads = args.NumThreads ?? Environment.ProcessorCount; - if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor) - using (var ch = Host.Start("GamTrainer")) - { - numThreads = Host.ConcurrencyFactor; - ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor " - + "setting of the environment. Using {0} training threads instead.", numThreads); - ch.Done(); - } - - InitializeThreads(numThreads); + InitializeThreads(); } protected void TrainBase(TrainContext context) @@ -204,7 +215,7 @@ private void ConvertData(RoleMappedData trainData, RoleMappedData validationData CheckLabel(trainData); var useTranspose = UseTranspose(Args.DiskTranspose, trainData); - var instanceConverter = new ExamplesToFastTreeBins(Host, Args.MaxBins, useTranspose, !Args.FeatureFlocks, Args.MinDocuments, Float.PositiveInfinity); + var instanceConverter = new ExamplesToFastTreeBins(Host, Args.MaxBins, useTranspose, !Args.FeatureFlocks, Args.MinDocuments, float.PositiveInfinity); ParallelTraining.InitEnvironment(); TrainSet = instanceConverter.FindBinsAndReturnDataset(trainData, PredictionKind, ParallelTraining, null, false); @@ -512,8 +523,20 @@ private void Initialize(IChannel ch) } } - private void InitializeThreads(int numThreads) + private void InitializeThreads() { + ParallelTraining = new SingleTrainer(); + + int numThreads = Args.NumThreads ?? Environment.ProcessorCount; + if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor) + using (var ch = Host.Start("GamTrainer")) + { + numThreads = Host.ConcurrencyFactor; + ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor " + + "setting of the environment. Using {0} training threads instead.", numThreads); + ch.Done(); + } + ThreadTaskManager.Initialize(numThreads); } @@ -593,7 +616,7 @@ public Stump(uint splitPoint, double lteValue, double gtValue) } } - public abstract class GamPredictorBase : PredictorBase, + public abstract class GamPredictorBase : PredictorBase, IValueMapper, ICanSaveModel, ICanSaveInTextFormat, ICanSaveSummary { private readonly double[][] _binUpperBounds; @@ -763,14 +786,14 @@ public override void Save(ModelSaveContext ctx) public ValueMapper GetMapper() { - Host.Check(typeof(TIn) == typeof(VBuffer)); - Host.Check(typeof(TOut) == typeof(Float)); + Host.Check(typeof(TIn) == typeof(VBuffer)); + Host.Check(typeof(TOut) == typeof(float)); - ValueMapper, Float> del = Map; + ValueMapper, float> del = Map; return (ValueMapper)(Delegate)del; } - private void Map(ref VBuffer features, ref Float response) + private void Map(ref VBuffer features, ref float response) { Host.CheckParam(features.Length == _inputLength, nameof(features), "Bad length of input"); @@ -795,7 +818,7 @@ private void Map(ref VBuffer features, ref Float response) } } - response = (Float)value; + response = (float)value; } /// @@ -803,21 +826,21 @@ private void Map(ref VBuffer features, ref Float response) /// is used as a buffer to accumulate the contributions across trees. /// If is null, it will be created, otherwise it will be reused. /// - internal void GetFeatureContributions(ref VBuffer features, ref VBuffer contribs, ref BufferBuilder builder) + internal void GetFeatureContributions(ref VBuffer features, ref VBuffer contribs, ref BufferBuilder builder) { if (builder == null) builder = new BufferBuilder(R4Adder.Instance); // The model is Intercept + Features builder.Reset(features.Length + 1, false); - builder.AddFeature(0, (Float)_intercept); + builder.AddFeature(0, (float)_intercept); if (features.IsDense) { for (int i = 0; i < features.Count; ++i) { if (_inputFeatureToDatasetFeatureMap.TryGetValue(i, out int j)) - builder.AddFeature(i+1, (Float) GetBinEffect(j, features.Values[i])); + builder.AddFeature(i+1, (float) GetBinEffect(j, features.Values[i])); } } else @@ -840,7 +863,7 @@ internal void GetFeatureContributions(ref VBuffer features, ref VBuffer features, ref VBuffer features, int[] bins) + internal double GetFeatureBinsAndScore(ref VBuffer features, int[] bins) { Host.CheckParam(features.Length == _inputLength, nameof(features)); Host.CheckParam(Utils.Size(bins) == _numFeatures, nameof(bins)); @@ -925,7 +948,7 @@ public void SaveAsText(TextWriter writer, RoleMappedSchema schema) writer.WriteLine(); writer.WriteLine("Per feature binned effects:"); writer.WriteLine("Feature Index\tFeature Value Bin Upper Bound\tOutput (effect on label)"); - writer.WriteLine($"{-1:D}\t{Float.MaxValue:R}\t{_intercept:R}"); + writer.WriteLine($"{-1:D}\t{float.MaxValue:R}\t{_intercept:R}"); for (int internalIndex = 0; internalIndex < _numFeatures; internalIndex++) { int featureIndex = _featureMap[internalIndex]; @@ -1106,7 +1129,7 @@ public long SetEffect(int feat, int bin, double effect) var deltaEffect = effect - effects[bin]; effects[bin] = effect; foreach (var docIndex in _binDocsList[internalIndex][bin]) - _scores[docIndex] += (Float)deltaEffect; + _scores[docIndex] += (float)deltaEffect; return checked(++_version); } } diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index c5222d4e87..c416e54dc3 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -132,7 +132,6 @@ public sealed class Arguments : FastForestArgumentsBase public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; private protected override bool NeedCalibration => true; - private readonly SchemaShape.Column[] _outputColumns; /// /// Initializes a new instance of @@ -147,12 +146,6 @@ public FastForestClassification(IHostEnvironment env, string labelColumn, string string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) { - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), - new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) - }; } /// @@ -161,12 +154,6 @@ public FastForestClassification(IHostEnvironment env, string labelColumn, string public FastForestClassification(IHostEnvironment env, Arguments args) : base(env, args, MakeLabelColumn(args.LabelColumn)) { - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), - new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) - }; } protected override IPredictorWithFeatureWeights TrainModelCore(TrainContext context) @@ -220,7 +207,15 @@ protected override Test ConstructTestForTrainingData() protected override BinaryPredictionTransformer> MakeTransformer(IPredictorWithFeatureWeights model, ISchema trainSchema) => new BinaryPredictionTransformer>(Host, model, trainSchema, FeatureColumn.Name); - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) + }; + } private sealed class ObjectiveFunctionImpl : RandomForestObjectiveFunction { diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 510c3b0ec6..ace1568bc1 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -2,18 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.FastTree.Internal; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Core.Data; +using System; [assembly: LoadableClass(FastForestRegression.Summary, typeof(FastForestRegression), typeof(FastForestRegression.Arguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -154,8 +154,6 @@ public sealed class Arguments : FastForestArgumentsBase internal const string UserNameValue = "Fast Forest Regression"; internal const string ShortName = "ffr"; - private readonly SchemaShape.Column[] _outputColumns; - /// /// Initializes a new instance of /// @@ -167,24 +165,16 @@ public sealed class Arguments : FastForestArgumentsBase /// A delegate to apply all the advanced arguments to the algorithm. public FastForestRegression(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, true, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, true, advancedSettings) { - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) - }; } /// /// Initializes a new instance of by using the legacy class. /// public FastForestRegression(IHostEnvironment env, Arguments args) - : base(env, args, MakeLabelColumn(args.LabelColumn), true) + : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), true) { - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) - }; } protected override FastForestRegressionPredictor TrainModelCore(TrainContext context) @@ -224,13 +214,14 @@ protected override Test ConstructTestForTrainingData() protected override RegressionPredictionTransformer MakeTransformer(FastForestRegressionPredictor model, ISchema trainSchema) => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - private static SchemaShape.Column MakeLabelColumn(string labelColumn) + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) + }; } - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; - private abstract class ObjectiveFunctionImplBase : RandomForestObjectiveFunction { private readonly float[] _labels; diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index f788b4feab..9b9fb3db7a 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -10,6 +11,8 @@ using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.Training; +using System; [assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(LightGbmArguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) }, @@ -23,10 +26,11 @@ namespace Microsoft.ML.Runtime.LightGBM { + /// public sealed class LightGbmBinaryPredictor : FastTreePredictionWrapper { - public const string LoaderSignature = "LightGBMBinaryExec"; - public const string RegistrationName = "LightGBMBinaryPredictor"; + internal const string LoaderSignature = "LightGBMBinaryExec"; + internal const string RegistrationName = "LightGBMBinaryPredictor"; private static VersionInfo GetVersionInfo() { @@ -79,7 +83,7 @@ public static IPredictorProducing Create(IHostEnvironment env, ModelLoadC } /// - public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase> + public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase>, IPredictorWithFeatureWeights> { internal const string UserName = "LightGBM Binary Classifier"; internal const string LoadNameValue = "LightGBMBinary"; @@ -89,7 +93,13 @@ public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase PredictionKind.BinaryClassification; public LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, args, LoadNameValue) + : base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) + { + } + + public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string featureColumn, + string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { } @@ -121,6 +131,18 @@ protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, Role if (!Options.ContainsKey("metric")) Options["metric"] = "binary_logloss"; } + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } + + protected override BinaryPredictionTransformer> MakeTransformer(IPredictorWithFeatureWeights model, ISchema trainSchema) + => new BinaryPredictionTransformer>(Host, model, trainSchema, FeatureColumn.Name); } /// diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index ff44139877..3a4fdccffd 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -4,12 +4,15 @@ using System; using System.Globalization; +using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.LightGBM; +using Microsoft.ML.Runtime.Training; [assembly: LoadableClass(LightGbmMulticlassTrainer.Summary, typeof(LightGbmMulticlassTrainer), typeof(LightGbmArguments), new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer) }, @@ -19,7 +22,7 @@ namespace Microsoft.ML.Runtime.LightGBM { /// - public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase, OvaPredictor> + public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase, MulticlassPredictionTransformer, OvaPredictor> { public const string Summary = "LightGBM Multi Class Classifier"; public const string LoadNameValue = "LightGBMMulticlass"; @@ -32,7 +35,14 @@ public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase PredictionKind.MultiClassClassification; public LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, args, LoadNameValue) + : base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) + { + _numClass = -1; + } + + public LightGbmMulticlassTrainer(IHostEnvironment env, string labelColumn, string featureColumn, + string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { _numClass = -1; } @@ -174,6 +184,23 @@ protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, Role if (!Options.ContainsKey("metric")) Options["metric"] = "multi_error"; } + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); + Contracts.Assert(success); + + var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) + .Concat(MetadataUtils.GetTrainerOutputMetadata())); + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, metadata) + }; + } + + protected override MulticlassPredictionTransformer MakeTransformer(OvaPredictor model, ISchema trainSchema) + => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); } /// diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 3fe4628182..5e1c232487 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -9,6 +10,8 @@ using Microsoft.ML.Runtime.FastTree.Internal; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.Training; +using System; [assembly: LoadableClass(LightGbmRankingTrainer.UserName, typeof(LightGbmRankingTrainer), typeof(LightGbmArguments), new[] { typeof(SignatureRankerTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) }, @@ -62,14 +65,14 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public static LightGbmRankingPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static LightGbmRankingPredictor Create(IHostEnvironment env, ModelLoadContext ctx) { return new LightGbmRankingPredictor(env, ctx); } } /// - public sealed class LightGbmRankingTrainer : LightGbmTrainerBase + public sealed class LightGbmRankingTrainer : LightGbmTrainerBase, LightGbmRankingPredictor> { public const string UserName = "LightGBM Ranking"; public const string LoadNameValue = "LightGBMRanking"; @@ -78,7 +81,13 @@ public sealed class LightGbmRankingTrainer : LightGbmTrainerBase PredictionKind.Ranking; public LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, args, LoadNameValue) + : base(env, LoadNameValue, args, TrainerUtils.MakeU4ScalarLabel(args.LabelColumn)) + { + } + + public LightGbmRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, + string groupIdColumn, string weightColumn = null, Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeU4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { } @@ -120,6 +129,17 @@ protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, Role // Only output one ndcg score. Options["eval_at"] = "5"; } + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } + + protected override RankingPredictionTransformer MakeTransformer(LightGbmRankingPredictor model, ISchema trainSchema) + => new RankingPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); } /// diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 0011a8d8e6..64bfb93da5 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -9,6 +10,8 @@ using Microsoft.ML.Runtime.FastTree.Internal; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.Training; +using System; [assembly: LoadableClass(LightGbmRegressorTrainer.Summary, typeof(LightGbmRegressorTrainer), typeof(LightGbmArguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) }, @@ -28,7 +31,7 @@ public sealed class LightGbmRegressionPredictor : FastTreePredictionWrapper private static VersionInfo GetVersionInfo() { - // REVIEW tfinley(guoke): can we decouple the version from FastTree predictor version ? + // REVIEW: can we decouple the version from FastTree predictor version ? return new VersionInfo( modelSignature: "LGBSIREG", // verWrittenCur: 0x00010001, // Initial @@ -71,17 +74,24 @@ public static LightGbmRegressionPredictor Create(IHostEnvironment env, ModelLoad } } - public sealed class LightGbmRegressorTrainer : LightGbmTrainerBase + /// + public sealed class LightGbmRegressorTrainer : LightGbmTrainerBase, LightGbmRegressionPredictor> { - public const string Summary = "LightGBM Regression"; - public const string LoadNameValue = "LightGBMRegression"; - public const string ShortName = "LightGBMR"; - public const string UserNameValue = "LightGBM Regressor"; + internal const string Summary = "LightGBM Regression"; + internal const string LoadNameValue = "LightGBMRegression"; + internal const string ShortName = "LightGBMR"; + internal const string UserNameValue = "LightGBM Regressor"; public override PredictionKind PredictionKind => PredictionKind.Regression; + public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string featureColumn, + string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + { + } + public LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, args, LoadNameValue) + : base(env, LoadNameValue, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } @@ -112,6 +122,17 @@ protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, Role if (!Options.ContainsKey("metric")) Options["metric"] = "l2"; } + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } + + protected override RegressionPredictionTransformer MakeTransformer(LightGbmRegressionPredictor model, ISchema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); } /// diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index eae632eb24..feb7d021cc 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Training; @@ -24,8 +25,9 @@ internal static class LightGbmShared /// /// Base class for all training with LightGBM. /// - public abstract class LightGbmTrainerBase : TrainerBase - where TPredictor : IPredictorProducing + public abstract class LightGbmTrainerBase : TrainerEstimatorBase + where TTransformer : IPredictionTransformer + where TModel : IPredictorProducing { private sealed class CategoricalMetaData { @@ -44,8 +46,8 @@ private sealed class CategoricalMetaData /// the code is culture agnostic. When retrieving key value from this dictionary as string /// please convert to string invariant by string.Format(CultureInfo.InvariantCulture, "{0}", Option[key]). /// - private protected readonly Dictionary Options; - private protected readonly IParallel ParallelTraining; + private protected Dictionary Options; + private protected IParallel ParallelTraining; // Store _featureCount and _trainedEnsemble to construct predictor. private protected int FeatureCount; @@ -54,18 +56,35 @@ private sealed class CategoricalMetaData private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false, supportValid: true); public override TrainerInfo Info => _info; - private protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, string name) - : base(env, name) + private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, + string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + { + Args = new LightGbmArguments(); + + //apply the advanced args, if the user supplied any + advancedSettings?.Invoke(Args); + Args.LabelColumn = label.Name; + + if (weightColumn != null) + Args.WeightColumn = weightColumn; + + if (groupIdColumn != null) + Args.GroupIdColumn = groupIdColumn; + + InitParallelTraining(); + } + + private protected LightGbmTrainerBase(IHostEnvironment env, string name, LightGbmArguments args, SchemaShape.Column label) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) { Host.CheckValue(args, nameof(args)); Args = args; - Options = Args.ToDictionary(Host); - ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); InitParallelTraining(); } - public override TPredictor Train(TrainContext context) + protected override TModel TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); @@ -102,6 +121,9 @@ public override TPredictor Train(TrainContext context) private void InitParallelTraining() { + Options = Args.ToDictionary(Host); + ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(Host) : new SingleTrainer(); + if (ParallelTraining.ParallelType() != "serial" && ParallelTraining.NumMachines() > 1) { Options["tree_learner"] = ParallelTraining.ParallelType(); @@ -849,7 +871,7 @@ private static int GetNumSampleRow(int numRow, int numCol) return ret; } - private protected abstract TPredictor CreatePredictor(); + private protected abstract TModel CreatePredictor(); /// /// This function will be called before training. It will check the label/group and add parameters for specific applications. diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index 4941764c35..e1611c46fb 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -49,13 +49,19 @@ public sealed class Arguments : ArgumentsBase public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), args, MakeFeatureColumn(featureColumn), MakeLabelColumn(labelColumn), MakeWeightColumn(weightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), args, TrainerUtils.MakeR4VecFeature(featureColumn), + TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { _loss = args.LossFunction.CreateComponent(env); Loss = _loss; _args = args; } + public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args) + : this(env, args, args.FeatureColumn, args.LabelColumn) + { + } + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); @@ -70,11 +76,6 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc }; } - public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args) - : this(env, args, args.FeatureColumn, args.LabelColumn) - { - } - protected override void CheckLabelCompatible(SchemaShape.Column labelCol) { Contracts.AssertValue(labelCol); @@ -88,23 +89,6 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol) error(); } - private static SchemaShape.Column MakeWeightColumn(string weightColumn) - { - if (weightColumn == null) - return null; - return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); - } - - private static SchemaShape.Column MakeLabelColumn(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); - } - - private static SchemaShape.Column MakeFeatureColumn(string featureColumn) - { - return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); - } - /// protected override void TrainWithoutLock(IProgressChannelProvider progress, FloatLabelCursor.Factory cursorFactory, IRandom rand, IdToIdxLookup idToIdx, int numThreads, DualsTableBase duals, Float[] biasReg, Float[] invariants, Float lambdaNInv, diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 4f6755e528..74ca15ea44 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -52,41 +52,19 @@ public Arguments() public override PredictionKind PredictionKind => PredictionKind.Regression; - private readonly SchemaShape.Column[] _outputColumns; - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; - public SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), args, MakeFeatureColumn(featureColumn), MakeLabelColumn(labelColumn), MakeWeightColumn(weightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), args, TrainerUtils.MakeR4VecFeature(featureColumn), + TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { _loss = args.LossFunction.CreateComponent(env); Loss = _loss; _args = args; - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) - }; } public SdcaRegressionTrainer(IHostEnvironment env, Arguments args) : this(env, args, args.FeatureColumn, args.LabelColumn) { } - private static SchemaShape.Column MakeWeightColumn(string weightColumn) - { - if (weightColumn == null) - return null; - return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); - } - - private static SchemaShape.Column MakeLabelColumn(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); - } - - private static SchemaShape.Column MakeFeatureColumn(string featureColumn) - { - return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); - } protected override LinearRegressionPredictor CreatePredictor(VBuffer[] weights, Float[] bias) { @@ -143,6 +121,14 @@ protected override Float TuneDefaultL2(IChannel ch, int maxIterations, long rowC return l2; } + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } + protected override RegressionPredictionTransformer MakeTransformer(LinearRegressionPredictor model, ISchema trainSchema) => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); } From 471a3683e7f9ca5b245ec89039936ec0f1148e4a Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Wed, 19 Sep 2018 22:41:13 -0700 Subject: [PATCH 2/8] Fit and Finish. --- src/Microsoft.ML.FastTree/FastTree.cs | 20 +++++++++---------- .../FastTreeArguments.cs | 2 +- .../FastTreeClassification.cs | 8 ++++---- .../FastTreeRegression.cs | 8 ++++---- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 8 ++++---- .../GamClassification.cs | 4 ++-- src/Microsoft.ML.FastTree/GamRegression.cs | 4 ++-- src/Microsoft.ML.FastTree/GamTrainer.cs | 12 +++++------ .../RandomForestClassification.cs | 16 +++++---------- .../RandomForestRegression.cs | 3 +-- .../LightGbmMulticlassTrainer.cs | 5 ++--- .../LightGbmRankingTrainer.cs | 1 - .../LightGbmRegressionTrainer.cs | 1 - .../LightGbmTrainerBase.cs | 5 +++-- 14 files changed, 44 insertions(+), 53 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index fedf642f2e..0628a5d76f 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2,16 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - -using System; -using System.Collections; -using System.Collections.Generic; -using System.ComponentModel; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Text; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -26,6 +16,15 @@ using Microsoft.ML.Runtime.Training; using Microsoft.ML.Runtime.TreePredictor; using Newtonsoft.Json.Linq; +using System; +using System.Collections; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using Float = System.Single; // All of these reviews apply in general to fast tree and random forest implementations. //REVIEW: Decouple train method in Application.cs to have boosting and random forest logic seperate. @@ -101,6 +100,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); Args.LabelColumn = label.Name; + Args.FeatureColumn = featureColumn; if (weightColumn != null) Args.WeightColumn = weightColumn; diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index 614250c0b8..e1eb8790de 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -2,11 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.Internal.Internallearn; +using System; [assembly: EntryPointModule(typeof(FastTreeBinaryClassificationTrainer.Arguments))] [assembly: EntryPointModule(typeof(FastTreeRegressionTrainer.Arguments))] diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index dccf57d45a..1f7b747218 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -2,9 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; @@ -15,6 +12,9 @@ using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; +using System; +using System.Collections.Generic; +using System.Linq; [assembly: LoadableClass(FastTreeBinaryClassificationTrainer.Summary, typeof(FastTreeBinaryClassificationTrainer), typeof(FastTreeBinaryClassificationTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -112,6 +112,7 @@ public sealed partial class FastTreeBinaryClassificationTrainer : internal const string ShortName = "ftc"; private bool[] _trainSetLabels; + private double _sigmoidParameter; /// /// Initializes a new instance of @@ -129,7 +130,6 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelCol // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss _sigmoidParameter = 2.0 * Args.LearningRates; } - private double _sigmoidParameter; /// /// Initializes a new instance of by using the legacy class. diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 2b224ac8ee..604e8a217e 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -2,18 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Linq; -using System.Text; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.FastTree.Internal; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Runtime.Internal.Internallearn; +using System; +using System.Linq; +using System.Text; [assembly: LoadableClass(FastTreeRegressionTrainer.Summary, typeof(FastTreeRegressionTrainer), typeof(FastTreeRegressionTrainer.Arguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 9dfc9556a4..2a1385e254 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -2,19 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Linq; -using System.Text; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.FastTree.Internal; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Runtime.Internal.Internallearn; +using System; +using System.Linq; +using System.Text; [assembly: LoadableClass(FastTreeTweedieTrainer.Summary, typeof(FastTreeTweedieTrainer), typeof(FastTreeTweedieTrainer.Arguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index 9e18b2936b..22aaacf4d2 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Threading.Tasks; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; @@ -14,6 +12,8 @@ using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; +using System; +using System.Threading.Tasks; [assembly: LoadableClass(BinaryClassificationGamTrainer.Summary, typeof(BinaryClassificationGamTrainer), typeof(BinaryClassificationGamTrainer.Arguments), diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 107e3af55e..4a198f88bc 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -42,10 +42,10 @@ public partial class Arguments : ArgumentsBase public override PredictionKind PredictionKind => PredictionKind.Regression; public RegressionGamTrainer(IHostEnvironment env, Arguments args) - : base(env, args, LoadNameValue, TrainerUtils.MakeR4VecLabel(args.LabelColumn)) { } + : base(env, args, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } public RegressionGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4VecLabel(labelColumn), featureColumn, weightColumn, advancedSettings) { } + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) { } internal override void CheckLabel(RoleMappedData data) { diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 0198e9f3a7..ebb5b83d50 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -2,11 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Threading; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Command; @@ -21,6 +16,11 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; using Timer = Microsoft.ML.Runtime.FastTree.Internal.Timer; [assembly: LoadableClass(typeof(GamPredictorBase.VisualizationCommand), typeof(GamPredictorBase.VisualizationCommand.Arguments), typeof(SignatureCommand), @@ -30,8 +30,8 @@ namespace Microsoft.ML.Runtime.FastTree { - using SplitInfo = LeastSquaresRegressionTreeLearner.SplitInfo; using AutoResetEvent = System.Threading.AutoResetEvent; + using SplitInfo = LeastSquaresRegressionTreeLearner.SplitInfo; /// /// Generalized Additive Model Learner. diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index c416e54dc3..0a82ecb602 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; @@ -15,6 +13,8 @@ using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; +using System; +using System.Linq; [assembly: LoadableClass(FastForestClassification.Summary, typeof(FastForestClassification), typeof(FastForestClassification.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -77,8 +77,7 @@ private static VersionInfo GetVersionInfo() /// public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - internal FastForestClassificationPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, - string innerArgs) + internal FastForestClassificationPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } @@ -144,7 +143,7 @@ public sealed class Arguments : FastForestArgumentsBase /// A delegate to apply all the advanced arguments to the algorithm. public FastForestClassification(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) { } @@ -152,7 +151,7 @@ public FastForestClassification(IHostEnvironment env, string labelColumn, string /// Initializes a new instance of by using the legacy class. /// public FastForestClassification(IHostEnvironment env, Arguments args) - : base(env, args, MakeLabelColumn(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { } @@ -194,11 +193,6 @@ protected override void PrepareLabels(IChannel ch) _trainSetLabels = TrainSet.Ratings.Select(x => x >= 1).ToArray(TrainSet.NumDocs); } - private static SchemaShape.Column MakeLabelColumn(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); - } - protected override Test ConstructTestForTrainingData() { return new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, 1); diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index ace1568bc1..b2eac72ea0 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -58,8 +58,7 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010006; - public FastForestRegressionPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, - string innerArgs, int samplesCount) + public FastForestRegressionPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { _quantileSampleCount = samplesCount; diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 3a4fdccffd..af47913e2c 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -2,9 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Globalization; -using System.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; @@ -13,6 +10,8 @@ using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Runtime.Training; +using System; +using System.Linq; [assembly: LoadableClass(LightGbmMulticlassTrainer.Summary, typeof(LightGbmMulticlassTrainer), typeof(LightGbmArguments), new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer) }, diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 5e1c232487..25e820381e 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -7,7 +7,6 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.FastTree; -using Microsoft.ML.Runtime.FastTree.Internal; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 64bfb93da5..b20d36fa12 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -7,7 +7,6 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.FastTree; -using Microsoft.ML.Runtime.FastTree.Internal; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index feb7d021cc..5e3b8bec59 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -2,12 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Training; +using System; +using System.Collections.Generic; namespace Microsoft.ML.Runtime.LightGBM { @@ -65,6 +65,7 @@ private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaS //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); Args.LabelColumn = label.Name; + Args.FeatureColumn = featureColumn; if (weightColumn != null) Args.WeightColumn = weightColumn; From 045c79845175ec4989586b7caab4509b2a8259cf Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 20 Sep 2018 13:00:27 -0700 Subject: [PATCH 3/8] Adding tests and commenting code --- .../Training/TrainerUtils.cs | 91 ++++--- .../FastTreeClassification.cs | 3 + src/Microsoft.ML.FastTree/FastTreeRanking.cs | 7 +- .../FastTreeRegression.cs | 2 + src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 5 +- .../GamClassification.cs | 14 ++ src/Microsoft.ML.FastTree/GamRegression.cs | 14 +- .../RandomForestClassification.cs | 8 +- .../RandomForestRegression.cs | 4 +- .../LightGbmBinaryTrainer.cs | 11 + .../LightGbmMulticlassTrainer.cs | 11 + .../LightGbmRankingTrainer.cs | 16 +- .../LightGbmRegressionTrainer.cs | 11 + .../Standard/Online/AveragedPerceptron.cs | 7 +- .../Standard/Online/OnlineGradientDescent.cs | 7 +- .../Standard/SdcaMultiClass.cs | 3 + .../Standard/SdcaRegression.cs | 3 + .../TrainerEstimators/MetalinearEstimators.cs | 2 +- .../TrainerEstimators/TreeEstimators.cs | 226 ++++++++++++++---- 19 files changed, 326 insertions(+), 119 deletions(-) diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index 2c4bda721f..5322149331 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -7,7 +7,6 @@ using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; -using Float = System.Single; namespace Microsoft.ML.Runtime.Training { @@ -238,9 +237,7 @@ private static Func CreatePredicate(RoleMappedData data, CursOpt opt, /// This does not verify that the columns exist, but merely activates the ones that do exist. /// public static IRowCursor CreateRowCursor(this RoleMappedData data, CursOpt opt, IRandom rand, IEnumerable extraCols = null) - { - return data.Data.GetRowCursor(CreatePredicate(data, opt, extraCols), rand); - } + => data.Data.GetRowCursor(CreatePredicate(data, opt, extraCols), rand); /// /// Create a row cursor set for the RoleMappedData with the indicated standard columns active. @@ -248,9 +245,7 @@ public static IRowCursor CreateRowCursor(this RoleMappedData data, CursOpt opt, /// public static IRowCursor[] CreateRowCursorSet(this RoleMappedData data, out IRowCursorConsolidator consolidator, CursOpt opt, int n, IRandom rand, IEnumerable extraCols = null) - { - return data.Data.GetRowCursorSet(out consolidator, CreatePredicate(data, opt, extraCols), n, rand); - } + => data.Data.GetRowCursorSet(out consolidator, CreatePredicate(data, opt, extraCols), n, rand); private static void AddOpt(HashSet cols, ColumnInfo info) { @@ -260,32 +255,32 @@ private static void AddOpt(HashSet cols, ColumnInfo info) } /// - /// Get the getter for the feature column, assuming it is a vector of Float. + /// Get the getter for the feature column, assuming it is a vector of float. /// - public static ValueGetter> GetFeatureFloatVectorGetter(this IRow row, RoleMappedSchema schema) + public static ValueGetter> GetFeatureFloatVectorGetter(this IRow row, RoleMappedSchema schema) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckParam(schema.Schema == row.Schema, nameof(schema), "schemas don't match!"); Contracts.CheckParam(schema.Feature != null, nameof(schema), "Missing feature column"); - return row.GetGetter>(schema.Feature.Index); + return row.GetGetter>(schema.Feature.Index); } /// - /// Get the getter for the feature column, assuming it is a vector of Float. + /// Get the getter for the feature column, assuming it is a vector of float. /// - public static ValueGetter> GetFeatureFloatVectorGetter(this IRow row, RoleMappedData data) + public static ValueGetter> GetFeatureFloatVectorGetter(this IRow row, RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); return GetFeatureFloatVectorGetter(row, data.Schema); } /// - /// Get a getter for the label as a Float. This assumes that the label column type + /// Get a getter for the label as a float. This assumes that the label column type /// has already been validated as appropriate for the kind of training being done. /// - public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedSchema schema) + public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedSchema schema) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); @@ -296,10 +291,10 @@ public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedSc } /// - /// Get a getter for the label as a Float. This assumes that the label column type + /// Get a getter for the label as a float. This assumes that the label column type /// has already been validated as appropriate for the kind of training being done. /// - public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedData data) + public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); return GetLabelFloatGetter(row, data.Schema); @@ -308,7 +303,7 @@ public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedDa /// /// Get the getter for the weight column, or null if there is no weight column. /// - public static ValueGetter GetOptWeightFloatGetter(this IRow row, RoleMappedSchema schema) + public static ValueGetter GetOptWeightFloatGetter(this IRow row, RoleMappedSchema schema) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); @@ -318,10 +313,10 @@ public static ValueGetter GetOptWeightFloatGetter(this IRow row, RoleMapp var col = schema.Weight; if (col == null) return null; - return RowCursorUtils.GetGetterAs(NumberType.Float, row, col.Index); + return RowCursorUtils.GetGetterAs(NumberType.Float, row, col.Index); } - public static ValueGetter GetOptWeightFloatGetter(this IRow row, RoleMappedData data) + public static ValueGetter GetOptWeightFloatGetter(this IRow row, RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); return GetOptWeightFloatGetter(row, data.Schema); @@ -349,34 +344,38 @@ public static ValueGetter GetOptGroupGetter(this IRow row, RoleMappedData return GetOptGroupGetter(row, data.Schema); } + /// + /// The for the label column for binary classification tasks. + /// + /// name of the label column public static SchemaShape.Column MakeBoolScalarLabel(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); - } + => new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); + /// + /// The for the label column for regression tasks. + /// + /// name of the weight column public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); - } + => new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); + /// + /// The for the label column for regression tasks. + /// + /// name of the weight column public static SchemaShape.Column MakeU4ScalarLabel(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); - } + => new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); /// /// The for the feature column. /// /// name of the feature column public static SchemaShape.Column MakeR4VecFeature(string featureColumn) - { - return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); - } + => new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); /// - /// The for the feature column. + /// The for the weight column. /// - /// name of the feature column + /// name of the weight column public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn) { if (weightColumn == null) @@ -628,11 +627,11 @@ public void Signal(CursOpt opt) } /// - /// This supports Weight (Float), Group (ulong), and Id (DvInt8) columns. + /// This supports Weight (float), Group (ulong), and Id (DvInt8) columns. /// public class StandardScalarCursor : TrainingCursorBase { - private readonly ValueGetter _getWeight; + private readonly ValueGetter _getWeight; private readonly ValueGetter _getGroup; private readonly ValueGetter _getId; private readonly bool _keepBadWeight; @@ -643,7 +642,7 @@ public class StandardScalarCursor : TrainingCursorBase public long BadWeightCount { get { return _badWeightCount; } } public long BadGroupCount { get { return _badGroupCount; } } - public Float Weight; + public float Weight; public ulong Group; public UInt128 Id; @@ -690,7 +689,7 @@ public override bool Accept() if (_getWeight != null) { _getWeight(ref Weight); - if (!_keepBadWeight && !(0 < Weight && Weight < Float.PositiveInfinity)) + if (!_keepBadWeight && !(0 < Weight && Weight < float.PositiveInfinity)) { _badWeightCount++; return false; @@ -718,9 +717,7 @@ public Factory(RoleMappedData data, CursOpt opt) } protected override StandardScalarCursor CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal) - { - return new StandardScalarCursor(input, data, opt, signal); - } + => new StandardScalarCursor(input, data, opt, signal); } } @@ -730,13 +727,13 @@ protected override StandardScalarCursor CreateCursorCore(IRowCursor input, RoleM /// public class FeatureFloatVectorCursor : StandardScalarCursor { - private readonly ValueGetter> _get; + private readonly ValueGetter> _get; private readonly bool _keepBad; private long _badCount; public long BadFeaturesRowCount { get { return _badCount; } } - public VBuffer Features; + public VBuffer Features; public FeatureFloatVectorCursor(RoleMappedData data, CursOpt opt = CursOpt.Features, IRandom rand = null, params int[] extraCols) @@ -793,18 +790,18 @@ protected override FeatureFloatVectorCursor CreateCursorCore(IRowCursor input, R } /// - /// This derives from the FeatureFloatVectorCursor and adds the Label (Float) column. + /// This derives from the FeatureFloatVectorCursor and adds the Label (float) column. /// public class FloatLabelCursor : FeatureFloatVectorCursor { - private readonly ValueGetter _get; + private readonly ValueGetter _get; private readonly bool _keepBad; private long _badCount; public long BadLabelCount { get { return _badCount; } } - public Float Label; + public float Label; public FloatLabelCursor(RoleMappedData data, CursOpt opt = CursOpt.Label, IRandom rand = null, params int[] extraCols) @@ -866,13 +863,13 @@ protected override FloatLabelCursor CreateCursorCore(IRowCursor input, RoleMappe public class MultiClassLabelCursor : FeatureFloatVectorCursor { private readonly int _classCount; - private readonly ValueGetter _get; + private readonly ValueGetter _get; private readonly bool _keepBad; private long _badCount; public long BadLabelCount { get { return _badCount; } } - private Float _raw; + private float _raw; public int Label; public MultiClassLabelCursor(int classCount, RoleMappedData data, CursOpt opt = CursOpt.Label, diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 1f7b747218..8553fc1596 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -127,6 +127,9 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelCol string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { + Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); + Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss _sigmoidParameter = 2.0 * Args.LearningRates; } diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 2b14de03ea..59a83d4a47 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -70,15 +70,18 @@ public sealed partial class FastTreeRankingTrainer /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeU4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) { + Host.CheckValue(labelColumn, nameof(labelColumn)); + Host.CheckValue(featureColumn, nameof(featureColumn)); + Host.CheckValue(groupIdColumn, nameof(groupIdColumn)); } /// /// Initializes a new instance of by using the legacy class. /// public FastTreeRankingTrainer(IHostEnvironment env, Arguments args) - : base(env, args, TrainerUtils.MakeU4ScalarLabel(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 604e8a217e..24c1a1fe62 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -64,6 +64,8 @@ public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, strin string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { + Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); + Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); } /// diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 2a1385e254..d211377cbb 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -58,9 +58,12 @@ public sealed partial class FastTreeTweedieTrainer /// The name for the column containing the initial weight. /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn, string weightColumn = null, Action advancedSettings = null) + string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { + Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); + Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + Initialize(); } diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index 22aaacf4d2..889e5efff9 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -46,15 +46,29 @@ public sealed class Arguments : ArgumentsBase public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; private protected override bool NeedCalibration => true; + /// + /// Initializes a new instance of + /// public BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args) : base(env, args, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { _sigmoidParameter = 1; } + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. public BinaryClassificationGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) { + Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); + Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + _sigmoidParameter = 1; } diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 4a198f88bc..5d12c86f96 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -44,8 +44,20 @@ public partial class Arguments : ArgumentsBase public RegressionGamTrainer(IHostEnvironment env, Arguments args) : base(env, args, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. public RegressionGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) { } + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) + { + Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); + Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + } internal override void CheckLabel(RoleMappedData data) { diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 0a82ecb602..acb4abba08 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -145,6 +145,8 @@ public FastForestClassification(IHostEnvironment env, string labelColumn, string string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) { + Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); + Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); } /// @@ -205,9 +207,9 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc { return new[] { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), - new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) }; } diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index b2eac72ea0..0f24d14b4d 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -166,6 +166,8 @@ public FastForestRegression(IHostEnvironment env, string labelColumn, string fea string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, true, advancedSettings) { + Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); + Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); } /// @@ -217,7 +219,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc { return new[] { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) }; } diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index 9b9fb3db7a..d22390b842 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -97,10 +97,21 @@ public LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) { } + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the group ID. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { + Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); + Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); } private protected override IPredictorWithFeatureWeights CreatePredictor() diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index af47913e2c..f88aa3ab0e 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -39,10 +39,21 @@ public LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) _numClass = -1; } + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the group ID. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. public LightGbmMulticlassTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { + Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); + Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); _numClass = -1; } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 25e820381e..4975b2f19b 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -80,14 +80,26 @@ public sealed class LightGbmRankingTrainer : LightGbmTrainerBase PredictionKind.Ranking; public LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, LoadNameValue, args, TrainerUtils.MakeU4ScalarLabel(args.LabelColumn)) + : base(env, LoadNameValue, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the group ID. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. public LightGbmRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeU4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { + Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); + Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + Host.CheckValue(groupIdColumn, nameof(groupIdColumn), "groupIdColumn should not be null."); } protected override void CheckDataValid(IChannel ch, RoleMappedData data) diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index b20d36fa12..bb98ae6c6b 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -83,10 +83,21 @@ public sealed class LightGbmRegressorTrainer : LightGbmTrainerBase PredictionKind.Regression; + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the group ID. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { + Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); + Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); } public LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index f371322dca..f5a1dc44f5 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -53,7 +53,7 @@ public class Arguments : AveragedLinearArguments } public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args) - : base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn)) + : base(args, env, UserNameValue, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { _args = args; LossFunction = _args.LossFunction.CreateComponent(env); @@ -94,11 +94,6 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol) error(); } - private static SchemaShape.Column MakeLabelColumn(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); - } - protected override LinearBinaryPredictor CreatePredictor() { Contracts.Assert(WeightsScale == 1); diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index cf707e783f..768bb1da75 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -52,7 +52,7 @@ public Arguments() } public OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args) - : base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn)) + : base(args, env, UserNameValue, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { LossFunction = args.LossFunction.CreateComponent(env); @@ -92,11 +92,6 @@ protected override LinearRegressionPredictor CreatePredictor() return new LinearRegressionPredictor(Host, ref weights, bias); } - private static SchemaShape.Column MakeLabelColumn(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); - } - [TlcModule.EntryPoint(Name = "Trainers.OnlineGradientDescentRegressor", Desc = "Train a Online gradient descent perceptron.", UserName = UserNameValue, diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index e1611c46fb..4b8b0aa6bf 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -52,6 +52,9 @@ public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args, : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), args, TrainerUtils.MakeR4VecFeature(featureColumn), TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { + Host.CheckValue(labelColumn, nameof(labelColumn)); + Host.CheckValue(featureColumn, nameof(featureColumn)); + _loss = args.LossFunction.CreateComponent(env); Loss = _loss; _args = args; diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 74ca15ea44..15f247fbda 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -56,6 +56,9 @@ public SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featur : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), args, TrainerUtils.MakeR4VecFeature(featureColumn), TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { + Host.CheckValue(labelColumn, nameof(labelColumn)); + Host.CheckValue(featureColumn, nameof(featureColumn)); + _loss = args.LossFunction.CreateComponent(env); Loss = _loss; _args = args; diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index eb4e845a6c..27828a6a9c 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -125,6 +125,6 @@ private TextLoader.Arguments GetIrisLoaderArgs() new TextLoader.Column("Label", DataKind.Text, 4) } }; - } + } } } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index fa6784bd69..4985f7dd55 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -2,8 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Runtime.RunTests; using System.Linq; using Xunit; @@ -25,9 +27,62 @@ public TreeEstimators(ITestOutputHelper output) : base(output) [Fact] public void FastTreeBinaryEstimator() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) - { - var reader = new TextLoader(env, + var (pipeline, data) = GetBinaryClassificationPipeline(); + + pipeline.Append(new FastTreeBinaryClassificationTrainer(Env, "Label", "Features", advancedSettings: s => { + s.NumTrees = 10; + s.NumThreads = 1; + s.NumLeaves = 5; + })); + + TestEstimatorCore(pipeline, data); + } + + [Fact] + public void LightGBMBinaryEstimator() + { + var (pipeline, data) = GetBinaryClassificationPipeline(); + + pipeline.Append(new LightGbmBinaryTrainer(Env, "Label", "Features", advancedSettings: s => { + s.NumLeaves = 10; + s.NThread = 1; + s.MinDataPerLeaf = 2; + })); + + TestEstimatorCore(pipeline, data); + } + + + [Fact] + public void GAMClassificationEstimator() + { + var (pipeline, data) = GetBinaryClassificationPipeline(); + + pipeline.Append(new BinaryClassificationGamTrainer(Env, "Label", "Features", advancedSettings: s => { + s.GainConfidenceLevel = 0; + s.NumIterations = 15; + })); + + TestEstimatorCore(pipeline, data); + } + + + [Fact] + public void FastForestClassificationEstimator() + { + var (pipeline, data) = GetBinaryClassificationPipeline(); + + pipeline.Append(new FastForestClassification(Env, "Label", "Features", advancedSettings: s => { + s.NumLeaves = 10; + s.NumTrees = 20; + })); + + TestEstimatorCore(pipeline, data); + } + + private (IEstimator, IDataView) GetBinaryClassificationPipeline() + { + var data = new TextLoader(Env, new TextLoader.Arguments() { Separator = "\t", @@ -37,53 +92,62 @@ public void FastTreeBinaryEstimator() new TextLoader.Column("Label", DataKind.BL, 0), new TextLoader.Column("SentimentText", DataKind.Text, 1) } - }); + }).Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); + // Pipeline. + var pipeline = new TextTransform(Env, "SentimentText", "Features"); + + return (pipeline, data); + } + + /// + /// FastTreeBinaryClassification TrainerEstimator test + /// + [Fact] + public void FastTreeRankerEstimator() + { + var (pipeline, data) = GetRankingPipeline(); - // Pipeline. - var pipeline = new TextTransform(env, "SentimentText", "Features") - .Append(new FastTreeBinaryClassificationTrainer(env, "Label", "Features", advancedSettings: s => { - s.NumTrees = 10; - s.NumThreads = 1; - s.NumLeaves = 5; - })); + pipeline.Append(new FastTreeRankingTrainer(Env, "Label0", "NumericFeatures", "Group", + advancedSettings: s => { s.NumTrees = 10; })); - TestEstimatorCore(pipeline, data); - } + TestEstimatorCore(pipeline, data); } /// /// FastTreeBinaryClassification TrainerEstimator test /// [Fact] - public void FastTreeRankerEstimator() + public void LightGBMRankerEstimator() + { + var (pipeline, data) = GetRankingPipeline(); + + pipeline.Append(new LightGbmRankingTrainer(Env, "Label0", "NumericFeatures", "Group", + advancedSettings: s => { s.LearningRate = 0.4; })); + + TestEstimatorCore(pipeline, data); + } + + private (IEstimator, IDataView) GetRankingPipeline() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) + var data = new TextLoader(Env, new TextLoader.Arguments { - var reader = new TextLoader(env, new TextLoader.Arguments - { - HasHeader = true, - Separator ="\t", - Column = new[] - { + HasHeader = true, + Separator = "\t", + Column = new[] + { new TextLoader.Column("Label", DataKind.R4, 0), new TextLoader.Column("Workclass", DataKind.Text, 1), new TextLoader.Column("NumericFeatures", DataKind.R4, new [] { new TextLoader.Range(9, 14) }) } - }); - var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.adultRanking.trainFilename))); - + }).Read(new MultiFileSource(GetDataPath(TestDatasets.adultRanking.trainFilename))); - // Pipeline. - var pipeline = new TermEstimator(env, new[]{ + // Pipeline. + var pipeline = new TermEstimator(Env, new[]{ new TermTransform.ColumnInfo("Workclass", "Group"), - new TermTransform.ColumnInfo("Label", "Label0") }) - .Append(new FastTreeRankingTrainer(env, "Label0", "NumericFeatures", "Group", - advancedSettings: s => { s.NumTrees = 10; })); + new TermTransform.ColumnInfo("Label", "Label0") }); - TestEstimatorCore(pipeline, data); - } + return (pipeline, data); } /// @@ -92,10 +156,86 @@ public void FastTreeRankerEstimator() [Fact] public void FastTreeRegressorEstimator() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) - { - // "loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+}" - var reader = new TextLoader(env, + + // Pipeline. + var pipeline = new FastTreeRegressionTrainer(Env, "Label", "Features", advancedSettings: s => { + s.NumTrees = 10; + s.NumThreads = 1; + s.NumLeaves = 5; + }); + + TestEstimatorCore(pipeline, GetRegressionData()); + } + + /// + /// FastTreeRegressor TrainerEstimator test + /// + [Fact] + public void LightGBMRegressorEstimator() + { + + // Pipeline. + var pipeline = new LightGbmRegressorTrainer(Env, "Label", "Features", advancedSettings: s => { + s.NThread = 1; + s.NormalizeFeatures = NormalizeOption.Warn; + s.CatL2 = 5; + }); + + TestEstimatorCore(pipeline, GetRegressionData()); + } + + + /// + /// FastTreeRegressor TrainerEstimator test + /// + [Fact] + public void GAMRegressorEstimator() + { + + // Pipeline. + var pipeline = new RegressionGamTrainer(Env, "Label", "Features", advancedSettings: s => { + s.EnablePruning = false; + s.NumIterations = 15; + }); + + TestEstimatorCore(pipeline, GetRegressionData()); + } + + /// + /// FastTreeRegressor TrainerEstimator test + /// + [Fact] + public void TweedieRegressorEstimator() + { + + // Pipeline. + var pipeline = new FastTreeTweedieTrainer(Env, "Label", "Features", advancedSettings: s => { + s.EntropyCoefficient = 0.3; + s.OptimizationAlgorithm = BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent; + }); + + TestEstimatorCore(pipeline, GetRegressionData()); + } + + /// + /// FastTreeRegressor TrainerEstimator test + /// + [Fact] + public void FastForestRegressorEstimator() + { + + // Pipeline. + var pipeline = new FastForestRegression(Env, "Label", "Features", advancedSettings: s => { + s.BaggingSize = 2; + s.NumTrees = 10; + }); + + TestEstimatorCore(pipeline, GetRegressionData()); + } + + private IDataView GetRegressionData() + { + return new TextLoader(Env, new TextLoader.Arguments() { Separator = ";", @@ -105,19 +245,7 @@ public void FastTreeRegressorEstimator() new TextLoader.Column("Label", DataKind.R4, 11), new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(0, 10) } ) } - }); - - var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.generatedRegressionDatasetmacro.trainFilename))); - - // Pipeline. - var pipeline = new FastTreeRegressionTrainer(env, "Label", "Features", advancedSettings: s => { - s.NumTrees = 10; - s.NumThreads = 1; - s.NumLeaves = 5; - }); - - TestEstimatorCore(pipeline, data); - } + }).Read(new MultiFileSource(GetDataPath(TestDatasets.generatedRegressionDatasetmacro.trainFilename))); } } } From b157aa82eb500663e46f88537f46284082b95b78 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 20 Sep 2018 14:05:03 -0700 Subject: [PATCH 4/8] post merge work, fit and finish, --- src/Microsoft.ML.FastTree/GamTrainer.cs | 2 +- .../LightGbmTrainerBase.cs | 2 +- .../TrainerEstimators/MetalinearEstimators.cs | 92 +++++------------- .../TrainerEstimators/PriorRandomTests.cs | 40 +++----- .../TrainerEstimators/TrainerEstimators.cs | 92 ++++++++++++++++++ .../TrainerEstimators/TreeEstimators.cs | 97 ++++++------------- 6 files changed, 163 insertions(+), 162 deletions(-) diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 69b9b22f3c..7692ce7613 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -37,7 +37,7 @@ namespace Microsoft.ML.Runtime.FastTree /// Generalized Additive Model Learner. /// public abstract partial class GamTrainerBase : TrainerEstimatorBase - where TTransformer: IPredictionTransformer + where TTransformer: ISingleFeaturePredictionTransformer where TArgs : GamTrainerBase.ArgumentsBase, new() where TPredictor : IPredictorProducing { diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index 5e3b8bec59..8bbdda871f 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -26,7 +26,7 @@ internal static class LightGbmShared /// Base class for all training with LightGBM. /// public abstract class LightGbmTrainerBase : TrainerEstimatorBase - where TTransformer : IPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictorProducing { private sealed class CategoricalMetaData diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 43f4d01612..3075b908f8 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -20,21 +20,15 @@ public partial class TrainerEstimators [Fact] public void OVAWithExplicitCalibrator() { - var dataPath = GetDataPath(IrisDataPath); + var (pipeline, data) = GetMultiClassPipeline(); + var calibrator = new PavCalibratorTrainer(Env); - using (var env = new ConsoleEnvironment()) - { - var calibrator = new PavCalibratorTrainer(env); + var sdcaTrainer = new LinearClassificationTrainer(Env, new LinearClassificationTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label"); + pipeline.Append(new Ova(Env, sdcaTrainer, "Label", calibrator: calibrator, maxCalibrationExamples: 990000)) + .Append(new KeyToValueEstimator(Env, "PredictedLabel")); - var data = new TextLoader(env, GetIrisLoaderArgs()).Read(new MultiFileSource(dataPath)); - - var sdcaTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label"); - var pipeline = new TermEstimator(env, "Label") - .Append(new Ova(env, sdcaTrainer, "Label", calibrator: calibrator, maxCalibrationExamples: 990000)) - .Append(new KeyToValueEstimator(env, "PredictedLabel")); - - TestEstimatorCore(pipeline, data); - } + TestEstimatorCore(pipeline, data); + Done(); } /// @@ -43,23 +37,15 @@ public void OVAWithExplicitCalibrator() [Fact] public void OVAWithAllConstructorArgs() { - var dataPath = GetDataPath(IrisDataPath); - string featNam = "Features"; - string labNam = "Label"; - - using (var env = new ConsoleEnvironment()) - { - var calibrator = new FixedPlattCalibratorTrainer(env, new FixedPlattCalibratorTrainer.Arguments()); + var (pipeline, data) = GetMultiClassPipeline(); + var calibrator = new PlattCalibratorTrainer(Env); + var averagePerceptron = new AveragedPerceptronTrainer(Env, new AveragedPerceptronTrainer.Arguments { FeatureColumn = "Features", LabelColumn = "Label", Shuffle = true, Calibrator = null }); - var data = new TextLoader(env, GetIrisLoaderArgs()).Read(new MultiFileSource(dataPath)); + pipeline.Append(new Ova(Env, averagePerceptron, "Label", true, calibrator: calibrator, 10000, true)) + .Append(new KeyToValueEstimator(Env, "PredictedLabel")); - var averagePerceptron = new AveragedPerceptronTrainer(env, new AveragedPerceptronTrainer.Arguments { FeatureColumn = featNam, LabelColumn = labNam, Shuffle = true, Calibrator = null }); - var pipeline = new TermEstimator(env, labNam) - .Append(new Ova(env, averagePerceptron, labNam, true, calibrator: calibrator, 10000, true)) - .Append(new KeyToValueEstimator(env, "PredictedLabel")); - - TestEstimatorCore(pipeline, data); - } + TestEstimatorCore(pipeline, data); + Done(); } /// @@ -68,19 +54,15 @@ public void OVAWithAllConstructorArgs() [Fact] public void OVAUncalibrated() { - var dataPath = GetDataPath(IrisDataPath); + var (pipeline, data) = GetMultiClassPipeline(); - using (var env = new ConsoleEnvironment()) - { - var data = new TextLoader(env, GetIrisLoaderArgs()).Read(new MultiFileSource(dataPath)); + var sdcaTrainer = new LinearClassificationTrainer(Env, new LinearClassificationTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1, Calibrator = null }, "Features", "Label"); - var sdcaTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1, Calibrator = null }, "Features", "Label"); - var pipeline = new TermEstimator(env, "Label") - .Append(new Ova(env, sdcaTrainer, useProbabilities: false)) - .Append(new KeyToValueEstimator(env, "PredictedLabel")); + pipeline.Append(new Ova(Env, sdcaTrainer, useProbabilities: false)) + .Append(new KeyToValueEstimator(Env, "PredictedLabel")); - TestEstimatorCore(pipeline, data); - } + TestEstimatorCore(pipeline, data); + Done(); } /// @@ -89,36 +71,14 @@ public void OVAUncalibrated() [Fact(Skip = "The test fails the check for valid input to fit")] public void Pkpd() { - var dataPath = GetDataPath(IrisDataPath); - - using (var env = new ConsoleEnvironment()) - { - var calibrator = new PavCalibratorTrainer(env); + var (pipeline, data) = GetMultiClassPipeline(); - var data = new TextLoader(env, GetIrisLoaderArgs()) - .Read(new MultiFileSource(dataPath)); + var sdcaTrainer = new LinearClassificationTrainer(Env, new LinearClassificationTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label"); + pipeline.Append(new Pkpd(Env, sdcaTrainer)) + .Append(new KeyToValueEstimator(Env, "PredictedLabel")); - var sdcaTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label"); - var pipeline = new TermEstimator(env, "Label") - .Append(new Pkpd(env, sdcaTrainer)) - .Append(new KeyToValueEstimator(env, "PredictedLabel")); - - TestEstimatorCore(pipeline, data); - } - } - - private TextLoader.Arguments GetIrisLoaderArgs() - { - return new TextLoader.Arguments() - { - Separator = "comma", - HasHeader = true, - Column = new[] - { - new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(0, 3) }), - new TextLoader.Column("Label", DataKind.Text, 4) - } - }; + TestEstimatorCore(pipeline, data); + Done(); } } } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs index b510ab3bb0..da499f91fd 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs @@ -2,44 +2,32 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.RunTests; using Xunit; -using Xunit.Abstractions; -namespace Microsoft.ML.Tests +namespace Microsoft.ML.Tests.TrainerEstimators { - public class SimpleEstimatorTests : TestDataPipeBase + public partial class TrainerEstimators { private IDataView GetBreastCancerDataviewWithTextColumns() { - var dataPath = GetDataPath("breast-cancer.txt"); - var inputFile = new SimpleFileHandle(Env, dataPath, false, false); - return ImportTextData.TextLoader(Env, new ImportTextData.LoaderInput() - { - Arguments = - { - HasHeader = true, - Column = new[] - { - new TextLoader.Column("Label", type: null, 0), - new TextLoader.Column("F1", DataKind.Text, 1), - new TextLoader.Column("F2", DataKind.I4, 2), - new TextLoader.Column("Rest", type: null, new [] { new TextLoader.Range(3, 9) }) - } - }, - - InputFile = inputFile - }).Data; - } - public SimpleEstimatorTests(ITestOutputHelper output) : base(output) - { + return new TextLoader(Env, + new TextLoader.Arguments() + { + HasHeader = true, + Column = new[] + { + new TextLoader.Column("Label", type: null, 0), + new TextLoader.Column("F1", DataKind.Text, 1), + new TextLoader.Column("F2", DataKind.I4, 2), + new TextLoader.Column("Rest", type: null, new [] { new TextLoader.Range(3, 9) }) + } + }).Read(new MultiFileSource(GetDataPath(TestDatasets.breastCancer.trainFilename))); } [Fact] diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs index dc28fccc97..74fa97807a 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs @@ -16,5 +16,97 @@ public partial class TrainerEstimators : TestDataPipeBase public TrainerEstimators(ITestOutputHelper helper) : base(helper) { } + + private (IEstimator, IDataView) GetBinaryClassificationPipeline() + { + var data = new TextLoader(Env, + new TextLoader.Arguments() + { + Separator = "\t", + HasHeader = true, + Column = new[] + { + new TextLoader.Column("Label", DataKind.BL, 0), + new TextLoader.Column("SentimentText", DataKind.Text, 1) + } + }).Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); + + // Pipeline. + var pipeline = new TextTransform(Env, "SentimentText", "Features"); + + return (pipeline, data); + } + + + private (IEstimator, IDataView) GetRankingPipeline() + { + var data = new TextLoader(Env, new TextLoader.Arguments + { + HasHeader = true, + Separator = "\t", + Column = new[] + { + new TextLoader.Column("Label", DataKind.R4, 0), + new TextLoader.Column("Workclass", DataKind.Text, 1), + new TextLoader.Column("NumericFeatures", DataKind.R4, new [] { new TextLoader.Range(9, 14) }) + } + }).Read(new MultiFileSource(GetDataPath(TestDatasets.adultRanking.trainFilename))); + + // Pipeline. + var pipeline = new TermEstimator(Env, new[]{ + new TermTransform.ColumnInfo("Workclass", "Group"), + new TermTransform.ColumnInfo("Label", "Label0") }); + + return (pipeline, data); + } + + private IDataView GetRegressionPipeline() + { + return new TextLoader(Env, + new TextLoader.Arguments() + { + Separator = ";", + HasHeader = true, + Column = new[] + { + new TextLoader.Column("Label", DataKind.R4, 11), + new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(0, 10) } ) + } + }).Read(new MultiFileSource(GetDataPath(TestDatasets.generatedRegressionDatasetmacro.trainFilename))); + } + + private TextLoader.Arguments GetIrisLoaderArgs() + { + return new TextLoader.Arguments() + { + Separator = "comma", + HasHeader = true, + Column = new[] + { + new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(0, 3) }), + new TextLoader.Column("Label", DataKind.Text, 4) + } + }; + } + + private (IEstimator, IDataView) GetMultiClassPipeline() + { + + var data = new TextLoader(Env, new TextLoader.Arguments() + { + Separator = "comma", + HasHeader = true, + Column = new[] + { + new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(0, 3) }), + new TextLoader.Column("Label", DataKind.Text, 4) + } + }) + .Read(new MultiFileSource(GetDataPath(IrisDataPath))); + + var pipeline = new TermEstimator(Env, "Label"); + + return (pipeline, data); + } } } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index 4985f7dd55..a16ecd9a4f 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -2,25 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Runtime.RunTests; using System.Linq; using Xunit; -using Xunit.Abstractions; namespace Microsoft.ML.Tests.TrainerEstimators { - public partial class TreeEstimators : TestDataPipeBase + public partial class TrainerEstimators : TestDataPipeBase { - - public TreeEstimators(ITestOutputHelper output) : base(output) - { - } - - /// /// FastTreeBinaryClassification TrainerEstimator test /// @@ -36,6 +28,7 @@ public void FastTreeBinaryEstimator() })); TestEstimatorCore(pipeline, data); + Done(); } [Fact] @@ -50,6 +43,7 @@ public void LightGBMBinaryEstimator() })); TestEstimatorCore(pipeline, data); + Done(); } @@ -64,6 +58,7 @@ public void GAMClassificationEstimator() })); TestEstimatorCore(pipeline, data); + Done(); } @@ -78,28 +73,9 @@ public void FastForestClassificationEstimator() })); TestEstimatorCore(pipeline, data); + Done(); } - private (IEstimator, IDataView) GetBinaryClassificationPipeline() - { - var data = new TextLoader(Env, - new TextLoader.Arguments() - { - Separator = "\t", - HasHeader = true, - Column = new[] - { - new TextLoader.Column("Label", DataKind.BL, 0), - new TextLoader.Column("SentimentText", DataKind.Text, 1) - } - }).Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - - // Pipeline. - var pipeline = new TextTransform(Env, "SentimentText", "Features"); - - return (pipeline, data); - } - /// /// FastTreeBinaryClassification TrainerEstimator test /// @@ -112,6 +88,7 @@ public void FastTreeRankerEstimator() advancedSettings: s => { s.NumTrees = 10; })); TestEstimatorCore(pipeline, data); + Done(); } /// @@ -126,28 +103,7 @@ public void LightGBMRankerEstimator() advancedSettings: s => { s.LearningRate = 0.4; })); TestEstimatorCore(pipeline, data); - } - - private (IEstimator, IDataView) GetRankingPipeline() - { - var data = new TextLoader(Env, new TextLoader.Arguments - { - HasHeader = true, - Separator = "\t", - Column = new[] - { - new TextLoader.Column("Label", DataKind.R4, 0), - new TextLoader.Column("Workclass", DataKind.Text, 1), - new TextLoader.Column("NumericFeatures", DataKind.R4, new [] { new TextLoader.Range(9, 14) }) - } - }).Read(new MultiFileSource(GetDataPath(TestDatasets.adultRanking.trainFilename))); - - // Pipeline. - var pipeline = new TermEstimator(Env, new[]{ - new TermTransform.ColumnInfo("Workclass", "Group"), - new TermTransform.ColumnInfo("Label", "Label0") }); - - return (pipeline, data); + Done(); } /// @@ -164,7 +120,8 @@ public void FastTreeRegressorEstimator() s.NumLeaves = 5; }); - TestEstimatorCore(pipeline, GetRegressionData()); + TestEstimatorCore(pipeline, GetRegressionPipeline()); + Done(); } /// @@ -181,7 +138,8 @@ public void LightGBMRegressorEstimator() s.CatL2 = 5; }); - TestEstimatorCore(pipeline, GetRegressionData()); + TestEstimatorCore(pipeline, GetRegressionPipeline()); + Done(); } @@ -198,7 +156,8 @@ public void GAMRegressorEstimator() s.NumIterations = 15; }); - TestEstimatorCore(pipeline, GetRegressionData()); + TestEstimatorCore(pipeline, GetRegressionPipeline()); + Done(); } /// @@ -214,7 +173,8 @@ public void TweedieRegressorEstimator() s.OptimizationAlgorithm = BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent; }); - TestEstimatorCore(pipeline, GetRegressionData()); + TestEstimatorCore(pipeline, GetRegressionPipeline()); + Done(); } /// @@ -230,22 +190,23 @@ public void FastForestRegressorEstimator() s.NumTrees = 10; }); - TestEstimatorCore(pipeline, GetRegressionData()); + TestEstimatorCore(pipeline, GetRegressionPipeline()); + Done(); } - private IDataView GetRegressionData() + /// + /// FastTreeRegressor TrainerEstimator test + /// + [Fact] + public void LightGbmMultiClassEstimator() { - return new TextLoader(Env, - new TextLoader.Arguments() - { - Separator = ";", - HasHeader = true, - Column = new[] - { - new TextLoader.Column("Label", DataKind.R4, 11), - new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(0, 10) } ) - } - }).Read(new MultiFileSource(GetDataPath(TestDatasets.generatedRegressionDatasetmacro.trainFilename))); + var (pipeline, data) = GetMultiClassPipeline(); + + pipeline.Append(new LightGbmMulticlassTrainer(Env, "Label", "Features", advancedSettings: s => { s.LearningRate = 0.4; })) + .Append(new KeyToValueEstimator(Env, "PredictedLabel")); + + TestEstimatorCore(pipeline, data); + Done(); } } } From 45629c5c8126123eded05e2ecc9635360a7a87a8 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 20 Sep 2018 14:22:29 -0700 Subject: [PATCH 5/8] formatting --- test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs index da499f91fd..e729915d33 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs @@ -15,7 +15,6 @@ public partial class TrainerEstimators { private IDataView GetBreastCancerDataviewWithTextColumns() { - return new TextLoader(Env, new TextLoader.Arguments() { From 227788a9bd60c692127b2ed48c26b44c40124ffd Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 20 Sep 2018 14:36:51 -0700 Subject: [PATCH 6/8] CheckNonEmpty, rather than CheckValue for null checks. --- src/Microsoft.ML.FastTree/FastTreeClassification.cs | 4 ++-- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 6 +++--- src/Microsoft.ML.FastTree/FastTreeRegression.cs | 4 ++-- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 4 ++-- src/Microsoft.ML.FastTree/GamClassification.cs | 4 ++-- src/Microsoft.ML.FastTree/GamRegression.cs | 4 ++-- src/Microsoft.ML.FastTree/RandomForestClassification.cs | 4 ++-- src/Microsoft.ML.FastTree/RandomForestRegression.cs | 4 ++-- src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs | 4 ++-- src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs | 4 ++-- src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs | 6 +++--- src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs | 4 ++-- 12 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 8553fc1596..8ed4305935 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -127,8 +127,8 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelCol string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { - Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); - Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss _sigmoidParameter = 2.0 * Args.LearningRates; diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 59a83d4a47..1d251cbddf 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -72,9 +72,9 @@ public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string f string weightColumn = null, Action advancedSettings = null) : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) { - Host.CheckValue(labelColumn, nameof(labelColumn)); - Host.CheckValue(featureColumn, nameof(featureColumn)); - Host.CheckValue(groupIdColumn, nameof(groupIdColumn)); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); + Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); } /// diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 24c1a1fe62..315dd55fc9 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -64,8 +64,8 @@ public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, strin string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { - Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); - Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); } /// diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index d211377cbb..9ec436a6df 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -61,8 +61,8 @@ public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string f string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { - Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); - Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Initialize(); } diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index 889e5efff9..52e89c0246 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -66,8 +66,8 @@ public BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args) public BinaryClassificationGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) { - Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); - Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); _sigmoidParameter = 1; } diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 5d12c86f96..0e6a8745ab 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -55,8 +55,8 @@ public RegressionGamTrainer(IHostEnvironment env, Arguments args) public RegressionGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) { - Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); - Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); } internal override void CheckLabel(RoleMappedData data) diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index acb4abba08..7ce9dde504 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -145,8 +145,8 @@ public FastForestClassification(IHostEnvironment env, string labelColumn, string string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) { - Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); - Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); } /// diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 0f24d14b4d..d066f0299b 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -166,8 +166,8 @@ public FastForestRegression(IHostEnvironment env, string labelColumn, string fea string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, true, advancedSettings) { - Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); - Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); } /// diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index d22390b842..a7abb77e7f 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -110,8 +110,8 @@ public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string fe string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { - Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); - Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); } private protected override IPredictorWithFeatureWeights CreatePredictor() diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index f88aa3ab0e..7a8fd940d9 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -52,8 +52,8 @@ public LightGbmMulticlassTrainer(IHostEnvironment env, string labelColumn, strin string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { - Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); - Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); _numClass = -1; } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 4975b2f19b..b76f28c096 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -97,9 +97,9 @@ public LightGbmRankingTrainer(IHostEnvironment env, string labelColumn, string f string groupIdColumn, string weightColumn = null, Action advancedSettings = null) : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { - Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); - Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); - Host.CheckValue(groupIdColumn, nameof(groupIdColumn), "groupIdColumn should not be null."); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); + Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); } protected override void CheckDataValid(IChannel ch, RoleMappedData data) diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index bb98ae6c6b..4a67fa8222 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -96,8 +96,8 @@ public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { - Host.CheckValue(labelColumn, nameof(labelColumn), "labelColumn should not be null."); - Host.CheckValue(featureColumn, nameof(featureColumn), "featureColumn should not be null."); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); } public LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) From c5b46e99c4a5068d411aac3adc19f365b8138acf Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 21 Sep 2018 08:59:50 -0700 Subject: [PATCH 7/8] adjusting ctor visibility --- src/Microsoft.ML.FastTree/FastTree.cs | 4 ++-- src/Microsoft.ML.FastTree/FastTreeClassification.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeRegression.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 2 +- src/Microsoft.ML.FastTree/GamClassification.cs | 2 +- src/Microsoft.ML.FastTree/GamRegression.cs | 2 +- src/Microsoft.ML.FastTree/RandomForestClassification.cs | 2 +- src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs | 2 +- src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs | 2 +- src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs | 2 +- src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs | 2 +- 12 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 2ef7b0d24c..8302bb1341 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -89,7 +89,7 @@ public abstract class FastTreeTrainerBase : private protected virtual bool NeedCalibration => false; /// - /// Constructor to use when instantiating the classing deriving from here through the API. + /// Constructor to use when instantiating the classes deriving from here through the API. /// private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) @@ -120,7 +120,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l } /// - /// Legacy constructor that is used when invoking the classsing deriving from this, through maml. + /// Legacy constructor that is used when invoking the classes deriving from this, through maml. /// private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.WeightColumn)) diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 8ed4305935..8fbc4098a7 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -137,7 +137,7 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelCol /// /// Initializes a new instance of by using the legacy class. /// - public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args) + internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args) : base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 1d251cbddf..490d6aca25 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -80,7 +80,7 @@ public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string f /// /// Initializes a new instance of by using the legacy class. /// - public FastTreeRankingTrainer(IHostEnvironment env, Arguments args) + internal FastTreeRankingTrainer(IHostEnvironment env, Arguments args) : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 315dd55fc9..d0b6a2b6c6 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -71,7 +71,7 @@ public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, strin /// /// Initializes a new instance of by using the legacy class. /// - public FastTreeRegressionTrainer(IHostEnvironment env, Arguments args) + internal FastTreeRegressionTrainer(IHostEnvironment env, Arguments args) : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 9ec436a6df..4f85e31e1d 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -70,7 +70,7 @@ public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string f /// /// Initializes a new instance of by using the legacy class. /// - public FastTreeTweedieTrainer(IHostEnvironment env, Arguments args) + internal FastTreeTweedieTrainer(IHostEnvironment env, Arguments args) : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { Initialize(); diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index 52e89c0246..b87acc0e53 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -49,7 +49,7 @@ public sealed class Arguments : ArgumentsBase /// /// Initializes a new instance of /// - public BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args) + internal BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args) : base(env, args, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { _sigmoidParameter = 1; diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 0e6a8745ab..db6fffe8f2 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -41,7 +41,7 @@ public partial class Arguments : ArgumentsBase public override PredictionKind PredictionKind => PredictionKind.Regression; - public RegressionGamTrainer(IHostEnvironment env, Arguments args) + internal RegressionGamTrainer(IHostEnvironment env, Arguments args) : base(env, args, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } /// diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 7ce9dde504..f21376b07f 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -77,7 +77,7 @@ private static VersionInfo GetVersionInfo() /// public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - internal FastForestClassificationPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs) + public FastForestClassificationPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index a7abb77e7f..6846164378 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -92,7 +92,7 @@ public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase PredictionKind.BinaryClassification; - public LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) + internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { } diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 7a8fd940d9..262dc1caf9 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -33,7 +33,7 @@ public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase PredictionKind.MultiClassClassification; - public LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) + internal LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { _numClass = -1; diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index b76f28c096..936728a7a4 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -79,7 +79,7 @@ public sealed class LightGbmRankingTrainer : LightGbmTrainerBase PredictionKind.Ranking; - public LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) + internal LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, LoadNameValue, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 4a67fa8222..a43352bba8 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -100,7 +100,7 @@ public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); } - public LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) + internal LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, LoadNameValue, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } From 5588944729ee99eb86c62a90726651bf560c102a Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 21 Sep 2018 09:09:30 -0700 Subject: [PATCH 8/8] adjusting constructor --- .../ScenariosWithDirectInstantiation/TensorflowTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 71f5f95f33..6ee86513bc 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -303,7 +303,7 @@ public void TensorFlowTransformMNISTConvTest() trans = TensorFlowTransform.Create(env, trans, model_location, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" }); trans = new ConcatTransform(env, "Features", "Softmax", "dense/Relu").Transform(trans); - var trainer = new LightGbmMulticlassTrainer(env, new LightGbmArguments()); + var trainer = new LightGbmMulticlassTrainer(env, "Label", "Features"); var cached = new CacheDataView(env, trans, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");