Skip to content

Commit 58dbbac

Browse files
committed
Removing the arguments from the generics definition
1 parent 101c2e8 commit 58dbbac

File tree

5 files changed

+19
-11
lines changed

5 files changed

+19
-11
lines changed

src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs

+6-4
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,12 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments
5353
public Float AveragedTolerance = (Float)1e-2;
5454
}
5555

56-
public abstract class AveragedLinearTrainer<TArguments, TTransformer, TModel> : OnlineLinearTrainer<TArguments, TTransformer, TModel>
57-
where TArguments : AveragedLinearArguments
56+
public abstract class AveragedLinearTrainer<TTransformer, TModel> : OnlineLinearTrainer<TTransformer, TModel>
5857
where TTransformer : IPredictionTransformer<TModel>
5958
where TModel : IPredictor
6059
{
60+
protected readonly new AveragedLinearArguments Args;
6161
protected IScalarOutputLoss LossFunction;
62-
6362
protected Float Gain;
6463

6564
// For computing averaged weights and bias (if needed)
@@ -76,15 +75,18 @@ public abstract class AveragedLinearTrainer<TArguments, TTransformer, TModel> :
7675
// We'll keep a few things global to prevent garbage collection
7776
protected int NumNoUpdates;
7877

79-
protected AveragedLinearTrainer(TArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
78+
protected AveragedLinearTrainer(AveragedLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
8079
: base(args, env, name, label)
8180
{
8281
Contracts.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), UserErrorPositive);
8382
Contracts.CheckUserArg(!args.ResetWeightsAfterXExamples.HasValue || args.ResetWeightsAfterXExamples > 0, nameof(args.ResetWeightsAfterXExamples), UserErrorPositive);
83+
8484
// Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible.
8585
Contracts.CheckUserArg(0 <= args.L2RegularizerWeight && args.L2RegularizerWeight < 0.5, nameof(args.L2RegularizerWeight), "must be in range [0, 0.5)");
8686
Contracts.CheckUserArg(args.RecencyGain >= 0, nameof(args.RecencyGain), UserErrorNonNegative);
8787
Contracts.CheckUserArg(args.AveragedTolerance >= 0, nameof(args.AveragedTolerance), UserErrorNonNegative);
88+
89+
Args = args;
8890
}
8991

9092
protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)

src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ namespace Microsoft.ML.Runtime.Learners
3030
// - Feature normalization. By default, rescaling between min and max values for every feature
3131
// - Prediction calibration to produce probabilities. Off by default, if on, uses exponential (aka Platt) calibration.
3232
/// <include file='doc.xml' path='doc/members/member[@name="AP"]/*' />
33-
public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<AveragedPerceptronTrainer.Arguments, BinaryPredictionTransformer<LinearBinaryPredictor> , LinearBinaryPredictor>
33+
public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPredictionTransformer<LinearBinaryPredictor> , LinearBinaryPredictor>
3434
{
3535
public const string LoadNameValue = "AveragedPerceptron";
3636
internal const string UserNameValue = "Averaged Perceptron";
3737
internal const string ShortName = "ap";
3838
internal const string Summary = "Averaged Perceptron Binary Classifier.";
3939

40+
internal new readonly Arguments Args;
41+
4042
public class Arguments : AveragedLinearArguments
4143
{
4244
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
@@ -52,6 +54,7 @@ public class Arguments : AveragedLinearArguments
5254
public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
5355
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
5456
{
57+
Args = args;
5558
LossFunction = Args.LossFunction.CreateComponent(env);
5659

5760
OutputColumns = new[]

src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace Microsoft.ML.Runtime.Learners
3232
/// <summary>
3333
/// Linear SVM that implements PEGASOS for training. See: http://ttic.uchicago.edu/~shai/papers/ShalevSiSr07.pdf
3434
/// </summary>
35-
public sealed class LinearSvm : OnlineLinearTrainer<LinearSvm.Arguments, BinaryPredictionTransformer<LinearBinaryPredictor>, LinearBinaryPredictor>
35+
public sealed class LinearSvm : OnlineLinearTrainer<BinaryPredictionTransformer<LinearBinaryPredictor>, LinearBinaryPredictor>
3636
{
3737
public const string LoadNameValue = "LinearSVM";
3838
public const string ShortName = "svm";
@@ -42,6 +42,8 @@ public sealed class LinearSvm : OnlineLinearTrainer<LinearSvm.Arguments, BinaryP
4242
+ "and all the negative examples are on the other. After this mapping, quadratic programming is used to find the separating hyperplane that maximizes the "
4343
+ "margin, i.e., the minimal distance between it and the instances.";
4444

45+
internal new readonly Arguments Args;
46+
4547
public sealed class Arguments : OnlineLinearArguments
4648
{
4749
[Argument(ArgumentType.AtMostOnce, HelpText = "Regularizer constant", ShortName = "lambda", SortOrder = 50)]
@@ -89,6 +91,8 @@ public LinearSvm(IHostEnvironment env, Arguments args)
8991
Contracts.CheckUserArg(args.Lambda > 0, nameof(args.Lambda), UserErrorPositive);
9092
Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive);
9193

94+
Args = args;
95+
9296
OutputColumns = new[]
9397
{
9498
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),

src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace Microsoft.ML.Runtime.Learners
2929
using TPredictor = LinearRegressionPredictor;
3030

3131
/// <include file='doc.xml' path='doc/members/member[@name="OGD"]/*' />
32-
public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer<OnlineGradientDescentTrainer.Arguments, RegressionPredictionTransformer<LinearRegressionPredictor>, LinearRegressionPredictor>
32+
public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer<RegressionPredictionTransformer<LinearRegressionPredictor>, LinearRegressionPredictor>
3333
{
3434
internal const string LoadNameValue = "OnlineGradientDescent";
3535
internal const string UserNameValue = "Stochastic Gradient Descent (Regression)";

src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs

+3-4
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,11 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
4242
public int StreamingCacheSize = 1000000;
4343
}
4444

45-
public abstract class OnlineLinearTrainer<TArguments, TTransformer, TModel> : TrainerEstimatorBase<TTransformer, TModel>
45+
public abstract class OnlineLinearTrainer<TTransformer, TModel> : TrainerEstimatorBase<TTransformer, TModel>
4646
where TTransformer : IPredictionTransformer<TModel>
4747
where TModel : IPredictor
48-
where TArguments : OnlineLinearArguments
4948
{
50-
protected readonly TArguments Args;
49+
protected readonly OnlineLinearArguments Args;
5150
protected readonly string Name;
5251

5352
// Initialized by InitCore
@@ -77,7 +76,7 @@ public abstract class OnlineLinearTrainer<TArguments, TTransformer, TModel> : Tr
7776

7877
protected virtual bool NeedCalibration => false;
7978

80-
protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
79+
protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
8180
: base(Contracts.CheckRef(env, nameof(env)).Register(name), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.InitialWeights))
8281
{
8382
Contracts.CheckValue(args, nameof(args));

0 commit comments

Comments
 (0)