|
23 | 23 | using Microsoft.ML.Training;
|
24 | 24 | using Microsoft.ML.Transforms;
|
25 | 25 |
|
26 |
| -[assembly: LoadableClass(typeof(SdcaBinaryTrainer), typeof(SdcaBinaryTrainer.Arguments), |
| 26 | +[assembly: LoadableClass(typeof(SdcaBinaryTrainer), typeof(SdcaBinaryTrainer.Options), |
27 | 27 | new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
|
28 | 28 | SdcaBinaryTrainer.UserNameValue,
|
29 | 29 | SdcaBinaryTrainer.LoadNameValue,
|
@@ -253,21 +253,19 @@ protected enum MetricKind
|
253 | 253 |
|
254 | 254 | private const string RegisterName = nameof(SdcaTrainerBase<TArgs, TTransformer, TModel>);
|
255 | 255 |
|
256 |
| - private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn, Action<TArgs> advancedSettings = null) |
| 256 | + private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn) |
257 | 257 | {
|
258 | 258 | var args = new TArgs();
|
259 | 259 |
|
260 |
| - // Apply the advanced args, if the user supplied any. |
261 |
| - advancedSettings?.Invoke(args); |
262 | 260 | args.FeatureColumn = featureColumn;
|
263 | 261 | args.LabelColumn = labelColumn.Name;
|
264 | 262 | return args;
|
265 | 263 | }
|
266 | 264 |
|
267 | 265 | internal SdcaTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn,
|
268 |
| - SchemaShape.Column weight = default, Action<TArgs> advancedSettings = null, float? l2Const = null, |
| 266 | + SchemaShape.Column weight = default, float? l2Const = null, |
269 | 267 | float? l1Threshold = null, int? maxIterations = null)
|
270 |
| - : this(env, ArgsInit(featureColumn, labelColumn, advancedSettings), labelColumn, weight, l2Const, l1Threshold, maxIterations) |
| 268 | + : this(env, ArgsInit(featureColumn, labelColumn), labelColumn, weight, l2Const, l1Threshold, maxIterations) |
271 | 269 | {
|
272 | 270 | }
|
273 | 271 |
|
@@ -1398,13 +1396,13 @@ public void Add(Double summand)
|
1398 | 1396 | }
|
1399 | 1397 | }
|
1400 | 1398 |
|
1401 |
| - public sealed class SdcaBinaryTrainer : SdcaTrainerBase<SdcaBinaryTrainer.Arguments, BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor> |
| 1399 | + public sealed class SdcaBinaryTrainer : SdcaTrainerBase<SdcaBinaryTrainer.Options, BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor> |
1402 | 1400 | {
|
1403 | 1401 | public const string LoadNameValue = "SDCA";
|
1404 | 1402 |
|
1405 | 1403 | internal const string UserNameValue = "Fast Linear (SA-SDCA)";
|
1406 | 1404 |
|
1407 |
| - public sealed class Arguments : ArgumentsBase |
| 1405 | + public sealed class Options : ArgumentsBase |
1408 | 1406 | {
|
1409 | 1407 | [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
|
1410 | 1408 | public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory();
|
@@ -1449,21 +1447,16 @@ internal override void Check(IHostEnvironment env)
|
1449 | 1447 | /// <param name="l2Const">The L2 regularization hyperparameter.</param>
|
1450 | 1448 | /// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
|
1451 | 1449 | /// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
|
1452 |
| - /// <param name="advancedSettings">A delegate to set more settings. |
1453 |
| - /// The settings here will override the ones provided in the direct method signature, |
1454 |
| - /// if both are present and have different values. |
1455 |
| - /// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param> |
1456 |
| - public SdcaBinaryTrainer(IHostEnvironment env, |
| 1450 | + internal SdcaBinaryTrainer(IHostEnvironment env, |
1457 | 1451 | string labelColumn = DefaultColumnNames.Label,
|
1458 | 1452 | string featureColumn = DefaultColumnNames.Features,
|
1459 | 1453 | string weightColumn = null,
|
1460 | 1454 | ISupportSdcaClassificationLoss loss = null,
|
1461 | 1455 | float? l2Const = null,
|
1462 | 1456 | float? l1Threshold = null,
|
1463 |
| - int? maxIterations = null, |
1464 |
| - Action<Arguments> advancedSettings = null) |
1465 |
| - : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, |
1466 |
| - l2Const, l1Threshold, maxIterations) |
| 1457 | + int? maxIterations = null) |
| 1458 | + : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), |
| 1459 | + l2Const, l1Threshold, maxIterations) |
1467 | 1460 | {
|
1468 | 1461 | Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
|
1469 | 1462 | Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
|
@@ -1503,11 +1496,11 @@ public SdcaBinaryTrainer(IHostEnvironment env,
|
1503 | 1496 | _outputColumns = outCols.ToArray();
|
1504 | 1497 | }
|
1505 | 1498 |
|
1506 |
| - internal SdcaBinaryTrainer(IHostEnvironment env, Arguments args, |
| 1499 | + internal SdcaBinaryTrainer(IHostEnvironment env, Options options, |
1507 | 1500 | string featureColumn, string labelColumn, string weightColumn = null)
|
1508 |
| - : base(env, args, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) |
| 1501 | + : base(env, options, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) |
1509 | 1502 | {
|
1510 |
| - _loss = args.LossFunction.CreateComponent(env); |
| 1503 | + _loss = options.LossFunction.CreateComponent(env); |
1511 | 1504 | Loss = _loss;
|
1512 | 1505 | Info = new TrainerInfo(calibration: !(_loss is LogLoss));
|
1513 | 1506 | _positiveInstanceWeight = Args.PositiveInstanceWeight;
|
@@ -1544,8 +1537,8 @@ internal SdcaBinaryTrainer(IHostEnvironment env, Arguments args,
|
1544 | 1537 |
|
1545 | 1538 | }
|
1546 | 1539 |
|
1547 |
| - public SdcaBinaryTrainer(IHostEnvironment env, Arguments args) |
1548 |
| - : this(env, args, args.FeatureColumn, args.LabelColumn) |
| 1540 | + internal SdcaBinaryTrainer(IHostEnvironment env, Options options) |
| 1541 | + : this(env, options, options.FeatureColumn, options.LabelColumn) |
1549 | 1542 | {
|
1550 | 1543 | }
|
1551 | 1544 |
|
@@ -1731,15 +1724,15 @@ internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env,
|
1731 | 1724 | /// Initializes a new instance of <see cref="StochasticGradientDescentClassificationTrainer"/>
|
1732 | 1725 | /// </summary>
|
1733 | 1726 | /// <param name="env">The environment to use.</param>
|
1734 |
| - /// <param name="args">Advanced arguments to the algorithm.</param> |
1735 |
| - internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Options args) |
1736 |
| - : base(env, args.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) |
| 1727 | + /// <param name="options">Advanced arguments to the algorithm.</param> |
| 1728 | + internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Options options) |
| 1729 | + : base(env, options.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn, options.WeightColumn.IsExplicit)) |
1737 | 1730 | {
|
1738 |
| - args.Check(env); |
1739 |
| - _loss = args.LossFunction.CreateComponent(env); |
| 1731 | + options.Check(env); |
| 1732 | + _loss = options.LossFunction.CreateComponent(env); |
1740 | 1733 | Info = new TrainerInfo(calibration: !(_loss is LogLoss), supportIncrementalTrain: true);
|
1741 |
| - NeedShuffle = args.Shuffle; |
1742 |
| - _args = args; |
| 1734 | + NeedShuffle = options.Shuffle; |
| 1735 | + _args = options; |
1743 | 1736 | }
|
1744 | 1737 |
|
1745 | 1738 | protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
|
@@ -1979,14 +1972,14 @@ public static partial class Sdca
|
1979 | 1972 | ShortName = SdcaBinaryTrainer.LoadNameValue,
|
1980 | 1973 | XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/member[@name=""SDCA""]/*' />",
|
1981 | 1974 | @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/example[@name=""StochasticDualCoordinateAscentBinaryClassifier""]/*'/>" })]
|
1982 |
| - public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaBinaryTrainer.Arguments input) |
| 1975 | + public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaBinaryTrainer.Options input) |
1983 | 1976 | {
|
1984 | 1977 | Contracts.CheckValue(env, nameof(env));
|
1985 | 1978 | var host = env.Register("TrainSDCA");
|
1986 | 1979 | host.CheckValue(input, nameof(input));
|
1987 | 1980 | EntryPointUtils.CheckInputArgs(host, input);
|
1988 | 1981 |
|
1989 |
| - return LearnerEntryPointsUtils.Train<SdcaBinaryTrainer.Arguments, CommonOutputs.BinaryClassificationOutput>(host, input, |
| 1982 | + return LearnerEntryPointsUtils.Train<SdcaBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input, |
1990 | 1983 | () => new SdcaBinaryTrainer(host, input),
|
1991 | 1984 | () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
|
1992 | 1985 | calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
|
|
0 commit comments