diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs index 5f08dee906..ca4f9531a7 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs @@ -1,6 +1,7 @@ using System; using System.Linq; using Microsoft.ML.Data; +using Microsoft.ML.Trainers; namespace Microsoft.ML.Samples.Dynamic { @@ -59,15 +60,13 @@ public static void SDCA_BinaryClassification() // If we wanted to specify more advanced parameters for the algorithm, // we could do so by tweaking the 'advancedSetting'. var advancedPipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features") - .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent - (labelColumn: "Sentiment", - featureColumn: "Features", - advancedSettings: s=> - { - s.ConvergenceTolerance = 0.01f; // The learning rate for adjusting bias from being regularized - s.NumThreads = 2; // Degree of lock-free parallelism - }) - ); + .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { + LabelColumn = "Sentiment", + FeatureColumn = "Features", + ConvergenceTolerance = 0.01f, // The learning rate for adjusting bias from being regularized + NumThreads = 2, // Degree of lock-free parallelism + })); // Run Cross-Validation on this second pipeline. var cvResults_advancedPipeline = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment", numFolds: 3); diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs index b4c5f1a506..313f66837d 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -47,6 +47,8 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi FastTreeRegressionTrainer.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(options, nameof(options)); + var env = CatalogUtils.GetEnvironment(ctx); return new FastTreeRegressionTrainer(env, options); } @@ -85,6 +87,8 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica FastTreeBinaryClassificationTrainer.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(options, nameof(options)); + var env = CatalogUtils.GetEnvironment(ctx); return new FastTreeBinaryClassificationTrainer(env, options); } @@ -125,6 +129,8 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer FastTreeRankingTrainer.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(options, nameof(options)); + var env = CatalogUtils.GetEnvironment(ctx); return new FastTreeRankingTrainer(env, options); } @@ -213,6 +219,8 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr FastTreeTweedieTrainer.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(options, nameof(options)); + var env = CatalogUtils.GetEnvironment(ctx); return new FastTreeTweedieTrainer(env, options); } @@ -251,6 +259,8 @@ public static FastForestRegression FastForest(this RegressionContext.RegressionT FastForestRegression.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(options, nameof(options)); + var env = CatalogUtils.GetEnvironment(ctx); return new FastForestRegression(env, options); } @@ -289,6 +299,8 @@ public static FastForestClassification FastForest(this BinaryClassificationConte FastForestClassification.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(options, nameof(options)); + var env = CatalogUtils.GetEnvironment(ctx); return new FastForestClassification(env, options); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs index 1648314c9d..01c57d27f9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs @@ -23,7 +23,7 @@ using Microsoft.ML.Training; using Microsoft.ML.Transforms; -[assembly: LoadableClass(typeof(SdcaBinaryTrainer), typeof(SdcaBinaryTrainer.Arguments), +[assembly: LoadableClass(typeof(SdcaBinaryTrainer), typeof(SdcaBinaryTrainer.Options), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, SdcaBinaryTrainer.UserNameValue, SdcaBinaryTrainer.LoadNameValue, @@ -31,7 +31,7 @@ "lc", "sasdca")] -[assembly: LoadableClass(typeof(StochasticGradientDescentClassificationTrainer), typeof(StochasticGradientDescentClassificationTrainer.Arguments), +[assembly: LoadableClass(typeof(StochasticGradientDescentClassificationTrainer), typeof(StochasticGradientDescentClassificationTrainer.Options), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, StochasticGradientDescentClassificationTrainer.UserNameValue, StochasticGradientDescentClassificationTrainer.LoadNameValue, @@ -246,21 +246,19 @@ protected enum MetricKind private const string RegisterName = nameof(SdcaTrainerBase); - private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn, Action advancedSettings = null) + private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn) { var args = new TArgs(); - // Apply the advanced args, if the user supplied any. - advancedSettings?.Invoke(args); args.FeatureColumn = featureColumn; args.LabelColumn = labelColumn.Name; return args; } internal SdcaTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, - SchemaShape.Column weight = default, Action advancedSettings = null, float? l2Const = null, + SchemaShape.Column weight = default, float? l2Const = null, float? l1Threshold = null, int? maxIterations = null) - : this(env, ArgsInit(featureColumn, labelColumn, advancedSettings), labelColumn, weight, l2Const, l1Threshold, maxIterations) + : this(env, ArgsInit(featureColumn, labelColumn), labelColumn, weight, l2Const, l1Threshold, maxIterations) { } @@ -1391,12 +1389,12 @@ public void Add(Double summand) } } - public sealed class SdcaBinaryTrainer : SdcaTrainerBase, TScalarPredictor> + public sealed class SdcaBinaryTrainer : SdcaTrainerBase, TScalarPredictor> { - public const string LoadNameValue = "SDCA"; + internal const string LoadNameValue = "SDCA"; internal const string UserNameValue = "Fast Linear (SA-SDCA)"; - public sealed class Arguments : ArgumentsBase + public sealed class Options : ArgumentsBase { [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory(); @@ -1441,21 +1439,16 @@ internal override void Check(IHostEnvironment env) /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . - public SdcaBinaryTrainer(IHostEnvironment env, + internal SdcaBinaryTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weightColumn = null, ISupportSdcaClassificationLoss loss = null, float? l2Const = null, float? l1Threshold = null, - int? maxIterations = null, - Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, - l2Const, l1Threshold, maxIterations) + int? maxIterations = null) + : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), + l2Const, l1Threshold, maxIterations) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); @@ -1495,11 +1488,10 @@ public SdcaBinaryTrainer(IHostEnvironment env, _outputColumns = outCols.ToArray(); } - internal SdcaBinaryTrainer(IHostEnvironment env, Arguments args, - string featureColumn, string labelColumn, string weightColumn = null) - : base(env, args, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + internal SdcaBinaryTrainer(IHostEnvironment env, Options options) + : base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn)) { - _loss = args.LossFunction.CreateComponent(env); + _loss = options.LossFunction.CreateComponent(env); Loss = _loss; Info = new TrainerInfo(calibration: !(_loss is LogLoss)); _positiveInstanceWeight = Args.PositiveInstanceWeight; @@ -1533,12 +1525,6 @@ internal SdcaBinaryTrainer(IHostEnvironment env, Arguments args, }; _outputColumns = outCols.ToArray(); - - } - - public SdcaBinaryTrainer(IHostEnvironment env, Arguments args) - : this(env, args, args.FeatureColumn, args.LabelColumn) - { } protected override void CheckLabelCompatible(SchemaShape.Column labelCol) @@ -1594,7 +1580,7 @@ public sealed class StochasticGradientDescentClassificationTrainer : internal const string UserNameValue = "Hogwild SGD (binary)"; internal const string ShortName = "HogwildSGD"; - public sealed class Arguments : LearnerInputBaseWithWeight + public sealed class Options : LearnerInputBaseWithWeight { [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] public ISupportClassificationLossFactory LossFunction = new LogLossFactory(); @@ -1669,9 +1655,9 @@ internal static class Defaults } private readonly IClassificationLoss _loss; - private readonly Arguments _args; + private readonly Options _options; - protected override bool ShuffleData => _args.Shuffle; + protected override bool ShuffleData => _options.Shuffle; public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; @@ -1688,53 +1674,50 @@ internal static class Defaults /// The initial learning rate used by SGD. /// The L2 regularizer constant. /// The loss function to use. - /// A delegate to apply all the advanced arguments to the algorithm. - public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, + internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weightColumn = null, - int maxIterations = Arguments.Defaults.MaxIterations, - double initLearningRate = Arguments.Defaults.InitLearningRate, - float l2Weight = Arguments.Defaults.L2Weight, - ISupportClassificationLossFactory loss = null, - Action advancedSettings = null) + int maxIterations = Options.Defaults.MaxIterations, + double initLearningRate = Options.Defaults.InitLearningRate, + float l2Weight = Options.Defaults.L2Weight, + ISupportClassificationLossFactory loss = null) : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weightColumn) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - _args = new Arguments(); - _args.MaxIterations = maxIterations; - _args.InitLearningRate = initLearningRate; - _args.L2Weight = l2Weight; - - // Apply the advanced args, if the user supplied any. - advancedSettings?.Invoke(_args); + _options = new Options(); + _options.MaxIterations = maxIterations; + _options.InitLearningRate = initLearningRate; + _options.L2Weight = l2Weight; - _args.FeatureColumn = featureColumn; - _args.LabelColumn = labelColumn; - _args.WeightColumn = weightColumn; + _options.FeatureColumn = featureColumn; + _options.LabelColumn = labelColumn; + _options.WeightColumn = weightColumn != null ? Optional.Explicit(weightColumn) : Optional.Implicit(DefaultColumnNames.Weight); if (loss != null) - _args.LossFunction = loss; - _args.Check(env); + _options.LossFunction = loss; + _options.Check(env); - _loss = _args.LossFunction.CreateComponent(env); + _loss = _options.LossFunction.CreateComponent(env); Info = new TrainerInfo(calibration: !(_loss is LogLoss), supportIncrementalTrain: true); - NeedShuffle = _args.Shuffle; + NeedShuffle = _options.Shuffle; } /// /// Initializes a new instance of /// - internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Arguments args) - : base(env, args.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn), args.WeightColumn) + /// The environment to use. + /// Advanced arguments to the algorithm. + internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Options options) + : base(env, options.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn), options.WeightColumn.IsExplicit ? options.WeightColumn.Value : null) { - args.Check(env); - _loss = args.LossFunction.CreateComponent(env); + options.Check(env); + _loss = options.LossFunction.CreateComponent(env); Info = new TrainerInfo(calibration: !(_loss is LogLoss), supportIncrementalTrain: true); - NeedShuffle = args.Shuffle; - _args = args; + NeedShuffle = options.Shuffle; + _options = options; } protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -1767,22 +1750,22 @@ private protected override TScalarPredictor TrainCore(IChannel ch, RoleMappedDat var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features | CursOpt.Weight); int numThreads; - if (_args.NumThreads.HasValue) + if (_options.NumThreads.HasValue) { - numThreads = _args.NumThreads.Value; - ch.CheckUserArg(numThreads > 0, nameof(_args.NumThreads), "The number of threads must be either null or a positive integer."); + numThreads = _options.NumThreads.Value; + ch.CheckUserArg(numThreads > 0, nameof(_options.NumThreads), "The number of threads must be either null or a positive integer."); } else numThreads = ComputeNumThreads(cursorFactory); ch.Assert(numThreads > 0); - int checkFrequency = _args.CheckFrequency ?? numThreads; + int checkFrequency = _options.CheckFrequency ?? numThreads; if (checkFrequency <= 0) checkFrequency = int.MaxValue; - var l2Weight = _args.L2Weight; + var l2Weight = _options.L2Weight; var lossFunc = _loss; var pOptions = new ParallelOptions { MaxDegreeOfParallelism = numThreads }; - var positiveInstanceWeight = _args.PositiveInstanceWeight; + var positiveInstanceWeight = _options.PositiveInstanceWeight; var weights = default(VBuffer); float bias = 0.0f; if (predictor != null) @@ -1804,7 +1787,7 @@ private protected override TScalarPredictor TrainCore(IChannel ch, RoleMappedDat // REVIEW: Investigate using parallel row cursor set instead of getting cursor independently. The convergence of SDCA need to be verified. Action checkConvergence = (e, pch) => { - if (e % checkFrequency == 0 && e != _args.MaxIterations) + if (e % checkFrequency == 0 && e != _options.MaxIterations) { Double trainTime = watch.Elapsed.TotalSeconds; var lossSum = new CompensatedSum(); @@ -1830,8 +1813,8 @@ private protected override TScalarPredictor TrainCore(IChannel ch, RoleMappedDat improvement = improvement == 0 ? loss - newLoss : 0.5 * (loss - newLoss + improvement); loss = newLoss; - pch.Checkpoint(loss, improvement, e, _args.MaxIterations); - converged = improvement < _args.ConvergenceTolerance; + pch.Checkpoint(loss, improvement, e, _options.MaxIterations); + converged = improvement < _options.ConvergenceTolerance; } }; @@ -1840,17 +1823,17 @@ private protected override TScalarPredictor TrainCore(IChannel ch, RoleMappedDat //Reference: Leon Bottou. Stochastic Gradient Descent Tricks. //https://research.microsoft.com/pubs/192769/tricks-2012.pdf - var trainingTasks = new Action[_args.MaxIterations]; - var rands = new Random[_args.MaxIterations]; - var ilr = _args.InitLearningRate; + var trainingTasks = new Action[_options.MaxIterations]; + var rands = new Random[_options.MaxIterations]; + var ilr = _options.InitLearningRate; long t = 0; - for (int epoch = 1; epoch <= _args.MaxIterations; epoch++) + for (int epoch = 1; epoch <= _options.MaxIterations; epoch++) { int e = epoch; //localize the modified closure rands[e - 1] = RandomUtils.Create(Host.Rand.Next()); trainingTasks[e - 1] = (rand, pch) => { - using (var cursor = _args.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create()) + using (var cursor = _options.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create()) { var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights); while (cursor.MoveNext()) @@ -1905,9 +1888,9 @@ private protected override TScalarPredictor TrainCore(IChannel ch, RoleMappedDat { int iter = 0; pch.SetHeader(new ProgressHeader(new[] { "Loss", "Improvement" }, new[] { "iterations" }), - entry => entry.SetProgress(0, iter, _args.MaxIterations)); + entry => entry.SetProgress(0, iter, _options.MaxIterations)); // Synchorized SGD. - for (int i = 0; i < _args.MaxIterations; i++) + for (int i = 0; i < _options.MaxIterations; i++) { iter = i; trainingTasks[i](rands[i], pch); @@ -1925,7 +1908,7 @@ private protected override TScalarPredictor TrainCore(IChannel ch, RoleMappedDat // REVIEW: technically, we could keep track of how many iterations have started, // but this needs more synchronization than Parallel.For allows. }); - Parallel.For(0, _args.MaxIterations, pOptions, i => trainingTasks[i](rands[i], pch)); + Parallel.For(0, _options.MaxIterations, pOptions, i => trainingTasks[i](rands[i], pch)); //note that P.Invoke will wait until all tasks finish } } @@ -1947,14 +1930,14 @@ private protected override void CheckLabel(RoleMappedData examples, out int weig } [TlcModule.EntryPoint(Name = "Trainers.StochasticGradientDescentBinaryClassifier", Desc = "Train an Hogwild SGD binary model.", UserName = UserNameValue, ShortName = ShortName)] - public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input) + public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainHogwildSGD"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new StochasticGradientDescentClassificationTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn), @@ -1974,14 +1957,14 @@ public static partial class Sdca ShortName = SdcaBinaryTrainer.LoadNameValue, XmlInclude = new[] { @"", @"" })] - public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaBinaryTrainer.Arguments input) + public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaBinaryTrainer.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainSDCA"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new SdcaBinaryTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples); diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index 0a069f5d8c..873ee81d41 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -19,7 +19,7 @@ using Microsoft.ML.Training; using Float = System.Single; -[assembly: LoadableClass(SdcaMultiClassTrainer.Summary, typeof(SdcaMultiClassTrainer), typeof(SdcaMultiClassTrainer.Arguments), +[assembly: LoadableClass(SdcaMultiClassTrainer.Summary, typeof(SdcaMultiClassTrainer), typeof(SdcaMultiClassTrainer.Options), new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, SdcaMultiClassTrainer.UserNameValue, SdcaMultiClassTrainer.LoadNameValue, @@ -29,14 +29,14 @@ namespace Microsoft.ML.Trainers { // SDCA linear multiclass trainer. /// - public class SdcaMultiClassTrainer : SdcaTrainerBase, MulticlassLogisticRegressionModelParameters> + public class SdcaMultiClassTrainer : SdcaTrainerBase, MulticlassLogisticRegressionModelParameters> { public const string LoadNameValue = "SDCAMC"; public const string UserNameValue = "Fast Linear Multi-class Classification (SA-SDCA)"; public const string ShortName = "sasdcamc"; internal const string Summary = "The SDCA linear multi-class classification trainer."; - public sealed class Arguments : ArgumentsBase + public sealed class Options : ArgumentsBase { [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory(); @@ -57,21 +57,16 @@ public sealed class Arguments : ArgumentsBase /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . - public SdcaMultiClassTrainer(IHostEnvironment env, + internal SdcaMultiClassTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weights = null, ISupportSdcaClassificationLoss loss = null, float? l2Const = null, float? l1Threshold = null, - int? maxIterations = null, - Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights), advancedSettings, - l2Const, l1Threshold, maxIterations) + int? maxIterations = null) + : base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights), + l2Const, l1Threshold, maxIterations) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); @@ -79,19 +74,19 @@ public SdcaMultiClassTrainer(IHostEnvironment env, Loss = _loss; } - internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args, + internal SdcaMultiClassTrainer(IHostEnvironment env, Options options, string featureColumn, string labelColumn, string weightColumn = null) - : base(env, args, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + : base(env, options, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); - _loss = args.LossFunction.CreateComponent(env); + _loss = options.LossFunction.CreateComponent(env); Loss = _loss; } - internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args) - : this(env, args, args.FeatureColumn, args.LabelColumn) + internal SdcaMultiClassTrainer(IHostEnvironment env, Options options) + : this(env, options, options.FeatureColumn, options.LabelColumn) { } @@ -455,14 +450,14 @@ public static partial class Sdca ShortName = SdcaMultiClassTrainer.ShortName, XmlInclude = new[] { @"", @"" })] - public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, SdcaMultiClassTrainer.Arguments input) + public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, SdcaMultiClassTrainer.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainSDCA"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new SdcaMultiClassTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 8700f69605..b8a8b286e1 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -15,7 +15,7 @@ using Microsoft.ML.Trainers; using Microsoft.ML.Training; -[assembly: LoadableClass(SdcaRegressionTrainer.Summary, typeof(SdcaRegressionTrainer), typeof(SdcaRegressionTrainer.Arguments), +[assembly: LoadableClass(SdcaRegressionTrainer.Summary, typeof(SdcaRegressionTrainer), typeof(SdcaRegressionTrainer.Options), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, SdcaRegressionTrainer.UserNameValue, SdcaRegressionTrainer.LoadNameValue, @@ -24,19 +24,19 @@ namespace Microsoft.ML.Trainers { /// - public sealed class SdcaRegressionTrainer : SdcaTrainerBase, LinearRegressionModelParameters> + public sealed class SdcaRegressionTrainer : SdcaTrainerBase, LinearRegressionModelParameters> { internal const string LoadNameValue = "SDCAR"; internal const string UserNameValue = "Fast Linear Regression (SA-SDCA)"; internal const string ShortName = "sasdcar"; internal const string Summary = "The SDCA linear regression trainer."; - public sealed class Arguments : ArgumentsBase + public sealed class Options : ArgumentsBase { [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] public ISupportSdcaRegressionLossFactory LossFunction = new SquaredLossFactory(); - public Arguments() + public Options() { // Using a higher default tolerance for better RMS. ConvergenceTolerance = 0.01f; @@ -61,21 +61,16 @@ public Arguments() /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . - public SdcaRegressionTrainer(IHostEnvironment env, + internal SdcaRegressionTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weights = null, ISupportSdcaRegressionLoss loss = null, float? l2Const = null, float? l1Threshold = null, - int? maxIterations = null, - Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeR4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights), advancedSettings, - l2Const, l1Threshold, maxIterations) + int? maxIterations = null) + : base(env, featureColumn, TrainerUtils.MakeR4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights), + l2Const, l1Threshold, maxIterations) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); @@ -83,18 +78,18 @@ public SdcaRegressionTrainer(IHostEnvironment env, Loss = _loss; } - internal SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null) - : base(env, args, TrainerUtils.MakeR4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + internal SdcaRegressionTrainer(IHostEnvironment env, Options options, string featureColumn, string labelColumn, string weightColumn = null) + : base(env, options, TrainerUtils.MakeR4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); - _loss = args.LossFunction.CreateComponent(env); + _loss = options.LossFunction.CreateComponent(env); Loss = _loss; } - internal SdcaRegressionTrainer(IHostEnvironment env, Arguments args) - : this(env, args, args.FeatureColumn, args.LabelColumn) + internal SdcaRegressionTrainer(IHostEnvironment env, Options options) + : this(env, options, options.FeatureColumn, options.LabelColumn) { } @@ -178,14 +173,14 @@ public static partial class Sdca ShortName = SdcaRegressionTrainer.ShortName, XmlInclude = new[] { @"", @"" })] - public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, SdcaRegressionTrainer.Arguments input) + public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, SdcaRegressionTrainer.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainSDCA"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new SdcaRegressionTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); } diff --git a/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs b/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs index c5018748f8..2b62f817e0 100644 --- a/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs +++ b/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs @@ -13,7 +13,7 @@ namespace Microsoft.ML { using LRArguments = LogisticRegression.Arguments; - using SgdArguments = StochasticGradientDescentClassificationTrainer.Arguments; + using SgdOptions = StochasticGradientDescentClassificationTrainer.Options; /// /// TrainerEstimator extension methods. @@ -31,20 +31,33 @@ public static class StandardLearnersCatalog /// The initial learning rate used by SGD. /// The L2 regularization constant. /// The loss function to use. - /// A delegate to apply all the advanced arguments to the algorithm. public static StochasticGradientDescentClassificationTrainer StochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weights = null, - int maxIterations = SgdArguments.Defaults.MaxIterations, - double initLearningRate = SgdArguments.Defaults.InitLearningRate, - float l2Weight = SgdArguments.Defaults.L2Weight, - ISupportClassificationLossFactory loss = null, - Action advancedSettings = null) + int maxIterations = SgdOptions.Defaults.MaxIterations, + double initLearningRate = SgdOptions.Defaults.InitLearningRate, + float l2Weight = SgdOptions.Defaults.L2Weight, + ISupportClassificationLossFactory loss = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new StochasticGradientDescentClassificationTrainer(env, labelColumn, featureColumn, weights, maxIterations, initLearningRate, l2Weight, loss); + } + + /// + /// Predict a target using a linear binary classification model trained with the trainer. + /// + /// The binary classificaiton context trainer object. + /// Advanced arguments to the algorithm. + public static StochasticGradientDescentClassificationTrainer StochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + SgdOptions options) { Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(options, nameof(options)); + var env = CatalogUtils.GetEnvironment(ctx); - return new StochasticGradientDescentClassificationTrainer(env, labelColumn, featureColumn, weights, maxIterations, initLearningRate, l2Weight, loss, advancedSettings); + return new StochasticGradientDescentClassificationTrainer(env, options); } /// @@ -58,10 +71,6 @@ public static StochasticGradientDescentClassificationTrainer StochasticGradientD /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. /// The custom loss, if unspecified will be . - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . public static SdcaRegressionTrainer StochasticDualCoordinateAscent(this RegressionContext.RegressionTrainers ctx, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -69,12 +78,26 @@ public static SdcaRegressionTrainer StochasticDualCoordinateAscent(this Regressi ISupportSdcaRegressionLoss loss = null, float? l2Const = null, float? l1Threshold = null, - int? maxIterations = null, - Action advancedSettings = null) + int? maxIterations = null) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new SdcaRegressionTrainer(env, labelColumn, featureColumn, weights, loss, l2Const, l1Threshold, maxIterations, advancedSettings); + return new SdcaRegressionTrainer(env, labelColumn, featureColumn, weights, loss, l2Const, l1Threshold, maxIterations); + } + + /// + /// Predict a target using a linear regression model trained with the SDCA trainer. + /// + /// The regression context trainer object. + /// Advanced arguments to the algorithm. + public static SdcaRegressionTrainer StochasticDualCoordinateAscent(this RegressionContext.RegressionTrainers ctx, + SdcaRegressionTrainer.Options options) + { + Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(options, nameof(options)); + + var env = CatalogUtils.GetEnvironment(ctx); + return new SdcaRegressionTrainer(env, options); } /// @@ -88,16 +111,6 @@ public static SdcaRegressionTrainer StochasticDualCoordinateAscent(this Regressi /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . - /// - /// - /// - /// /// /// /// advancedSettings = null - ) + int? maxIterations = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new SdcaBinaryTrainer(env, labelColumn, featureColumn, weights, loss, l2Const, l1Threshold, maxIterations); + } + + /// + /// Predict a target using a linear binary classification model trained with the SDCA trainer. + /// + /// The binary classification context trainer object. + /// Advanced arguments to the algorithm. + public static SdcaBinaryTrainer StochasticDualCoordinateAscent( + this BinaryClassificationContext.BinaryClassificationTrainers ctx, + SdcaBinaryTrainer.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(options, nameof(options)); + var env = CatalogUtils.GetEnvironment(ctx); - return new SdcaBinaryTrainer(env, labelColumn, featureColumn, weights, loss, l2Const, l1Threshold, maxIterations, advancedSettings); + return new SdcaBinaryTrainer(env, options); } /// @@ -132,10 +159,6 @@ public static SdcaBinaryTrainer StochasticDualCoordinateAscent( /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . public static SdcaMultiClassTrainer StochasticDualCoordinateAscent(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -143,12 +166,26 @@ public static SdcaMultiClassTrainer StochasticDualCoordinateAscent(this Multicla ISupportSdcaClassificationLoss loss = null, float? l2Const = null, float? l1Threshold = null, - int? maxIterations = null, - Action advancedSettings = null) + int? maxIterations = null) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new SdcaMultiClassTrainer(env, labelColumn, featureColumn, weights, loss, l2Const, l1Threshold, maxIterations, advancedSettings); + return new SdcaMultiClassTrainer(env, labelColumn, featureColumn, weights, loss, l2Const, l1Threshold, maxIterations); + } + + /// + /// Predict a target using a linear multiclass classification model trained with the SDCA trainer. + /// + /// The multiclass classification context trainer object. + /// Advanced arguments to the algorithm. + public static SdcaMultiClassTrainer StochasticDualCoordinateAscent(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + SdcaMultiClassTrainer.Options options) + { + Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(options, nameof(options)); + + var env = CatalogUtils.GetEnvironment(ctx); + return new SdcaMultiClassTrainer(env, options); } /// @@ -177,6 +214,7 @@ public static AveragedPerceptronTrainer AveragedPerceptron( Action advancedSettings = null) { Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); return new AveragedPerceptronTrainer(env, labelColumn, featureColumn, weights, lossFunction ?? new LogLoss(), learningRate, decreaseLearningRate, l2RegularizerWeight, numIterations, advancedSettings); } diff --git a/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs index 8743d8e77c..0754484c5b 100644 --- a/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs @@ -26,10 +26,6 @@ public static class SdcaStaticExtensions /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. /// The custom loss, if unspecified will be . - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -48,7 +44,6 @@ public static Scalar Sdca(this RegressionContext.RegressionTrainers ctx, float? l1Threshold = null, int? maxIterations = null, ISupportSdcaRegressionLoss loss = null, - Action advancedSettings = null, Action onFit = null) { Contracts.CheckValue(label, nameof(label)); @@ -58,13 +53,58 @@ public static Scalar Sdca(this RegressionContext.RegressionTrainers ctx, Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative, if specified."); Contracts.CheckParam(!(maxIterations < 1), nameof(maxIterations), "Must be positive if specified"); Contracts.CheckValueOrNull(loss); - Contracts.CheckValueOrNull(advancedSettings); Contracts.CheckValueOrNull(onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => { - var trainer = new SdcaRegressionTrainer(env, labelName, featuresName, weightsName, loss, l2Const, l1Threshold, maxIterations, advancedSettings); + var trainer = new SdcaRegressionTrainer(env, labelName, featuresName, weightsName, loss, l2Const, l1Threshold, maxIterations); + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, weights); + + return rec.Score; + } + + /// + /// Predict a target using a linear regression model trained with the SDCA trainer. + /// + /// The regression context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// Advanced arguments to the algorithm. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to + /// be informed about what was learnt. + /// The predicted output. + /// + /// + /// + /// + public static Scalar Sdca(this RegressionContext.RegressionTrainers ctx, + Scalar label, Vector features, Scalar weights, + SdcaRegressionTrainer.Options options, + Action onFit = null) + { + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckValue(features, nameof(features)); + Contracts.CheckValueOrNull(weights); + Contracts.CheckValueOrNull(options); + Contracts.CheckValueOrNull(onFit); + + var rec = new TrainerEstimatorReconciler.Regression( + (env, labelName, featuresName, weightsName) => + { + options.LabelColumn = labelName; + options.FeatureColumn = featuresName; + + var trainer = new SdcaRegressionTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -83,10 +123,6 @@ public static Scalar Sdca(this RegressionContext.RegressionTrainers ctx, /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -106,7 +142,6 @@ public static (Scalar score, Scalar probability, Scalar pred float? l2Const = null, float? l1Threshold = null, int? maxIterations = null, - Action advancedSettings = null, Action onFit = null) { Contracts.CheckValue(label, nameof(label)); @@ -115,13 +150,69 @@ public static (Scalar score, Scalar probability, Scalar pred Contracts.CheckParam(!(l2Const < 0), nameof(l2Const), "Must not be negative, if specified."); Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative, if specified."); Contracts.CheckParam(!(maxIterations < 1), nameof(maxIterations), "Must be positive if specified"); - Contracts.CheckValueOrNull(advancedSettings); Contracts.CheckValueOrNull(onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => { - var trainer = new SdcaBinaryTrainer(env, labelName, featuresName, weightsName, loss: new LogLoss(), l2Const, l1Threshold, maxIterations, advancedSettings); + var trainer = new SdcaBinaryTrainer(env, labelName, featuresName, weightsName, loss: new LogLoss(), l2Const, l1Threshold, maxIterations); + if (onFit != null) + { + return trainer.WithOnFitDelegate(trans => + { + // Under the default log-loss we assume a calibrated predictor. + var model = trans.Model; + var cali = (ParameterMixingCalibratedPredictor)model; + var pred = (LinearBinaryModelParameters)cali.SubPredictor; + onFit(pred, cali); + }); + } + return trainer; + }, label, features, weights); + + return rec.Output; + } + + /// + /// Predict a target using a linear binary classification model trained with the SDCA trainer, and log-loss. + /// + /// The binary classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// Advanced arguments to the algorithm. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained, as well as the calibrator on top of that model. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted binary classification score (which will range + /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label. + /// + /// + /// + /// + public static (Scalar score, Scalar probability, Scalar predictedLabel) Sdca( + this BinaryClassificationContext.BinaryClassificationTrainers ctx, + Scalar label, Vector features, Scalar weights, + SdcaBinaryTrainer.Options options, + Action onFit = null) + { + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckValue(features, nameof(features)); + Contracts.CheckValueOrNull(weights); + Contracts.CheckValueOrNull(options); + Contracts.CheckValueOrNull(onFit); + + var rec = new TrainerEstimatorReconciler.BinaryClassifier( + (env, labelName, featuresName, weightsName) => + { + options.LabelColumn = labelName; + options.FeatureColumn = featuresName; + + var trainer = new SdcaBinaryTrainer(env, options); if (onFit != null) { return trainer.WithOnFitDelegate(trans => @@ -152,10 +243,6 @@ public static (Scalar score, Scalar probability, Scalar pred /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -163,7 +250,6 @@ public static (Scalar score, Scalar probability, Scalar pred /// result in any way; it is only a way for the caller to be informed about what was learnt. /// The set of output columns including in order the predicted binary classification score (which will range /// from negative to positive infinity), and the predicted label. - /// public static (Scalar score, Scalar predictedLabel) Sdca( this BinaryClassificationContext.BinaryClassificationTrainers ctx, Scalar label, Vector features, @@ -172,7 +258,6 @@ public static (Scalar score, Scalar predictedLabel) Sdca( float? l2Const = null, float? l1Threshold = null, int? maxIterations = null, - Action advancedSettings = null, Action onFit = null ) { @@ -183,7 +268,6 @@ public static (Scalar score, Scalar predictedLabel) Sdca( Contracts.CheckParam(!(l2Const < 0), nameof(l2Const), "Must not be negative, if specified."); Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative, if specified."); Contracts.CheckParam(!(maxIterations < 1), nameof(maxIterations), "Must be positive if specified"); - Contracts.CheckValueOrNull(advancedSettings); Contracts.CheckValueOrNull(onFit); bool hasProbs = loss is LogLoss; @@ -191,7 +275,66 @@ public static (Scalar score, Scalar predictedLabel) Sdca( var rec = new TrainerEstimatorReconciler.BinaryClassifierNoCalibration( (env, labelName, featuresName, weightsName) => { - var trainer = new SdcaBinaryTrainer(env, labelName, featuresName, weightsName, loss, l2Const, l1Threshold, maxIterations, advancedSettings); + var trainer = new SdcaBinaryTrainer(env, labelName, featuresName, weightsName, loss, l2Const, l1Threshold, maxIterations); + if (onFit != null) + { + return trainer.WithOnFitDelegate(trans => + { + var model = trans.Model; + if (model is ParameterMixingCalibratedPredictor cali) + onFit((LinearBinaryModelParameters)cali.SubPredictor); + else + onFit((LinearBinaryModelParameters)model); + }); + } + return trainer; + }, label, features, weights, hasProbs); + + return rec.Output; + } + + /// + /// Predict a target using a linear binary classification model trained with the SDCA trainer, and a custom loss. + /// Note that because we cannot be sure that all loss functions will produce naturally calibrated outputs, setting + /// a custom loss function will not produce a calibrated probability column. + /// + /// The binary classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The custom loss. + /// The optional example weights. + /// Advanced arguments to the algorithm. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained, as well as the calibrator on top of that model. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted binary classification score (which will range + /// from negative to positive infinity), and the predicted label. + public static (Scalar score, Scalar predictedLabel) Sdca( + this BinaryClassificationContext.BinaryClassificationTrainers ctx, + Scalar label, Vector features, + Scalar weights, + ISupportSdcaClassificationLoss loss, + SdcaBinaryTrainer.Options options, + Action onFit = null + ) + { + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckValue(features, nameof(features)); + Contracts.CheckValueOrNull(weights); + Contracts.CheckValueOrNull(options); + Contracts.CheckValueOrNull(onFit); + + bool hasProbs = loss is LogLoss; + + var rec = new TrainerEstimatorReconciler.BinaryClassifierNoCalibration( + (env, labelName, featuresName, weightsName) => + { + options.FeatureColumn = featuresName; + options.LabelColumn = labelName; + + var trainer = new SdcaBinaryTrainer(env, options); if (onFit != null) { return trainer.WithOnFitDelegate(trans => @@ -220,10 +363,6 @@ public static (Scalar score, Scalar predictedLabel) Sdca( /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. - /// The settings here will override the ones provided in the direct method signature, - /// if both are present and have different values. - /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -239,7 +378,6 @@ public static (Vector score, Key predictedLabel) float? l2Const = null, float? l1Threshold = null, int? maxIterations = null, - Action advancedSettings = null, Action onFit = null) { Contracts.CheckValue(label, nameof(label)); @@ -249,13 +387,55 @@ public static (Vector score, Key predictedLabel) Contracts.CheckParam(!(l2Const < 0), nameof(l2Const), "Must not be negative, if specified."); Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative, if specified."); Contracts.CheckParam(!(maxIterations < 1), nameof(maxIterations), "Must be positive if specified"); - Contracts.CheckValueOrNull(advancedSettings); Contracts.CheckValueOrNull(onFit); var rec = new TrainerEstimatorReconciler.MulticlassClassifier( (env, labelName, featuresName, weightsName) => { - var trainer = new SdcaMultiClassTrainer(env, labelName, featuresName, weightsName, loss, l2Const, l1Threshold, maxIterations, advancedSettings); + var trainer = new SdcaMultiClassTrainer(env, labelName, featuresName, weightsName, loss, l2Const, l1Threshold, maxIterations); + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, weights); + + return rec.Output; + } + + /// + /// Predict a target using a linear multiclass classification model trained with the SDCA trainer. + /// + /// The multiclass classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// Advanced arguments to the algorithm. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. + public static (Vector score, Key predictedLabel) + Sdca(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + Key label, + Vector features, + Scalar weights, + SdcaMultiClassTrainer.Options options, + Action onFit = null) + { + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckValue(features, nameof(features)); + Contracts.CheckValueOrNull(weights); + Contracts.CheckValueOrNull(options); + Contracts.CheckValueOrNull(onFit); + + var rec = new TrainerEstimatorReconciler.MulticlassClassifier( + (env, labelName, featuresName, weightsName) => + { + options.LabelColumn = labelName; + options.FeatureColumn = featuresName; + + var trainer = new SdcaMultiClassTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; diff --git a/src/Microsoft.ML.StaticPipe/SgdStatic.cs b/src/Microsoft.ML.StaticPipe/SgdStatic.cs index 5e5bd98d08..e5c579cc12 100644 --- a/src/Microsoft.ML.StaticPipe/SgdStatic.cs +++ b/src/Microsoft.ML.StaticPipe/SgdStatic.cs @@ -3,13 +3,15 @@ // See the LICENSE file in the project root for more information. using System; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.StaticPipe.Runtime; using Microsoft.ML.Trainers; namespace Microsoft.ML.StaticPipe { - using Arguments = StochasticGradientDescentClassificationTrainer.Arguments; + using Options = StochasticGradientDescentClassificationTrainer.Options; /// /// Binary Classification trainer estimators. @@ -27,7 +29,6 @@ public static class SgdStaticExtensions /// The initial learning rate used by SGD. /// The L2 regularization constant. /// The loss function to use. - /// A delegate to apply all the advanced arguments to the algorithm. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -38,17 +39,55 @@ public static (Scalar score, Scalar probability, Scalar pred Scalar label, Vector features, Scalar weights = null, - int maxIterations = Arguments.Defaults.MaxIterations, - double initLearningRate = Arguments.Defaults.InitLearningRate, - float l2Weight = Arguments.Defaults.L2Weight, + int maxIterations = Options.Defaults.MaxIterations, + double initLearningRate = Options.Defaults.InitLearningRate, + float l2Weight = Options.Defaults.L2Weight, ISupportClassificationLossFactory loss = null, - Action advancedSettings = null, Action> onFit = null) { var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => { - var trainer = new StochasticGradientDescentClassificationTrainer(env, labelName, featuresName, weightsName, maxIterations, initLearningRate, l2Weight, loss, advancedSettings); + var trainer = new StochasticGradientDescentClassificationTrainer(env, labelName, featuresName, weightsName, maxIterations, initLearningRate, l2Weight, loss); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + + }, label, features, weights); + + return rec.Output; + } + + /// + /// Predict a target using a linear binary classification model trained with the trainer. + /// + /// The binary classificaiton context trainer object. + /// The name of the label column. + /// The name of the feature column. + /// The name for the example weight column. + /// Advanced arguments to the algorithm. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to + /// be informed about what was learnt. + /// The predicted output. + public static (Scalar score, Scalar probability, Scalar predictedLabel) StochasticGradientDescentClassificationTrainer(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + Scalar label, + Vector features, + Scalar weights, + Options options, + Action> onFit = null) + { + var rec = new TrainerEstimatorReconciler.BinaryClassifier( + (env, labelName, featuresName, weightsName) => + { + options.FeatureColumn = featuresName; + options.LabelColumn = labelName; + options.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); + + var trainer = new StochasticGradientDescentClassificationTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); diff --git a/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs b/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs index 7bac497fb9..ba0a8456c0 100644 --- a/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs +++ b/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.StaticPipe.Runtime; using Microsoft.ML.Trainers.FastTree; @@ -93,6 +95,10 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => { + options.LabelColumn = labelName; + options.FeatureColumn = featuresName; + options.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); + var trainer = new FastTreeRegressionTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); @@ -185,6 +191,10 @@ public static (Scalar score, Scalar probability, Scalar pred var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => { + options.LabelColumn = labelName; + options.FeatureColumn = featuresName; + options.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); + var trainer = new FastTreeBinaryClassificationTrainer(env, options); if (onFit != null) @@ -265,6 +275,11 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => { + options.LabelColumn = labelName; + options.FeatureColumn = featuresName; + options.GroupIdColumn = groupIdName; + options.WeightColumn = weightsName != null ? Optional.Explicit(weightsName) : Optional.Implicit(DefaultColumnNames.Weight); + var trainer = new FastTreeRankingTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 74d896997f..948aeafd9a 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -65,10 +65,10 @@ Trainers.OnlineGradientDescentRegressor Train a Online gradient descent perceptr Trainers.OrdinaryLeastSquaresRegressor Train an OLS regression model. Microsoft.ML.Trainers.HalLearners.OlsLinearRegressionTrainer TrainRegression Microsoft.ML.Trainers.HalLearners.OlsLinearRegressionTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.PcaAnomalyDetector Train an PCA Anomaly model. Microsoft.ML.Trainers.PCA.RandomizedPcaTrainer TrainPcaAnomaly Microsoft.ML.Trainers.PCA.RandomizedPcaTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+AnomalyDetectionOutput Trainers.PoissonRegressor Train an Poisson regression model. Microsoft.ML.Trainers.PoissonRegression TrainRegression Microsoft.ML.Trainers.PoissonRegression+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput -Trainers.StochasticDualCoordinateAscentBinaryClassifier Train an SDCA binary model. Microsoft.ML.Trainers.Sdca TrainBinary Microsoft.ML.Trainers.SdcaBinaryTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Trainers.Sdca TrainMultiClass Microsoft.ML.Trainers.SdcaMultiClassTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput -Trainers.StochasticDualCoordinateAscentRegressor The SDCA linear regression trainer. Microsoft.ML.Trainers.Sdca TrainRegression Microsoft.ML.Trainers.SdcaRegressionTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput -Trainers.StochasticGradientDescentBinaryClassifier Train an Hogwild SGD binary model. Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer TrainBinary Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput +Trainers.StochasticDualCoordinateAscentBinaryClassifier Train an SDCA binary model. Microsoft.ML.Trainers.Sdca TrainBinary Microsoft.ML.Trainers.SdcaBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput +Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Trainers.Sdca TrainMultiClass Microsoft.ML.Trainers.SdcaMultiClassTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput +Trainers.StochasticDualCoordinateAscentRegressor The SDCA linear regression trainer. Microsoft.ML.Trainers.Sdca TrainRegression Microsoft.ML.Trainers.SdcaRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput +Trainers.StochasticGradientDescentBinaryClassifier Train an Hogwild SGD binary model. Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer TrainBinary Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.SymSgdBinaryClassifier Train a symbolic SGD. Microsoft.ML.Trainers.SymSgd.SymSgdClassificationTrainer TrainSymSgd Microsoft.ML.Trainers.SymSgd.SymSgdClassificationTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Transforms.ApproximateBootstrapSampler Approximate bootstrap sampling. Microsoft.ML.Transforms.BootstrapSample GetSample Microsoft.ML.Transforms.BootstrapSamplingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.BinaryPredictionScoreColumnsRenamer For binary prediction, it renames the PredictedLabel and Score columns to include the name of the positive class. Microsoft.ML.EntryPoints.ScoreModel RenameBinaryPredictionScoreColumns Microsoft.ML.EntryPoints.ScoreModel+RenameBinaryPredictionScoreColumnsInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs b/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs index f0269f4d7c..76d3d2aea4 100644 --- a/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs +++ b/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs @@ -53,7 +53,8 @@ public void SetupIrisPipeline() IDataView data = reader.Read(_irisDataPath); var pipeline = new ColumnConcatenatingEstimator(env, "Features", new[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" }) - .Append(new SdcaMultiClassTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; })); + .Append(env.MulticlassClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaMultiClassTrainer.Options {NumThreads = 1, ConvergenceTolerance = 1e-2f, })); var model = pipeline.Fit(data); @@ -82,7 +83,8 @@ public void SetupSentimentPipeline() IDataView data = reader.Read(_sentimentDataPath); var pipeline = new TextFeaturizingEstimator(env, "SentimentText", "Features") - .Append(new SdcaBinaryTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; })); + .Append(env.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options {NumThreads = 1, ConvergenceTolerance = 1e-2f, })); var model = pipeline.Fit(data); @@ -110,7 +112,8 @@ public void SetupBreastCancerPipeline() IDataView data = reader.Read(_breastCancerDataPath); - var pipeline = new SdcaBinaryTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; }); + var pipeline = env.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { NumThreads = 1, ConvergenceTolerance = 1e-2f, }); var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs index 7e8dcd97c5..afbc5e1d4f 100644 --- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs +++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs @@ -68,7 +68,7 @@ private TransformerChain (r.label, score: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2, - onFit: p => pred = p, advancedSettings: s => s.NumThreads = 1))); + .Append(r => (r.label, score: ctx.Trainers.Sdca(r.label, r.features, null, + new SdcaRegressionTrainer.Options() { MaxIterations = 2, NumThreads = 1 }, + onFit: p => pred = p))); var pipe = reader.Append(est); @@ -88,7 +89,8 @@ public void SdcaRegressionNameCollision() separator: ';', hasHeader: true); var est = reader.MakeNewEstimator() - .Append(r => (r.label, r.Score, score: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2, advancedSettings: s => s.NumThreads = 1))); + .Append(r => (r.label, r.Score, score: ctx.Trainers.Sdca(r.label, r.features, null, + new SdcaRegressionTrainer.Options() { MaxIterations = 2, NumThreads = 1 }))); var pipe = reader.Append(est); @@ -119,10 +121,9 @@ public void SdcaBinaryClassification() ParameterMixingCalibratedPredictor cali = null; var est = reader.MakeNewEstimator() - .Append(r => (r.label, preds: ctx.Trainers.Sdca(r.label, r.features, - maxIterations: 2, - onFit: (p, c) => { pred = p; cali = c; }, - advancedSettings: s => s.NumThreads = 1))); + .Append(r => (r.label, preds: ctx.Trainers.Sdca(r.label, r.features, null, + new SdcaBinaryTrainer.Options { MaxIterations = 2, NumThreads = 1 }, + onFit: (p, c) => { pred = p; cali = c; }))); var pipe = reader.Append(est); @@ -167,10 +168,9 @@ public void SdcaBinaryClassificationNoCalibration() // With a custom loss function we no longer get calibrated predictions. var est = reader.MakeNewEstimator() - .Append(r => (r.label, preds: ctx.Trainers.Sdca(r.label, r.features, - maxIterations: 2, - loss: loss, onFit: p => pred = p, - advancedSettings: s => s.NumThreads = 1))); + .Append(r => (r.label, preds: ctx.Trainers.Sdca(r.label, r.features, null, loss, + new SdcaBinaryTrainer.Options { MaxIterations = 2, NumThreads = 1 }, + onFit: p => pred = p))); var pipe = reader.Append(est); @@ -951,10 +951,9 @@ public void HogwildSGDBinaryClassification() IPredictorWithFeatureWeights pred = null; var est = reader.MakeNewEstimator() - .Append(r => (r.label, preds: ctx.Trainers.StochasticGradientDescentClassificationTrainer(r.label, r.features, - l2Weight: 0, - onFit: (p) => { pred = p; }, - advancedSettings: s => s.NumThreads = 1))); + .Append(r => (r.label, preds: ctx.Trainers.StochasticGradientDescentClassificationTrainer(r.label, r.features, null, + new StochasticGradientDescentClassificationTrainer.Options { L2Weight = 0, NumThreads = 1 }, + onFit: (p) => { pred = p; }))); var pipe = reader.Append(est); @@ -1105,8 +1104,8 @@ public void MultiClassLightGbmStaticPipelineWithInMemoryData() // The predicted label below should be with probability 0.922597349. Console.WriteLine("Our predicted label to this example is {0} with probability {1}", - nativeLabels[(int)nativePrediction.PredictedLabelIndex-1], - nativePrediction.Scores[(int)nativePrediction.PredictedLabelIndex-1]); + nativeLabels[(int)nativePrediction.PredictedLabelIndex - 1], + nativePrediction.Scores[(int)nativePrediction.PredictedLabelIndex - 1]); } } } \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/FeatureContributionTests.cs b/test/Microsoft.ML.Tests/FeatureContributionTests.cs index e41cf6954d..cac36a2e31 100644 --- a/test/Microsoft.ML.Tests/FeatureContributionTests.cs +++ b/test/Microsoft.ML.Tests/FeatureContributionTests.cs @@ -9,6 +9,7 @@ using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Microsoft.ML.Training; using Microsoft.ML.Transforms; using Xunit; @@ -72,7 +73,8 @@ public void TestFastTreeTweedieRegression() [Fact] public void TestSDCARegression() { - TestFeatureContribution(ML.Regression.Trainers.StochasticDualCoordinateAscent(advancedSettings: args => { args.NumThreads = 1; }), GetSparseDataset(numberOfInstances: 100), "SDCARegression"); + TestFeatureContribution(ML.Regression.Trainers.StochasticDualCoordinateAscent( + new SdcaRegressionTrainer.Options { NumThreads = 1, }), GetSparseDataset(numberOfInstances: 100), "SDCARegression"); } [Fact] @@ -146,13 +148,16 @@ public void TestLightGbmBinary() [Fact] public void TestSDCABinary() { - TestFeatureContribution(ML.BinaryClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: args => { args.NumThreads = 1; }), GetSparseDataset(TaskType.BinaryClassification, 100), "SDCABinary"); + TestFeatureContribution(ML.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { NumThreads = 1, }), GetSparseDataset(TaskType.BinaryClassification, 100), "SDCABinary"); } [Fact] public void TestSGDBinary() { - TestFeatureContribution(ML.BinaryClassification.Trainers.StochasticGradientDescent(advancedSettings: args => { args.NumThreads = 1; }), GetSparseDataset(TaskType.BinaryClassification, 100), "SGDBinary"); + TestFeatureContribution(ML.BinaryClassification.Trainers.StochasticGradientDescent( + new StochasticGradientDescentClassificationTrainer.Options { NumThreads = 1}), + GetSparseDataset(TaskType.BinaryClassification, 100), "SGDBinary"); } [Fact] diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs index 6016cfbdb0..11b34438e6 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -26,7 +27,8 @@ void CrossValidation() // Pipeline. var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") - .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: (s) => { s.ConvergenceTolerance = 1f; s.NumThreads = 1; })); + .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { ConvergenceTolerance = 1f, NumThreads = 1, })); var cvResult = ml.BinaryClassification.CrossValidate(data, pipeline); } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs index 94f53e65f2..3d8164e353 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs @@ -5,6 +5,7 @@ using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Conversions; using Xunit; @@ -31,7 +32,8 @@ void DecomposableTrainAndPredict() var pipeline = new ColumnConcatenatingEstimator (ml, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new ValueToKeyMappingEstimator(ml, "Label"), TransformerScope.TrainTest) - .Append(ml.MulticlassClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features",advancedSettings: s => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; })) + .Append(ml.MulticlassClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaMultiClassTrainer.Options { MaxIterations = 100, Shuffle = true, NumThreads = 1, })) .Append(new KeyToValueMappingEstimator(ml, "PredictedLabel")); var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs index 5d890cf7b8..c5bbce6f12 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -24,7 +25,8 @@ public void Evaluation() // Pipeline. var pipeline = ml.Data.CreateTextReader(TestDatasets.Sentiment.GetLoaderColumns(), hasHeader: true) .Append(ml.Transforms.Text.FeaturizeText("SentimentText", "Features")) - .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1)); + .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { NumThreads = 1 })); // Train. var readerModel = pipeline.Fit(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs index 84bd6691e9..94160d8caa 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs @@ -6,6 +6,7 @@ using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Conversions; using Xunit; @@ -40,7 +41,8 @@ void Extensibility() var pipeline = new ColumnConcatenatingEstimator (ml, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new CustomMappingEstimator(ml, action, null), TransformerScope.TrainTest) .Append(new ValueToKeyMappingEstimator(ml, "Label"), TransformerScope.TrainTest) - .Append(ml.MulticlassClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; })) + .Append(ml.MulticlassClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaMultiClassTrainer.Options { MaxIterations = 100, Shuffle = true, NumThreads = 1 })) .Append(new KeyToValueMappingEstimator(ml, "PredictedLabel")); var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs index afae98455c..e27d6360a0 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Data; using Microsoft.ML.Data.IO; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -37,7 +38,8 @@ void FileBasedSavingOfData() DataSaverUtils.SaveDataView(ch, saver, trainData, file); } - var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1); + var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { NumThreads = 1 }); var loadedTrainData = new BinaryLoader(ml, new BinaryLoader.Arguments(), new MultiFileSource(path)); // Train. diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs index 023cff8d76..eae28ef7f4 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -32,7 +33,8 @@ public void IntrospectiveTraining() var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") .AppendCacheCheckpoint(ml) - .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1)); + .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { NumThreads = 1 })); // Train. var model = pipeline.Fit(data); @@ -40,7 +42,6 @@ public void IntrospectiveTraining() // Get feature weights. VBuffer weights = default; model.LastTransformer.Model.GetFeatureWeights(ref weights); - } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs index 70b6b0bbb5..b3c33095d0 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs @@ -24,7 +24,8 @@ public void Metacomponents() var ml = new MLContext(); var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.irisData.trainFilename), separatorChar: ','); - var sdcaTrainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; }); + var sdcaTrainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { MaxIterations = 100, Shuffle = true, NumThreads = 1, }); var pipeline = new ColumnConcatenatingEstimator (ml, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new ValueToKeyMappingEstimator(ml, "Label"), TransformerScope.TrainTest) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs index e710956462..cb6074578d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs @@ -5,6 +5,7 @@ using System.Threading.Tasks; using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -29,7 +30,8 @@ void MultithreadedPrediction() // Pipeline. var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") .AppendCacheCheckpoint(ml) - .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1)); + .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { NumThreads = 1 })); // Train. var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs index 5b9482ae6e..e6d4592a84 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -29,7 +30,9 @@ public void ReconfigurablePrediction() var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") .Fit(data); - var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: (s) => s.NumThreads = 1); + var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { NumThreads = 1 }); + var trainData = ml.Data.Cache(pipeline.Transform(data)); // Cache the data right before the trainer to boost the training speed. var model = trainer.Fit(trainData); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs index af2c7ffa99..838c840f5b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs @@ -5,6 +5,7 @@ using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -26,7 +27,8 @@ public void SimpleTrainAndPredict() // Pipeline. var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") .AppendCacheCheckpoint(ml) - .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1)); + .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { NumThreads = 1 })); // Train. var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs index d9cfa732a9..8375a59665 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -28,7 +29,8 @@ public void TrainSaveModelAndPredict() // Pipeline. var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") .AppendCacheCheckpoint(ml) - .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1)); + .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { NumThreads = 1 })); // Train. var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs index 4954ecea4a..cac3c7f9be 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -30,7 +31,9 @@ public void TrainWithInitialPredictor() var trainData = ml.Data.Cache(pipeline.Fit(data).Transform(data)); // Train the first predictor. - var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features",advancedSettings: s => s.NumThreads = 1); + var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { NumThreads = 1 }); + var firstModel = trainer.Fit(trainData); // Train the second predictor on the same data. diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index f3906ca806..5ebbb6e03a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Data; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; +using Microsoft.ML.Trainers; using Xunit; using Xunit.Abstractions; @@ -30,7 +31,8 @@ public void TrainAndPredictIrisModelTest() var pipe = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(mlContext.Transforms.Normalize("Features")) .AppendCacheCheckpoint(mlContext) - .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1)); + .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaMultiClassTrainer.Options { NumThreads = 1 })); // Read training and test data sets string dataPath = GetDataPath(TestDatasets.iris.trainFilename); diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index ff38fbebe5..2b4c989e4c 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Data; +using Microsoft.ML.Trainers; using Xunit; namespace Microsoft.ML.Scenarios @@ -36,7 +37,8 @@ public void TrainAndPredictIrisModelWithStringLabelTest() .Append(mlContext.Transforms.Normalize("Features")) .Append(mlContext.Transforms.Conversion.MapValueToKey("IrisPlantType", "Label"), TransformerScope.TrainTest) .AppendCacheCheckpoint(mlContext) - .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1)) + .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaMultiClassTrainer.Options { NumThreads = 1 })) .Append(mlContext.Transforms.Conversion.MapKeyToValue(("PredictedLabel", "Plant"))); // Train the pipeline diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs index 430ae1c28e..8140220de4 100644 --- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs @@ -41,7 +41,7 @@ public void TensorFlowTransforCifarEndToEndTest() .Append(new ColumnConcatenatingEstimator(mlContext, "Features", "Output")) .Append(new ValueToKeyMappingEstimator(mlContext, "Label")) .AppendCacheCheckpoint(mlContext) - .Append(new SdcaMultiClassTrainer(mlContext)); + .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent()); var transformer = pipeEstimator.Fit(data); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs index 646eb7b148..bbafc04150 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Xunit; namespace Microsoft.ML.Scenarios @@ -28,7 +29,8 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest() var pipe = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(mlContext.Transforms.Normalize("Features")) .AppendCacheCheckpoint(mlContext) - .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1)); + .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaMultiClassTrainer.Options { NumThreads = 1 })); // Read training and test data sets string dataPath = GetDataPath(TestDatasets.iris.trainFilename); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 8371e3e415..9ed4676c42 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -43,7 +43,8 @@ public void OVAWithAllConstructorArgs() public void OVAUncalibrated() { var (pipeline, data) = GetMultiClassPipeline(); - var sdcaTrainer = new SdcaBinaryTrainer(Env, "Label", "Features", advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; s.Calibrator = null; }); + var sdcaTrainer = ML.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { MaxIterations = 100, Shuffle = true, NumThreads = 1, Calibrator = null }); pipeline = pipeline.Append(new Ova(Env, sdcaTrainer, useProbabilities: false)) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); @@ -60,7 +61,9 @@ public void Pkpd() { var (pipeline, data) = GetMultiClassPipeline(); - var sdcaTrainer = new SdcaBinaryTrainer(Env, "Label", "Features", advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; }); + var sdcaTrainer = ML.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { MaxIterations = 100, Shuffle = true, NumThreads = 1 }); + pipeline = pipeline.Append(new Pkpd(Env, sdcaTrainer)) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); @@ -74,7 +77,14 @@ public void MetacomponentsFeaturesRenamed() var data = new TextLoader(Env, TestDatasets.irisData.GetLoaderColumns(), separatorChar: ',') .Read(GetDataPath(TestDatasets.irisData.trainFilename)); - var sdcaTrainer = new SdcaBinaryTrainer(Env, "Label", "Vars", advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; }); + var sdcaTrainer = ML.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { + LabelColumn = "Label", + FeatureColumn = "Vars", + MaxIterations = 100, + Shuffle = true, + NumThreads = 1, }); + var pipeline = new ColumnConcatenatingEstimator(Env, "Vars", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new ValueToKeyMappingEstimator(Env, "Label"), TransformerScope.TrainTest) .Append(new Ova(Env, sdcaTrainer)) diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs index f32242cedc..bef612edeb 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs @@ -18,13 +18,16 @@ public void SdcaWorkout() var data = TextLoaderStatic.CreateReader(Env, ctx => (Label: ctx.LoadFloat(0), Features: ctx.LoadFloat(1, 10))) .Read(dataPath).Cache(); - var binaryTrainer = new SdcaBinaryTrainer(Env, "Label", "Features", advancedSettings: (s) => s.ConvergenceTolerance = 1e-2f); + var binaryTrainer = ML.BinaryClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaBinaryTrainer.Options { ConvergenceTolerance = 1e-2f }); TestEstimatorCore(binaryTrainer, data.AsDynamic); - var regressionTrainer = new SdcaRegressionTrainer(Env, "Label", "Features", advancedSettings: (s) => s.ConvergenceTolerance = 1e-2f); + var regressionTrainer = ML.Regression.Trainers.StochasticDualCoordinateAscent( + new SdcaRegressionTrainer.Options { ConvergenceTolerance = 1e-2f }); TestEstimatorCore(regressionTrainer, data.AsDynamic); - var mcTrainer = new SdcaMultiClassTrainer(Env, "Label", "Features", advancedSettings: (s) => s.ConvergenceTolerance = 1e-2f); + var mcTrainer = ML.MulticlassClassification.Trainers.StochasticDualCoordinateAscent( + new SdcaMultiClassTrainer.Options { ConvergenceTolerance = 1e-2f }); TestEstimatorCore(mcTrainer, data.AsDynamic); Done(); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs index 30dc030aed..53edbc9e33 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs @@ -32,7 +32,7 @@ public void TestEstimatorSymSgdInitPredictor() (var pipe, var dataView) = GetBinaryClassificationPipeline(); var transformedData = pipe.Fit(dataView).Transform(dataView); - var initPredictor = new SdcaBinaryTrainer(Env, "Label", "Features").Fit(transformedData); + var initPredictor = ML.BinaryClassification.Trainers.StochasticDualCoordinateAscent().Fit(transformedData); var data = initPredictor.Transform(transformedData); var withInitPredictor = new SymSgdClassificationTrainer(Env, "Label", "Features").Train(transformedData, initialPredictor: initPredictor.Model); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs index eef3bd0aeb..6b5cbef050 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs @@ -89,7 +89,7 @@ public void KMeansEstimator() public void TestEstimatorHogwildSGD() { (IEstimator pipe, IDataView dataView) = GetBinaryClassificationPipeline(); - var trainer = new StochasticGradientDescentClassificationTrainer(Env, "Label", "Features"); + var trainer = ML.BinaryClassification.Trainers.StochasticGradientDescent(); var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView);