From 96fd88ecb83a6b5c7de07d24b6af9b6134298912 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 6 Sep 2018 11:39:01 -0700 Subject: [PATCH 1/8] Converting AveragePerceptron, OGD and Linear SVM to estimators. --- .../Training/TrainerEstimatorBase.cs | 6 ++++ .../Standard/LinearPredictor.cs | 4 +-- .../Standard/Online/AveragedLinear.cs | 10 ++++--- .../Standard/Online/AveragedPerceptron.cs | 28 +++++++++++++++---- .../Standard/Online/LinearSvm.cs | 23 +++++++++++++-- .../Standard/Online/OnlineGradientDescent.cs | 24 +++++++++++++--- .../Standard/Online/OnlineLinear.cs | 28 +++++++++++++++---- 7 files changed, 97 insertions(+), 26 deletions(-) diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index 62625604ab..1bdb4a543b 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -17,6 +17,12 @@ public abstract class TrainerEstimatorBase : ITrainerEstim where TTransformer : IPredictionTransformer where TModel : IPredictor { + /// + /// A standard string to use in errors or warnings by subclasses, to communicate the idea that no valid + /// instances were able to be found. + /// + protected const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features."; + /// /// The feature column that the trainer expects. /// diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index 82069d413d..309b114e5b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -487,9 +487,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public override PredictionKind PredictionKind { - get { return PredictionKind.BinaryClassification; } - } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; /// /// Combine a bunch of models into one by averaging parameters diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs index 7a3ac55edf..99191bd0d1 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs @@ -11,6 +11,7 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Numeric; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Core.Data; // TODO: Check if it works properly if Averaged is set to false @@ -52,9 +53,10 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments public Float AveragedTolerance = (Float)1e-2; } - public abstract class AveragedLinearTrainer : OnlineLinearTrainer + public abstract class AveragedLinearTrainer : OnlineLinearTrainer where TArguments : AveragedLinearArguments - where TPredictor : IPredictorProducing + where TTransformer : IPredictionTransformer + where TModel : IPredictor { protected IScalarOutputLoss LossFunction; @@ -74,8 +76,8 @@ public abstract class AveragedLinearTrainer : OnlineLine // We'll keep a few things global to prevent garbage collection protected int NumNoUpdates; - protected AveragedLinearTrainer(TArguments args, IHostEnvironment env, string name) - : base(args, env, name) + protected AveragedLinearTrainer(TArguments args, IHostEnvironment env, string name, SchemaShape.Column label) + : base(args, env, name, label) { Contracts.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), UserErrorPositive); Contracts.CheckUserArg(!args.ResetWeightsAfterXExamples.HasValue || args.ResetWeightsAfterXExamples > 0, nameof(args.ResetWeightsAfterXExamples), UserErrorPositive); diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index c7c2d1d627..a03655930a 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.Numeric; using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Core.Data; [assembly: LoadableClass(AveragedPerceptronTrainer.Summary, typeof(AveragedPerceptronTrainer), typeof(AveragedPerceptronTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -29,8 +30,7 @@ namespace Microsoft.ML.Runtime.Learners // - Feature normalization. By default, rescaling between min and max values for every feature // - Prediction calibration to produce probabilities. Off by default, if on, uses exponential (aka Platt) calibration. /// - public sealed class AveragedPerceptronTrainer : - AveragedLinearTrainer + public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer , LinearBinaryPredictor> { public const string LoadNameValue = "AveragedPerceptron"; internal const string UserNameValue = "Averaged Perceptron"; @@ -49,15 +49,23 @@ public class Arguments : AveragedLinearArguments public int MaxCalibrationExamples = 1000000; } - protected override bool NeedCalibration => true; - public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args) - : base(args, env, UserNameValue) + : base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn)) { LossFunction = Args.LossFunction.CreateComponent(env); + + OutputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) + }; } - public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + + protected override bool NeedCalibration => true; + + protected override SchemaShape.Column[] OutputColumns { get; } protected override void CheckLabel(RoleMappedData data) { @@ -65,6 +73,11 @@ protected override void CheckLabel(RoleMappedData data) data.CheckBinaryLabel(); } + private static SchemaShape.Column MakeLabelColumn(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); + } + protected override LinearBinaryPredictor CreatePredictor() { Contracts.Assert(WeightsScale == 1); @@ -87,6 +100,9 @@ protected override LinearBinaryPredictor CreatePredictor() return new LinearBinaryPredictor(Host, ref weights, bias); } + protected override BinaryPredictionTransformer MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema) + => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + [TlcModule.EntryPoint(Name = "Trainers.AveragedPerceptronBinaryClassifier", Desc = Summary, UserName = UserNameValue, diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs index d435539e95..aa6431662f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs @@ -26,12 +26,13 @@ namespace Microsoft.ML.Runtime.Learners { + using Microsoft.ML.Core.Data; using TPredictor = LinearBinaryPredictor; /// /// Linear SVM that implements PEGASOS for training. See: http://ttic.uchicago.edu/~shai/papers/ShalevSiSr07.pdf /// - public sealed class LinearSvm : OnlineLinearTrainer + public sealed class LinearSvm : OnlineLinearTrainer, LinearBinaryPredictor> { public const string LoadNameValue = "LinearSVM"; public const string ShortName = "svm"; @@ -83,13 +84,21 @@ public sealed class Arguments : OnlineLinearArguments protected override bool NeedCalibration => true; public LinearSvm(IHostEnvironment env, Arguments args) - : base(args, env, UserNameValue) + : base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn)) { Contracts.CheckUserArg(args.Lambda > 0, nameof(args.Lambda), UserErrorPositive); Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive); + + OutputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) + }; } - public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + + protected override SchemaShape.Column[] OutputColumns { get; } protected override void CheckLabel(RoleMappedData data) { @@ -105,6 +114,11 @@ protected override Float Margin(ref VBuffer feat) return Bias + VectorUtils.DotProduct(ref feat, ref Weights) * WeightsScale; } + private static SchemaShape.Column MakeLabelColumn(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); + } + protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor) { base.InitCore(ch, numFeatures, predictor); @@ -237,5 +251,8 @@ public static CommonOutputs.BinaryClassificationOutput TrainLinearSvm(IHostEnvir () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples); } + + protected override BinaryPredictionTransformer MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema) + => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index 78dc5ea3b2..c3e22f60da 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -25,10 +25,11 @@ namespace Microsoft.ML.Runtime.Learners { + using Microsoft.ML.Core.Data; using TPredictor = LinearRegressionPredictor; /// - public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer + public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer, LinearRegressionPredictor> { internal const string LoadNameValue = "OnlineGradientDescent"; internal const string UserNameValue = "Stochastic Gradient Descent (Regression)"; @@ -53,19 +54,26 @@ public Arguments() } public OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args) - : base(args, env, UserNameValue) + : base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn)) { LossFunction = args.LossFunction.CreateComponent(env); + + OutputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false) + }; } - public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } } + public override PredictionKind PredictionKind => PredictionKind.Regression; + + protected override SchemaShape.Column[] OutputColumns { get; } protected override void CheckLabel(RoleMappedData data) { data.CheckRegressionLabel(); } - protected override TPredictor CreatePredictor() + protected override LinearRegressionPredictor CreatePredictor() { Contracts.Assert(WeightsScale == 1); VBuffer weights = default(VBuffer); @@ -85,6 +93,11 @@ protected override TPredictor CreatePredictor() return new LinearRegressionPredictor(Host, ref weights, bias); } + private static SchemaShape.Column MakeLabelColumn(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, true); + } + [TlcModule.EntryPoint(Name = "Trainers.OnlineGradientDescentRegressor", Desc = "Train a Online gradient descent perceptron.", UserName = UserNameValue, @@ -102,5 +115,8 @@ public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment en () => new OnlineGradientDescentTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); } + + protected override RegressionPredictionTransformer MakeTransformer(TPredictor model, ISchema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index 4976bf20d3..7eab10e3c5 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -4,6 +4,7 @@ using System; using System.Globalization; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -41,11 +42,13 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel public int StreamingCacheSize = 1000000; } - public abstract class OnlineLinearTrainer : TrainerBase + public abstract class OnlineLinearTrainer : TrainerEstimatorBase + where TTransformer : IPredictionTransformer + where TModel : IPredictor where TArguments : OnlineLinearArguments - where TPredictor : IPredictorProducing { protected readonly TArguments Args; + protected readonly string Name; // Initialized by InitCore protected int NumFeatures; @@ -74,8 +77,8 @@ public abstract class OnlineLinearTrainer : TrainerBase< protected virtual bool NeedCalibration => false; - protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name) - : base(env, name) + protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name, SchemaShape.Column label) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.InitialWeights)) { Contracts.CheckValue(args, nameof(args)); Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive); @@ -83,6 +86,7 @@ protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name Contracts.CheckUserArg(args.StreamingCacheSize > 0, nameof(args.StreamingCacheSize), UserErrorPositive); Args = args; + Name = name; // REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue. Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true); } @@ -111,7 +115,7 @@ protected void ScaleWeightsIfNeeded() ScaleWeights(); } - public override TPredictor Train(TrainContext context) + protected override TModel TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var initPredictor = context.InitialPredictor; @@ -148,10 +152,22 @@ public override TPredictor Train(TrainContext context) return CreatePredictor(); } - protected abstract TPredictor CreatePredictor(); + protected abstract TModel CreatePredictor(); protected abstract void CheckLabel(RoleMappedData data); + private static SchemaShape.Column MakeWeightColumn(string weightColumn) + { + if (weightColumn == null) + return null; + return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); + } + + private static SchemaShape.Column MakeFeatureColumn(string featureColumn) + { + return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + } + protected virtual void TrainCore(IChannel ch, RoleMappedData data) { bool shuffle = Args.Shuffle; From d47014e3ef32e70e1e4106a65764fc98db75ef4e Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 6 Sep 2018 11:43:03 -0700 Subject: [PATCH 2/8] Added Propability to the output columns of binary --- src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs | 2 +- .../Standard/Online/AveragedPerceptron.cs | 1 + src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index 309b114e5b..2a5d73705f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -487,7 +487,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; /// /// Combine a bunch of models into one by averaging parameters diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index a03655930a..350386f3b9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -57,6 +57,7 @@ public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args) OutputColumns = new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) }; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs index aa6431662f..4fde5fdb1f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs @@ -92,6 +92,7 @@ public LinearSvm(IHostEnvironment env, Arguments args) OutputColumns = new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) }; } From 23da0eb70c96dff1fe7aacd34a3d4e4937303da0 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 6 Sep 2018 12:56:27 -0700 Subject: [PATCH 3/8] fixing MakeLabel for OGD --- .../Standard/Online/OnlineGradientDescent.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index c3e22f60da..864eaa1673 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -95,7 +95,7 @@ protected override LinearRegressionPredictor CreatePredictor() private static SchemaShape.Column MakeLabelColumn(string labelColumn) { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, true); + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } [TlcModule.EntryPoint(Name = "Trainers.OnlineGradientDescentRegressor", From 101c2e8d90ff2dfe0a11ea4b7c3e42426984d192 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 6 Sep 2018 14:03:05 -0700 Subject: [PATCH 4/8] Fit should take an optional InitialPredictor for the OnlineTrainers. Updated test. --- src/Microsoft.ML.Core/Data/IEstimator.cs | 2 +- .../DataLoadSave/EstimatorChain.cs | 2 +- .../DataLoadSave/EstimatorExtensions.cs | 8 ++--- .../DataLoadSave/TrivialEstimator.cs | 4 +-- .../Training/TrainerEstimatorBase.cs | 2 +- .../Transforms/CopyColumnsTransform.cs | 2 +- .../Transforms/Normalizer.cs | 2 +- .../Transforms/TermEstimator.cs | 2 +- .../Standard/LinearClassificationTrainer.cs | 2 +- .../StaticPipeFakes.cs | 2 +- .../Estimators/TrainWithInitialPredictor.cs | 4 +-- .../Scenarios/Api/Estimators/Wrappers.cs | 33 +++---------------- 12 files changed, 21 insertions(+), 44 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index 54b0c64cd9..0b88564b51 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -237,7 +237,7 @@ public interface IEstimator /// /// Train and return a transformer. /// - TTransformer Fit(IDataView input); + TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null); /// /// Schema propagation for estimators. diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs index 727dfa9e0b..c308118bbc 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs @@ -42,7 +42,7 @@ public EstimatorChain() LastEstimator = null; } - public TransformerChain Fit(IDataView input) + public TransformerChain Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) { // REVIEW: before fitting, run schema propagation. // Currently, it throws. diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs index ecbc28ebdf..19d173bb5f 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs @@ -90,7 +90,7 @@ public DelegateEstimator(IEstimator estimator, Action - /// Given an estimator, return a wrapping object that will call a delegate once + /// Given an estimator, return a wrapping object that will call a delegate once /// is called. It is often important for an estimator to return information about what was fit, which is why the - /// method returns a specifically typed object, rather than just a general + /// method returns a specifically typed object, rather than just a general /// . However, at the same time, are often formed into pipelines /// with many objects, so we may need to build a chain of estimators via where the /// estimator for which we want to get the transformer is buried somewhere in this chain. For that scenario, we can through this @@ -113,7 +113,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) /// The type of returned by /// The estimator to wrap /// The delegate that is called with the resulting instances once - /// is called. Because + /// is called. Because /// may be called multiple times, this delegate may also be called multiple times. /// A wrapping estimator that calls the indicated delegate whenever fit is called public static IEstimator WithOnFitDelegate(this IEstimator estimator, Action onFit) diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs index 29c081ac35..4f228317b6 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs @@ -8,7 +8,7 @@ namespace Microsoft.ML.Runtime.Data { /// /// The trivial implementation of that already has - /// the transformer and returns it on every call to . + /// the transformer and returns it on every call to . /// /// Concrete implementations still have to provide the schema propagation mechanism, since /// there is no easy way to infer it from the transformer. @@ -28,7 +28,7 @@ protected TrivialEstimator(IHost host, TTransformer transformer) Transformer = transformer; } - public TTransformer Fit(IDataView input) => Transformer; + public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => Transformer; public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema); } diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index 1bdb4a543b..64b971b788 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -67,7 +67,7 @@ public TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape. WeightColumn = weight; } - public TTransformer Fit(IDataView input) => TrainTransformer(input); + public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => TrainTransformer(input, validationData, initialPredictor); public SchemaShape GetOutputSchema(SchemaShape inputSchema) { diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs index 9a287dd386..180f3c66b7 100644 --- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs @@ -54,7 +54,7 @@ public CopyColumnsEstimator(IHostEnvironment env, params (string source, string _columns = columns; } - public CopyColumnsTransform Fit(IDataView input) + public CopyColumnsTransform Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) { // Invoke schema validation. GetOutputSchema(SchemaShape.Create(input.Schema)); diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index 5d888cd433..c2e2850a3a 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -170,7 +170,7 @@ public Normalizer(IHostEnvironment env, params ColumnBase[] columns) _columns = columns.ToArray(); } - public NormalizerTransformer Fit(IDataView input) + public NormalizerTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) { _host.CheckValue(input, nameof(input)); return NormalizerTransformer.Train(_host, input, _columns); diff --git a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs index 33d23b9e8c..2aaf7b9e61 100644 --- a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs @@ -23,7 +23,7 @@ public TermEstimator(IHostEnvironment env, params TermTransform.ColumnInfo[] col _columns = columns; } - public TermTransform Fit(IDataView input) => new TermTransform(_host, input, _columns); + public TermTransform Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => new TermTransform(_host, input, _columns); public SchemaShape GetOutputSchema(SchemaShape inputSchema) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 52d2d3aef0..9e35d94975 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1478,7 +1478,7 @@ protected override void CheckLabel(RoleMappedData examples, out int weightSetCou protected override BinaryPredictionTransformer MakeTransformer(TScalarPredictor model, ISchema trainSchema) => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public BinaryPredictionTransformer Train(IDataView trainData, IDataView validationData) => TrainTransformer(trainData, validationData); + public BinaryPredictionTransformer Train(IDataView trainData, IDataView validationData = null, IPredictor initialPredictor = null) => TrainTransformer(trainData, validationData, initialPredictor); } public sealed class StochasticGradientDescentClassificationTrainer : diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs index 8313d94e7f..bee9d83fa2 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs @@ -51,7 +51,7 @@ public override IEstimator Reconcile( private sealed class FakeEstimator : IEstimator { - public ITransformer Fit(IDataView input) => throw new NotImplementedException(); + public ITransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => throw new NotImplementedException(); public SchemaShape GetOutputSchema(SchemaShape inputSchema) => throw new NotImplementedException(); } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs index a08c0f4b65..724a8838c8 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs @@ -37,8 +37,8 @@ public void New_TrainWithInitialPredictor() var firstModel = trainer.Fit(trainData); // Train the second predictor on the same data. - var secondTrainer = new MyAveragedPerceptron(env, new AveragedPerceptronTrainer.Arguments(), "Features", "Label"); - var finalModel = secondTrainer.Train(trainData, firstModel.Model); + var secondTrainer = new AveragedPerceptronTrainer(env, new AveragedPerceptronTrainer.Arguments()); + var finalModel = secondTrainer.Fit(trainData, initialPredictor: firstModel.Model); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs index ed4283f233..a34bac4ecf 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs @@ -245,7 +245,7 @@ protected TrainerBase(IHostEnvironment env, TrainerInfo trainerInfo, string feat Info = trainerInfo; } - public TTransformer Fit(IDataView input) + public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) { return TrainTransformer(input); } @@ -311,7 +311,7 @@ public MyTextTransform(IHostEnvironment env, TextTransform.Arguments args) _args = args; } - public TransformWrapper Fit(IDataView input) + public TransformWrapper Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) { var xf = TextTransform.Create(_env, _args, input); var empty = new EmptyDataView(_env, input.Schema); @@ -338,7 +338,7 @@ public MyConcatTransform(IHostEnvironment env, string name, params string[] sour _source = source; } - public TransformWrapper Fit(IDataView input) + public TransformWrapper Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) { var xf = new ConcatTransform(_env, input, _name, _source); var empty = new EmptyDataView(_env, input.Schema); @@ -365,7 +365,7 @@ public MyKeyToValueTransform(IHostEnvironment env, string name, string source = _source = source; } - public TransformWrapper Fit(IDataView input) + public TransformWrapper Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) { var xf = new KeyToValueTransform(_env, input, _name, _source); var empty = new EmptyDataView(_env, input.Schema); @@ -379,29 +379,6 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) } } - public sealed class MyAveragedPerceptron : TrainerBase, IPredictor> - { - private readonly AveragedPerceptronTrainer _trainer; - - public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - - public MyAveragedPerceptron(IHostEnvironment env, AveragedPerceptronTrainer.Arguments args, string featureCol, string labelCol) - : base(env, new TrainerInfo(caching: false), featureCol, labelCol) - { - _trainer = new AveragedPerceptronTrainer(env, args); - } - - protected override IPredictor TrainCore(TrainContext trainContext) => _trainer.Train(trainContext); - - public ITransformer Train(IDataView trainData, IPredictor initialPredictor) - { - return TrainTransformer(trainData, initPredictor: initialPredictor); - } - - protected override BinaryScorerWrapper MakeScorer(IPredictor predictor, RoleMappedData data) - => new BinaryScorerWrapper(_env, predictor, data.Data.Schema, _featureCol, new BinaryClassifierScorer.Arguments()); - } - public sealed class MyPredictionEngine where TSrc : class where TDst : class, new() @@ -533,7 +510,7 @@ public MyLambdaTransform(IHostEnvironment env, Action action) _action = action; } - public TransformWrapper Fit(IDataView input) + public TransformWrapper Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) { var xf = LambdaTransform.CreateMap(_env, input, _action); var empty = new EmptyDataView(_env, input.Schema); From 58dbbacd58f1c79f541b6941f18ca9b2b0a4200e Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 6 Sep 2018 14:36:01 -0700 Subject: [PATCH 5/8] Removing the arguments from the generics definition --- .../Standard/Online/AveragedLinear.cs | 10 ++++++---- .../Standard/Online/AveragedPerceptron.cs | 5 ++++- .../Standard/Online/LinearSvm.cs | 6 +++++- .../Standard/Online/OnlineGradientDescent.cs | 2 +- .../Standard/Online/OnlineLinear.cs | 7 +++---- 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs index 99191bd0d1..98773cdc91 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs @@ -53,13 +53,12 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments public Float AveragedTolerance = (Float)1e-2; } - public abstract class AveragedLinearTrainer : OnlineLinearTrainer - where TArguments : AveragedLinearArguments + public abstract class AveragedLinearTrainer : OnlineLinearTrainer where TTransformer : IPredictionTransformer where TModel : IPredictor { + protected readonly new AveragedLinearArguments Args; protected IScalarOutputLoss LossFunction; - protected Float Gain; // For computing averaged weights and bias (if needed) @@ -76,15 +75,18 @@ public abstract class AveragedLinearTrainer : // We'll keep a few things global to prevent garbage collection protected int NumNoUpdates; - protected AveragedLinearTrainer(TArguments args, IHostEnvironment env, string name, SchemaShape.Column label) + protected AveragedLinearTrainer(AveragedLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label) : base(args, env, name, label) { Contracts.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), UserErrorPositive); Contracts.CheckUserArg(!args.ResetWeightsAfterXExamples.HasValue || args.ResetWeightsAfterXExamples > 0, nameof(args.ResetWeightsAfterXExamples), UserErrorPositive); + // Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible. Contracts.CheckUserArg(0 <= args.L2RegularizerWeight && args.L2RegularizerWeight < 0.5, nameof(args.L2RegularizerWeight), "must be in range [0, 0.5)"); Contracts.CheckUserArg(args.RecencyGain >= 0, nameof(args.RecencyGain), UserErrorNonNegative); Contracts.CheckUserArg(args.AveragedTolerance >= 0, nameof(args.AveragedTolerance), UserErrorNonNegative); + + Args = args; } protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index 350386f3b9..3a8153dbe2 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -30,13 +30,15 @@ namespace Microsoft.ML.Runtime.Learners // - Feature normalization. By default, rescaling between min and max values for every feature // - Prediction calibration to produce probabilities. Off by default, if on, uses exponential (aka Platt) calibration. /// - public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer , LinearBinaryPredictor> + public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer , LinearBinaryPredictor> { public const string LoadNameValue = "AveragedPerceptron"; internal const string UserNameValue = "Averaged Perceptron"; internal const string ShortName = "ap"; internal const string Summary = "Averaged Perceptron Binary Classifier."; + internal new readonly Arguments Args; + public class Arguments : AveragedLinearArguments { [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] @@ -52,6 +54,7 @@ public class Arguments : AveragedLinearArguments public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args) : base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn)) { + Args = args; LossFunction = Args.LossFunction.CreateComponent(env); OutputColumns = new[] diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs index 4fde5fdb1f..7992f1f7e1 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs @@ -32,7 +32,7 @@ namespace Microsoft.ML.Runtime.Learners /// /// Linear SVM that implements PEGASOS for training. See: http://ttic.uchicago.edu/~shai/papers/ShalevSiSr07.pdf /// - public sealed class LinearSvm : OnlineLinearTrainer, LinearBinaryPredictor> + public sealed class LinearSvm : OnlineLinearTrainer, LinearBinaryPredictor> { public const string LoadNameValue = "LinearSVM"; public const string ShortName = "svm"; @@ -42,6 +42,8 @@ public sealed class LinearSvm : OnlineLinearTrainer 0, nameof(args.Lambda), UserErrorPositive); Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive); + Args = args; + OutputColumns = new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index 864eaa1673..bb47b26ef9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -29,7 +29,7 @@ namespace Microsoft.ML.Runtime.Learners using TPredictor = LinearRegressionPredictor; /// - public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer, LinearRegressionPredictor> + public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer, LinearRegressionPredictor> { internal const string LoadNameValue = "OnlineGradientDescent"; internal const string UserNameValue = "Stochastic Gradient Descent (Regression)"; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index 7eab10e3c5..72429341db 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -42,12 +42,11 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel public int StreamingCacheSize = 1000000; } - public abstract class OnlineLinearTrainer : TrainerEstimatorBase + public abstract class OnlineLinearTrainer : TrainerEstimatorBase where TTransformer : IPredictionTransformer where TModel : IPredictor - where TArguments : OnlineLinearArguments { - protected readonly TArguments Args; + protected readonly OnlineLinearArguments Args; protected readonly string Name; // Initialized by InitCore @@ -77,7 +76,7 @@ public abstract class OnlineLinearTrainer : Tr protected virtual bool NeedCalibration => false; - protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name, SchemaShape.Column label) + protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label) : base(Contracts.CheckRef(env, nameof(env)).Register(name), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.InitialWeights)) { Contracts.CheckValue(args, nameof(args)); From 0c176aab5bc601689414a1cc2fb03acba1750eeb Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 6 Sep 2018 16:44:18 -0700 Subject: [PATCH 6/8] Reverting the signature change on Fit() --- src/Microsoft.ML.Core/Data/IEstimator.cs | 2 +- src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs | 2 +- .../DataLoadSave/EstimatorExtensions.cs | 8 ++++---- src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs | 4 ++-- src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs | 2 +- .../Transforms/CopyColumnsTransform.cs | 2 +- src/Microsoft.ML.Data/Transforms/Normalizer.cs | 2 +- src/Microsoft.ML.Data/Transforms/TermEstimator.cs | 2 +- .../StaticPipeFakes.cs | 2 +- .../Api/Estimators/TrainWithInitialPredictor.cs | 5 ++++- .../Scenarios/Api/Estimators/Wrappers.cs | 10 +++++----- 11 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index 0b88564b51..54b0c64cd9 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -237,7 +237,7 @@ public interface IEstimator /// /// Train and return a transformer. /// - TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null); + TTransformer Fit(IDataView input); /// /// Schema propagation for estimators. diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs index c308118bbc..727dfa9e0b 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs @@ -42,7 +42,7 @@ public EstimatorChain() LastEstimator = null; } - public TransformerChain Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) + public TransformerChain Fit(IDataView input) { // REVIEW: before fitting, run schema propagation. // Currently, it throws. diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs index 19d173bb5f..ecbc28ebdf 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs @@ -90,7 +90,7 @@ public DelegateEstimator(IEstimator estimator, Action - /// Given an estimator, return a wrapping object that will call a delegate once + /// Given an estimator, return a wrapping object that will call a delegate once /// is called. It is often important for an estimator to return information about what was fit, which is why the - /// method returns a specifically typed object, rather than just a general + /// method returns a specifically typed object, rather than just a general /// . However, at the same time, are often formed into pipelines /// with many objects, so we may need to build a chain of estimators via where the /// estimator for which we want to get the transformer is buried somewhere in this chain. For that scenario, we can through this @@ -113,7 +113,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) /// The type of returned by /// The estimator to wrap /// The delegate that is called with the resulting instances once - /// is called. Because + /// is called. Because /// may be called multiple times, this delegate may also be called multiple times. /// A wrapping estimator that calls the indicated delegate whenever fit is called public static IEstimator WithOnFitDelegate(this IEstimator estimator, Action onFit) diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs index 4f228317b6..29c081ac35 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs @@ -8,7 +8,7 @@ namespace Microsoft.ML.Runtime.Data { /// /// The trivial implementation of that already has - /// the transformer and returns it on every call to . + /// the transformer and returns it on every call to . /// /// Concrete implementations still have to provide the schema propagation mechanism, since /// there is no easy way to infer it from the transformer. @@ -28,7 +28,7 @@ protected TrivialEstimator(IHost host, TTransformer transformer) Transformer = transformer; } - public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => Transformer; + public TTransformer Fit(IDataView input) => Transformer; public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema); } diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index 64b971b788..1bdb4a543b 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -67,7 +67,7 @@ public TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape. WeightColumn = weight; } - public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => TrainTransformer(input, validationData, initialPredictor); + public TTransformer Fit(IDataView input) => TrainTransformer(input); public SchemaShape GetOutputSchema(SchemaShape inputSchema) { diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs index 180f3c66b7..9a287dd386 100644 --- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs @@ -54,7 +54,7 @@ public CopyColumnsEstimator(IHostEnvironment env, params (string source, string _columns = columns; } - public CopyColumnsTransform Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) + public CopyColumnsTransform Fit(IDataView input) { // Invoke schema validation. GetOutputSchema(SchemaShape.Create(input.Schema)); diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index c2e2850a3a..5d888cd433 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -170,7 +170,7 @@ public Normalizer(IHostEnvironment env, params ColumnBase[] columns) _columns = columns.ToArray(); } - public NormalizerTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) + public NormalizerTransformer Fit(IDataView input) { _host.CheckValue(input, nameof(input)); return NormalizerTransformer.Train(_host, input, _columns); diff --git a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs index 2aaf7b9e61..33d23b9e8c 100644 --- a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs @@ -23,7 +23,7 @@ public TermEstimator(IHostEnvironment env, params TermTransform.ColumnInfo[] col _columns = columns; } - public TermTransform Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => new TermTransform(_host, input, _columns); + public TermTransform Fit(IDataView input) => new TermTransform(_host, input, _columns); public SchemaShape GetOutputSchema(SchemaShape inputSchema) { diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs index bee9d83fa2..8313d94e7f 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs @@ -51,7 +51,7 @@ public override IEstimator Reconcile( private sealed class FakeEstimator : IEstimator { - public ITransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => throw new NotImplementedException(); + public ITransformer Fit(IDataView input) => throw new NotImplementedException(); public SchemaShape GetOutputSchema(SchemaShape inputSchema) => throw new NotImplementedException(); } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs index 724a8838c8..e38a78c083 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; using Xunit; @@ -38,7 +39,9 @@ public void New_TrainWithInitialPredictor() // Train the second predictor on the same data. var secondTrainer = new AveragedPerceptronTrainer(env, new AveragedPerceptronTrainer.Arguments()); - var finalModel = secondTrainer.Fit(trainData, initialPredictor: firstModel.Model); + + var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features"); + var finalModel = secondTrainer.Train(new TrainContext(trainRoles, initialPredictor: firstModel.Model)); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs index a34bac4ecf..4ef94407a5 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs @@ -245,7 +245,7 @@ protected TrainerBase(IHostEnvironment env, TrainerInfo trainerInfo, string feat Info = trainerInfo; } - public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) + public TTransformer Fit(IDataView input) { return TrainTransformer(input); } @@ -311,7 +311,7 @@ public MyTextTransform(IHostEnvironment env, TextTransform.Arguments args) _args = args; } - public TransformWrapper Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) + public TransformWrapper Fit(IDataView input) { var xf = TextTransform.Create(_env, _args, input); var empty = new EmptyDataView(_env, input.Schema); @@ -338,7 +338,7 @@ public MyConcatTransform(IHostEnvironment env, string name, params string[] sour _source = source; } - public TransformWrapper Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) + public TransformWrapper Fit(IDataView input) { var xf = new ConcatTransform(_env, input, _name, _source); var empty = new EmptyDataView(_env, input.Schema); @@ -365,7 +365,7 @@ public MyKeyToValueTransform(IHostEnvironment env, string name, string source = _source = source; } - public TransformWrapper Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) + public TransformWrapper Fit(IDataView input) { var xf = new KeyToValueTransform(_env, input, _name, _source); var empty = new EmptyDataView(_env, input.Schema); @@ -510,7 +510,7 @@ public MyLambdaTransform(IHostEnvironment env, Action action) _action = action; } - public TransformWrapper Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) + public TransformWrapper Fit(IDataView input) { var xf = LambdaTransform.CreateMap(_env, input, _action); var empty = new EmptyDataView(_env, input.Schema); From 6a019262c81fc75a486a51ccf1a40c160d6f307f Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 6 Sep 2018 22:26:34 -0700 Subject: [PATCH 7/8] addressing comments --- .../Standard/Online/AveragedPerceptron.cs | 8 ++++---- .../Standard/Online/OnlineGradientDescent.cs | 6 ++---- .../Standard/Online/OnlineLinear.cs | 3 ++- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index 3a8153dbe2..7a38d3dbac 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -37,7 +37,7 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer weights = default(VBuffer); Float bias; - if (!Args.Averaged) + if (!_args.Averaged) { Weights.CopyTo(ref weights); bias = Bias; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index bb47b26ef9..0d22870070 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -4,9 +4,9 @@ using Float = System.Single; -using System; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Learners; @@ -25,8 +25,6 @@ namespace Microsoft.ML.Runtime.Learners { - using Microsoft.ML.Core.Data; - using TPredictor = LinearRegressionPredictor; /// public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer, LinearRegressionPredictor> @@ -116,7 +114,7 @@ public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment en () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); } - protected override RegressionPredictionTransformer MakeTransformer(TPredictor model, ISchema trainSchema) + protected override RegressionPredictionTransformer MakeTransformer(LinearRegressionPredictor model, ISchema trainSchema) => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index 72429341db..d7b7d4cf2f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Float = System.Single; + using System; using System.Globalization; using Microsoft.ML.Core.Data; @@ -16,7 +18,6 @@ namespace Microsoft.ML.Runtime.Learners { - using Float = System.Single; public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel { From 40fcd333a61a801710a52d24a092a0ad227b9573 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 7 Sep 2018 08:45:39 -0700 Subject: [PATCH 8/8] ordering usings --- .../Standard/Online/AveragedLinear.cs | 2 +- .../Standard/Online/AveragedPerceptron.cs | 2 +- .../Standard/Online/OnlineGradientDescent.cs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs index 98773cdc91..402d227fad 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs @@ -5,13 +5,13 @@ using Float = System.Single; using System; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Numeric; using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Core.Data; // TODO: Check if it works properly if Averaged is set to false diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index 7a38d3dbac..82e760d41c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -4,6 +4,7 @@ using Float = System.Single; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -13,7 +14,6 @@ using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.Numeric; using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Core.Data; [assembly: LoadableClass(AveragedPerceptronTrainer.Summary, typeof(AveragedPerceptronTrainer), typeof(AveragedPerceptronTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index 0d22870070..60569c4319 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -4,9 +4,9 @@ using Float = System.Single; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Learners;