Skip to content

Ap, LinearSVM, OGD as estimators #849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstim
where TTransformer : IPredictionTransformer<TModel>
where TModel : IPredictor
{
/// <summary>
/// A standard string to use in errors or warnings by subclasses, to communicate the idea that no valid
/// instances were able to be found.
/// </summary>
protected const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features.";

/// <summary>
/// The feature column that the trainer expects.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,7 @@ protected override void CheckLabel(RoleMappedData examples, out int weightSetCou
protected override BinaryPredictionTransformer<TScalarPredictor> MakeTransformer(TScalarPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<TScalarPredictor>(Host, model, trainSchema, FeatureColumn.Name);

public BinaryPredictionTransformer<TScalarPredictor> Train(IDataView trainData, IDataView validationData) => TrainTransformer(trainData, validationData);
public BinaryPredictionTransformer<TScalarPredictor> Train(IDataView trainData, IDataView validationData = null, IPredictor initialPredictor = null) => TrainTransformer(trainData, validationData, initialPredictor);
}

public sealed class StochasticGradientDescentClassificationTrainer :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,7 @@ protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

public override PredictionKind PredictionKind {
get { return PredictionKind.BinaryClassification; }
}
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

/// <summary>
/// Combine a bunch of models into one by averaging parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Float = System.Single;

using System;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
Expand Down Expand Up @@ -52,12 +53,12 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments
public Float AveragedTolerance = (Float)1e-2;
}

public abstract class AveragedLinearTrainer<TArguments, TPredictor> : OnlineLinearTrainer<TArguments, TPredictor>
where TArguments : AveragedLinearArguments
where TPredictor : IPredictorProducing<Float>
public abstract class AveragedLinearTrainer<TTransformer, TModel> : OnlineLinearTrainer<TTransformer, TModel>
where TTransformer : IPredictionTransformer<TModel>
where TModel : IPredictor
{
protected readonly new AveragedLinearArguments Args;
protected IScalarOutputLoss LossFunction;

protected Float Gain;

// For computing averaged weights and bias (if needed)
Expand All @@ -74,15 +75,18 @@ public abstract class AveragedLinearTrainer<TArguments, TPredictor> : OnlineLine
// We'll keep a few things global to prevent garbage collection
protected int NumNoUpdates;

protected AveragedLinearTrainer(TArguments args, IHostEnvironment env, string name)
: base(args, env, name)
protected AveragedLinearTrainer(AveragedLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
: base(args, env, name, label)
{
Contracts.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), UserErrorPositive);
Contracts.CheckUserArg(!args.ResetWeightsAfterXExamples.HasValue || args.ResetWeightsAfterXExamples > 0, nameof(args.ResetWeightsAfterXExamples), UserErrorPositive);

// Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible.
Contracts.CheckUserArg(0 <= args.L2RegularizerWeight && args.L2RegularizerWeight < 0.5, nameof(args.L2RegularizerWeight), "must be in range [0, 0.5)");
Contracts.CheckUserArg(args.RecencyGain >= 0, nameof(args.RecencyGain), UserErrorNonNegative);
Contracts.CheckUserArg(args.AveragedTolerance >= 0, nameof(args.AveragedTolerance), UserErrorNonNegative);

Args = args;
}

protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using Float = System.Single;

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
Expand All @@ -29,14 +30,15 @@ namespace Microsoft.ML.Runtime.Learners
// - Feature normalization. By default, rescaling between min and max values for every feature
// - Prediction calibration to produce probabilities. Off by default, if on, uses exponential (aka Platt) calibration.
/// <include file='doc.xml' path='doc/members/member[@name="AP"]/*' />
public sealed class AveragedPerceptronTrainer :
AveragedLinearTrainer<AveragedPerceptronTrainer.Arguments, LinearBinaryPredictor>
public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPredictionTransformer<LinearBinaryPredictor> , LinearBinaryPredictor>
{
public const string LoadNameValue = "AveragedPerceptron";
internal const string UserNameValue = "Averaged Perceptron";
internal const string ShortName = "ap";
internal const string Summary = "Averaged Perceptron Binary Classifier.";

private readonly Arguments _args;

public class Arguments : AveragedLinearArguments
{
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
Expand All @@ -49,30 +51,45 @@ public class Arguments : AveragedLinearArguments
public int MaxCalibrationExamples = 1000000;
}

protected override bool NeedCalibration => true;

public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
: base(args, env, UserNameValue)
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
{
LossFunction = Args.LossFunction.CreateComponent(env);
_args = args;
LossFunction = _args.LossFunction.CreateComponent(env);

OutputColumns = new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
};
}

public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

protected override bool NeedCalibration => true;

protected override SchemaShape.Column[] OutputColumns { get; }

protected override void CheckLabel(RoleMappedData data)
{
Contracts.AssertValue(data);
data.CheckBinaryLabel();
}

private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
}

protected override LinearBinaryPredictor CreatePredictor()
{
Contracts.Assert(WeightsScale == 1);

VBuffer<Float> weights = default(VBuffer<Float>);
Float bias;

if (!Args.Averaged)
if (!_args.Averaged)
{
Weights.CopyTo(ref weights);
bias = Bias;
Expand All @@ -87,6 +104,9 @@ protected override LinearBinaryPredictor CreatePredictor()
return new LinearBinaryPredictor(Host, ref weights, bias);
}

protected override BinaryPredictionTransformer<LinearBinaryPredictor> MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<LinearBinaryPredictor>(Host, model, trainSchema, FeatureColumn.Name);

[TlcModule.EntryPoint(Name = "Trainers.AveragedPerceptronBinaryClassifier",
Desc = Summary,
UserName = UserNameValue,
Expand Down
28 changes: 25 additions & 3 deletions src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@

namespace Microsoft.ML.Runtime.Learners
{
using Microsoft.ML.Core.Data;
using TPredictor = LinearBinaryPredictor;

/// <summary>
/// Linear SVM that implements PEGASOS for training. See: http://ttic.uchicago.edu/~shai/papers/ShalevSiSr07.pdf
/// </summary>
public sealed class LinearSvm : OnlineLinearTrainer<LinearSvm.Arguments, TPredictor>
public sealed class LinearSvm : OnlineLinearTrainer<BinaryPredictionTransformer<LinearBinaryPredictor>, LinearBinaryPredictor>
{
public const string LoadNameValue = "LinearSVM";
public const string ShortName = "svm";
Expand All @@ -41,6 +42,8 @@ public sealed class LinearSvm : OnlineLinearTrainer<LinearSvm.Arguments, TPredic
+ "and all the negative examples are on the other. After this mapping, quadratic programming is used to find the separating hyperplane that maximizes the "
+ "margin, i.e., the minimal distance between it and the instances.";

internal new readonly Arguments Args;

public sealed class Arguments : OnlineLinearArguments
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Regularizer constant", ShortName = "lambda", SortOrder = 50)]
Expand Down Expand Up @@ -83,13 +86,24 @@ public sealed class Arguments : OnlineLinearArguments
protected override bool NeedCalibration => true;

public LinearSvm(IHostEnvironment env, Arguments args)
: base(args, env, UserNameValue)
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
{
Contracts.CheckUserArg(args.Lambda > 0, nameof(args.Lambda), UserErrorPositive);
Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive);

Args = args;

OutputColumns = new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
};
}

public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

protected override SchemaShape.Column[] OutputColumns { get; }

protected override void CheckLabel(RoleMappedData data)
{
Expand All @@ -105,6 +119,11 @@ protected override Float Margin(ref VBuffer<Float> feat)
return Bias + VectorUtils.DotProduct(ref feat, ref Weights) * WeightsScale;
}

private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
}

protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)
{
base.InitCore(ch, numFeatures, predictor);
Expand Down Expand Up @@ -237,5 +256,8 @@ public static CommonOutputs.BinaryClassificationOutput TrainLinearSvm(IHostEnvir
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
}

protected override BinaryPredictionTransformer<LinearBinaryPredictor> MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<LinearBinaryPredictor>(Host, model, trainSchema, FeatureColumn.Name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

using Float = System.Single;

using System;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
Expand All @@ -25,10 +25,9 @@

namespace Microsoft.ML.Runtime.Learners
{
using TPredictor = LinearRegressionPredictor;

/// <include file='doc.xml' path='doc/members/member[@name="OGD"]/*' />
public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer<OnlineGradientDescentTrainer.Arguments, TPredictor>
public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer<RegressionPredictionTransformer<LinearRegressionPredictor>, LinearRegressionPredictor>
{
internal const string LoadNameValue = "OnlineGradientDescent";
internal const string UserNameValue = "Stochastic Gradient Descent (Regression)";
Expand All @@ -53,19 +52,26 @@ public Arguments()
}

public OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args)
: base(args, env, UserNameValue)
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
{
LossFunction = args.LossFunction.CreateComponent(env);

OutputColumns = new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false)
};
}

public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
public override PredictionKind PredictionKind => PredictionKind.Regression;

protected override SchemaShape.Column[] OutputColumns { get; }

protected override void CheckLabel(RoleMappedData data)
{
data.CheckRegressionLabel();
}

protected override TPredictor CreatePredictor()
protected override LinearRegressionPredictor CreatePredictor()
{
Contracts.Assert(WeightsScale == 1);
VBuffer<Float> weights = default(VBuffer<Float>);
Expand All @@ -85,6 +91,11 @@ protected override TPredictor CreatePredictor()
return new LinearRegressionPredictor(Host, ref weights, bias);
}

private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}

[TlcModule.EntryPoint(Name = "Trainers.OnlineGradientDescentRegressor",
Desc = "Train a Online gradient descent perceptron.",
UserName = UserNameValue,
Expand All @@ -102,5 +113,8 @@ public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment en
() => new OnlineGradientDescentTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}

protected override RegressionPredictionTransformer<LinearRegressionPredictor> MakeTransformer(LinearRegressionPredictor model, ISchema trainSchema)
=> new RegressionPredictionTransformer<LinearRegressionPredictor>(Host, model, trainSchema, FeatureColumn.Name);
}
}
34 changes: 25 additions & 9 deletions src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Float = System.Single;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @Zruty0 meant remove it altogether, but that's all right, we can do a sweep for this later.

using System;
using System.Globalization;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
Expand All @@ -15,7 +18,6 @@

namespace Microsoft.ML.Runtime.Learners
{
using Float = System.Single;

public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
{
Expand All @@ -41,11 +43,12 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
public int StreamingCacheSize = 1000000;
}

public abstract class OnlineLinearTrainer<TArguments, TPredictor> : TrainerBase<TPredictor>
where TArguments : OnlineLinearArguments
where TPredictor : IPredictorProducing<Float>
public abstract class OnlineLinearTrainer<TTransformer, TModel> : TrainerEstimatorBase<TTransformer, TModel>
where TTransformer : IPredictionTransformer<TModel>
where TModel : IPredictor
{
protected readonly TArguments Args;
protected readonly OnlineLinearArguments Args;
protected readonly string Name;

// Initialized by InitCore
protected int NumFeatures;
Expand Down Expand Up @@ -74,15 +77,16 @@ public abstract class OnlineLinearTrainer<TArguments, TPredictor> : TrainerBase<

protected virtual bool NeedCalibration => false;

protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name)
: base(env, name)
protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(name), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.InitialWeights))
{
Contracts.CheckValue(args, nameof(args));
Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive);
Contracts.CheckUserArg(args.InitWtsDiameter >= 0, nameof(args.InitWtsDiameter), UserErrorNonNegative);
Contracts.CheckUserArg(args.StreamingCacheSize > 0, nameof(args.StreamingCacheSize), UserErrorPositive);

Args = args;
Name = name;
// REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue.
Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true);
}
Expand Down Expand Up @@ -111,7 +115,7 @@ protected void ScaleWeightsIfNeeded()
ScaleWeights();
}

public override TPredictor Train(TrainContext context)
protected override TModel TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var initPredictor = context.InitialPredictor;
Expand Down Expand Up @@ -148,10 +152,22 @@ public override TPredictor Train(TrainContext context)
return CreatePredictor();
}

protected abstract TPredictor CreatePredictor();
protected abstract TModel CreatePredictor();

protected abstract void CheckLabel(RoleMappedData data);

private static SchemaShape.Column MakeWeightColumn(string weightColumn)
{
if (weightColumn == null)
return null;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}

private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
{
return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
}

protected virtual void TrainCore(IChannel ch, RoleMappedData data)
{
bool shuffle = Args.Shuffle;
Expand Down
Loading