|
20 | 20 | using System.Threading;
|
21 | 21 | using System.Threading.Tasks;
|
22 | 22 |
|
23 |
| -[assembly: LoadableClass(typeof(LinearClassificationTrainer), typeof(LinearClassificationTrainer.Arguments), |
| 23 | +[assembly: LoadableClass(typeof(StochasticDualCoordinateAscent), typeof(StochasticDualCoordinateAscent.Arguments), |
24 | 24 | new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
|
25 |
| - LinearClassificationTrainer.UserNameValue, |
26 |
| - LinearClassificationTrainer.LoadNameValue, |
| 25 | + StochasticDualCoordinateAscent.UserNameValue, |
| 26 | + StochasticDualCoordinateAscent.LoadNameValue, |
27 | 27 | "LinearClassifier",
|
28 | 28 | "lc",
|
29 | 29 | "sasdca")]
|
@@ -1361,7 +1361,7 @@ public void Add(Double summand)
|
1361 | 1361 | }
|
1362 | 1362 | }
|
1363 | 1363 |
|
1364 |
| - public sealed class LinearClassificationTrainer : SdcaTrainerBase<BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor> |
| 1364 | + public sealed class StochasticDualCoordinateAscent : SdcaTrainerBase<BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor> |
1365 | 1365 | {
|
1366 | 1366 | public const string LoadNameValue = "SDCA";
|
1367 | 1367 | internal const string UserNameValue = "Fast Linear (SA-SDCA)";
|
@@ -1401,7 +1401,7 @@ internal override void Check(IHostEnvironment env)
|
1401 | 1401 |
|
1402 | 1402 | public override TrainerInfo Info { get; }
|
1403 | 1403 |
|
1404 |
| - public LinearClassificationTrainer(IHostEnvironment env, Arguments args, |
| 1404 | + public StochasticDualCoordinateAscent(IHostEnvironment env, Arguments args, |
1405 | 1405 | string featureColumn, string labelColumn, string weightColumn = null)
|
1406 | 1406 | : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), args, MakeFeatureColumn(featureColumn), MakeLabelColumn(labelColumn), MakeWeightColumn(weightColumn))
|
1407 | 1407 | {
|
@@ -1431,7 +1431,7 @@ public LinearClassificationTrainer(IHostEnvironment env, Arguments args,
|
1431 | 1431 |
|
1432 | 1432 | }
|
1433 | 1433 |
|
1434 |
| - public LinearClassificationTrainer(IHostEnvironment env, Arguments args) |
| 1434 | + public StochasticDualCoordinateAscent(IHostEnvironment env, Arguments args) |
1435 | 1435 | : this(env, args, args.FeatureColumn, args.LabelColumn)
|
1436 | 1436 | {
|
1437 | 1437 | }
|
@@ -1903,19 +1903,19 @@ public static partial class Sdca
|
1903 | 1903 | {
|
1904 | 1904 | [TlcModule.EntryPoint(Name = "Trainers.StochasticDualCoordinateAscentBinaryClassifier",
|
1905 | 1905 | Desc = "Train an SDCA binary model.",
|
1906 |
| - UserName = LinearClassificationTrainer.UserNameValue, |
1907 |
| - ShortName = LinearClassificationTrainer.LoadNameValue, |
| 1906 | + UserName = StochasticDualCoordinateAscent.UserNameValue, |
| 1907 | + ShortName = StochasticDualCoordinateAscent.LoadNameValue, |
1908 | 1908 | XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/member[@name=""SDCA""]/*' />",
|
1909 | 1909 | @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/example[@name=""StochasticDualCoordinateAscentBinaryClassifier""]/*'/>" })]
|
1910 |
| - public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, LinearClassificationTrainer.Arguments input) |
| 1910 | + public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, StochasticDualCoordinateAscent.Arguments input) |
1911 | 1911 | {
|
1912 | 1912 | Contracts.CheckValue(env, nameof(env));
|
1913 | 1913 | var host = env.Register("TrainSDCA");
|
1914 | 1914 | host.CheckValue(input, nameof(input));
|
1915 | 1915 | EntryPointUtils.CheckInputArgs(host, input);
|
1916 | 1916 |
|
1917 |
| - return LearnerEntryPointsUtils.Train<LinearClassificationTrainer.Arguments, CommonOutputs.BinaryClassificationOutput>(host, input, |
1918 |
| - () => new LinearClassificationTrainer(host, input), |
| 1917 | + return LearnerEntryPointsUtils.Train<StochasticDualCoordinateAscent.Arguments, CommonOutputs.BinaryClassificationOutput>(host, input, |
| 1918 | + () => new StochasticDualCoordinateAscent(host, input), |
1919 | 1919 | () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
|
1920 | 1920 | calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
|
1921 | 1921 | }
|
|
0 commit comments