Skip to content

Commit 96fd88e

Browse files
committed
Converting AveragePerceptron, OGD and Linear SVM to estimators.
1 parent ff8e21b commit 96fd88e

File tree

7 files changed

+97
-26
lines changed

7 files changed

+97
-26
lines changed

src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstim
1717
where TTransformer : IPredictionTransformer<TModel>
1818
where TModel : IPredictor
1919
{
20+
/// <summary>
21+
/// A standard string to use in errors or warnings by subclasses, to communicate the idea that no valid
22+
/// instances were able to be found.
23+
/// </summary>
24+
protected const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features.";
25+
2026
/// <summary>
2127
/// The feature column that the trainer expects.
2228
/// </summary>

src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,7 @@ protected override void SaveCore(ModelSaveContext ctx)
487487
ctx.SetVersionInfo(GetVersionInfo());
488488
}
489489

490-
public override PredictionKind PredictionKind {
491-
get { return PredictionKind.BinaryClassification; }
492-
}
490+
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
493491

494492
/// <summary>
495493
/// Combine a bunch of models into one by averaging parameters

src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using Microsoft.ML.Runtime.Internal.Utilities;
1212
using Microsoft.ML.Runtime.Numeric;
1313
using Microsoft.ML.Runtime.Internal.Internallearn;
14+
using Microsoft.ML.Core.Data;
1415

1516
// TODO: Check if it works properly if Averaged is set to false
1617

@@ -52,9 +53,10 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments
5253
public Float AveragedTolerance = (Float)1e-2;
5354
}
5455

55-
public abstract class AveragedLinearTrainer<TArguments, TPredictor> : OnlineLinearTrainer<TArguments, TPredictor>
56+
public abstract class AveragedLinearTrainer<TArguments, TTransformer, TModel> : OnlineLinearTrainer<TArguments, TTransformer, TModel>
5657
where TArguments : AveragedLinearArguments
57-
where TPredictor : IPredictorProducing<Float>
58+
where TTransformer : IPredictionTransformer<TModel>
59+
where TModel : IPredictor
5860
{
5961
protected IScalarOutputLoss LossFunction;
6062

@@ -74,8 +76,8 @@ public abstract class AveragedLinearTrainer<TArguments, TPredictor> : OnlineLine
7476
// We'll keep a few things global to prevent garbage collection
7577
protected int NumNoUpdates;
7678

77-
protected AveragedLinearTrainer(TArguments args, IHostEnvironment env, string name)
78-
: base(args, env, name)
79+
protected AveragedLinearTrainer(TArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
80+
: base(args, env, name, label)
7981
{
8082
Contracts.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), UserErrorPositive);
8183
Contracts.CheckUserArg(!args.ResetWeightsAfterXExamples.HasValue || args.ResetWeightsAfterXExamples > 0, nameof(args.ResetWeightsAfterXExamples), UserErrorPositive);

src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
using Microsoft.ML.Runtime.Learners;
1414
using Microsoft.ML.Runtime.Numeric;
1515
using Microsoft.ML.Runtime.Training;
16+
using Microsoft.ML.Core.Data;
1617

1718
[assembly: LoadableClass(AveragedPerceptronTrainer.Summary, typeof(AveragedPerceptronTrainer), typeof(AveragedPerceptronTrainer.Arguments),
1819
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
@@ -29,8 +30,7 @@ namespace Microsoft.ML.Runtime.Learners
2930
// - Feature normalization. By default, rescaling between min and max values for every feature
3031
// - Prediction calibration to produce probabilities. Off by default, if on, uses exponential (aka Platt) calibration.
3132
/// <include file='doc.xml' path='doc/members/member[@name="AP"]/*' />
32-
public sealed class AveragedPerceptronTrainer :
33-
AveragedLinearTrainer<AveragedPerceptronTrainer.Arguments, LinearBinaryPredictor>
33+
public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<AveragedPerceptronTrainer.Arguments, BinaryPredictionTransformer<LinearBinaryPredictor> , LinearBinaryPredictor>
3434
{
3535
public const string LoadNameValue = "AveragedPerceptron";
3636
internal const string UserNameValue = "Averaged Perceptron";
@@ -49,22 +49,35 @@ public class Arguments : AveragedLinearArguments
4949
public int MaxCalibrationExamples = 1000000;
5050
}
5151

52-
protected override bool NeedCalibration => true;
53-
5452
public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
55-
: base(args, env, UserNameValue)
53+
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
5654
{
5755
LossFunction = Args.LossFunction.CreateComponent(env);
56+
57+
OutputColumns = new[]
58+
{
59+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
60+
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
61+
};
5862
}
5963

60-
public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
64+
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
65+
66+
protected override bool NeedCalibration => true;
67+
68+
protected override SchemaShape.Column[] OutputColumns { get; }
6169

6270
protected override void CheckLabel(RoleMappedData data)
6371
{
6472
Contracts.AssertValue(data);
6573
data.CheckBinaryLabel();
6674
}
6775

76+
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
77+
{
78+
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
79+
}
80+
6881
protected override LinearBinaryPredictor CreatePredictor()
6982
{
7083
Contracts.Assert(WeightsScale == 1);
@@ -87,6 +100,9 @@ protected override LinearBinaryPredictor CreatePredictor()
87100
return new LinearBinaryPredictor(Host, ref weights, bias);
88101
}
89102

103+
protected override BinaryPredictionTransformer<LinearBinaryPredictor> MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema)
104+
=> new BinaryPredictionTransformer<LinearBinaryPredictor>(Host, model, trainSchema, FeatureColumn.Name);
105+
90106
[TlcModule.EntryPoint(Name = "Trainers.AveragedPerceptronBinaryClassifier",
91107
Desc = Summary,
92108
UserName = UserNameValue,

src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@
2626

2727
namespace Microsoft.ML.Runtime.Learners
2828
{
29+
using Microsoft.ML.Core.Data;
2930
using TPredictor = LinearBinaryPredictor;
3031

3132
/// <summary>
3233
/// Linear SVM that implements PEGASOS for training. See: http://ttic.uchicago.edu/~shai/papers/ShalevSiSr07.pdf
3334
/// </summary>
34-
public sealed class LinearSvm : OnlineLinearTrainer<LinearSvm.Arguments, TPredictor>
35+
public sealed class LinearSvm : OnlineLinearTrainer<LinearSvm.Arguments, BinaryPredictionTransformer<LinearBinaryPredictor>, LinearBinaryPredictor>
3536
{
3637
public const string LoadNameValue = "LinearSVM";
3738
public const string ShortName = "svm";
@@ -83,13 +84,21 @@ public sealed class Arguments : OnlineLinearArguments
8384
protected override bool NeedCalibration => true;
8485

8586
public LinearSvm(IHostEnvironment env, Arguments args)
86-
: base(args, env, UserNameValue)
87+
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
8788
{
8889
Contracts.CheckUserArg(args.Lambda > 0, nameof(args.Lambda), UserErrorPositive);
8990
Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive);
91+
92+
OutputColumns = new[]
93+
{
94+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
95+
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
96+
};
9097
}
9198

92-
public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
99+
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
100+
101+
protected override SchemaShape.Column[] OutputColumns { get; }
93102

94103
protected override void CheckLabel(RoleMappedData data)
95104
{
@@ -105,6 +114,11 @@ protected override Float Margin(ref VBuffer<Float> feat)
105114
return Bias + VectorUtils.DotProduct(ref feat, ref Weights) * WeightsScale;
106115
}
107116

117+
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
118+
{
119+
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
120+
}
121+
108122
protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)
109123
{
110124
base.InitCore(ch, numFeatures, predictor);
@@ -237,5 +251,8 @@ public static CommonOutputs.BinaryClassificationOutput TrainLinearSvm(IHostEnvir
237251
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
238252
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
239253
}
254+
255+
protected override BinaryPredictionTransformer<LinearBinaryPredictor> MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema)
256+
=> new BinaryPredictionTransformer<LinearBinaryPredictor>(Host, model, trainSchema, FeatureColumn.Name);
240257
}
241258
}

src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525

2626
namespace Microsoft.ML.Runtime.Learners
2727
{
28+
using Microsoft.ML.Core.Data;
2829
using TPredictor = LinearRegressionPredictor;
2930

3031
/// <include file='doc.xml' path='doc/members/member[@name="OGD"]/*' />
31-
public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer<OnlineGradientDescentTrainer.Arguments, TPredictor>
32+
public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer<OnlineGradientDescentTrainer.Arguments, RegressionPredictionTransformer<LinearRegressionPredictor>, LinearRegressionPredictor>
3233
{
3334
internal const string LoadNameValue = "OnlineGradientDescent";
3435
internal const string UserNameValue = "Stochastic Gradient Descent (Regression)";
@@ -53,19 +54,26 @@ public Arguments()
5354
}
5455

5556
public OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args)
56-
: base(args, env, UserNameValue)
57+
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
5758
{
5859
LossFunction = args.LossFunction.CreateComponent(env);
60+
61+
OutputColumns = new[]
62+
{
63+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false)
64+
};
5965
}
6066

61-
public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
67+
public override PredictionKind PredictionKind => PredictionKind.Regression;
68+
69+
protected override SchemaShape.Column[] OutputColumns { get; }
6270

6371
protected override void CheckLabel(RoleMappedData data)
6472
{
6573
data.CheckRegressionLabel();
6674
}
6775

68-
protected override TPredictor CreatePredictor()
76+
protected override LinearRegressionPredictor CreatePredictor()
6977
{
7078
Contracts.Assert(WeightsScale == 1);
7179
VBuffer<Float> weights = default(VBuffer<Float>);
@@ -85,6 +93,11 @@ protected override TPredictor CreatePredictor()
8593
return new LinearRegressionPredictor(Host, ref weights, bias);
8694
}
8795

96+
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
97+
{
98+
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, true);
99+
}
100+
88101
[TlcModule.EntryPoint(Name = "Trainers.OnlineGradientDescentRegressor",
89102
Desc = "Train a Online gradient descent perceptron.",
90103
UserName = UserNameValue,
@@ -102,5 +115,8 @@ public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment en
102115
() => new OnlineGradientDescentTrainer(host, input),
103116
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
104117
}
118+
119+
protected override RegressionPredictionTransformer<TPredictor> MakeTransformer(TPredictor model, ISchema trainSchema)
120+
=> new RegressionPredictionTransformer<LinearRegressionPredictor>(Host, model, trainSchema, FeatureColumn.Name);
105121
}
106122
}

src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Globalization;
7+
using Microsoft.ML.Core.Data;
78
using Microsoft.ML.Runtime.CommandLine;
89
using Microsoft.ML.Runtime.Data;
910
using Microsoft.ML.Runtime.EntryPoints;
@@ -41,11 +42,13 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
4142
public int StreamingCacheSize = 1000000;
4243
}
4344

44-
public abstract class OnlineLinearTrainer<TArguments, TPredictor> : TrainerBase<TPredictor>
45+
public abstract class OnlineLinearTrainer<TArguments, TTransformer, TModel> : TrainerEstimatorBase<TTransformer, TModel>
46+
where TTransformer : IPredictionTransformer<TModel>
47+
where TModel : IPredictor
4548
where TArguments : OnlineLinearArguments
46-
where TPredictor : IPredictorProducing<Float>
4749
{
4850
protected readonly TArguments Args;
51+
protected readonly string Name;
4952

5053
// Initialized by InitCore
5154
protected int NumFeatures;
@@ -74,15 +77,16 @@ public abstract class OnlineLinearTrainer<TArguments, TPredictor> : TrainerBase<
7477

7578
protected virtual bool NeedCalibration => false;
7679

77-
protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name)
78-
: base(env, name)
80+
protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
81+
: base(Contracts.CheckRef(env, nameof(env)).Register(name), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.InitialWeights))
7982
{
8083
Contracts.CheckValue(args, nameof(args));
8184
Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive);
8285
Contracts.CheckUserArg(args.InitWtsDiameter >= 0, nameof(args.InitWtsDiameter), UserErrorNonNegative);
8386
Contracts.CheckUserArg(args.StreamingCacheSize > 0, nameof(args.StreamingCacheSize), UserErrorPositive);
8487

8588
Args = args;
89+
Name = name;
8690
// REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue.
8791
Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true);
8892
}
@@ -111,7 +115,7 @@ protected void ScaleWeightsIfNeeded()
111115
ScaleWeights();
112116
}
113117

114-
public override TPredictor Train(TrainContext context)
118+
protected override TModel TrainModelCore(TrainContext context)
115119
{
116120
Host.CheckValue(context, nameof(context));
117121
var initPredictor = context.InitialPredictor;
@@ -148,10 +152,22 @@ public override TPredictor Train(TrainContext context)
148152
return CreatePredictor();
149153
}
150154

151-
protected abstract TPredictor CreatePredictor();
155+
protected abstract TModel CreatePredictor();
152156

153157
protected abstract void CheckLabel(RoleMappedData data);
154158

159+
private static SchemaShape.Column MakeWeightColumn(string weightColumn)
160+
{
161+
if (weightColumn == null)
162+
return null;
163+
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
164+
}
165+
166+
private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
167+
{
168+
return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
169+
}
170+
155171
protected virtual void TrainCore(IChannel ch, RoleMappedData data)
156172
{
157173
bool shuffle = Args.Shuffle;

0 commit comments

Comments
 (0)