Skip to content

Ap, LinearSVM, OGD as estimators #849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 8, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstim
where TTransformer : IPredictionTransformer<TModel>
where TModel : IPredictor
{
/// <summary>
/// A standard string to use in errors or warnings by subclasses, to communicate the idea that no valid
/// instances were able to be found.
/// </summary>
protected const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features.";

/// <summary>
/// The feature column that the trainer expects.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,7 @@ protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

public override PredictionKind PredictionKind {
get { return PredictionKind.BinaryClassification; }
}
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

/// <summary>
/// Combine a bunch of models into one by averaging parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Numeric;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Core.Data;

// TODO: Check if it works properly if Averaged is set to false

Expand Down Expand Up @@ -52,9 +53,10 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments
public Float AveragedTolerance = (Float)1e-2;
}

public abstract class AveragedLinearTrainer<TArguments, TPredictor> : OnlineLinearTrainer<TArguments, TPredictor>
public abstract class AveragedLinearTrainer<TArguments, TTransformer, TModel> : OnlineLinearTrainer<TArguments, TTransformer, TModel>
Copy link
Member Author

@sfilipi sfilipi Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TArguments [](start = 48, length = 10)

I will remove this on the next iteration, together with updating the tests. #Closed

where TArguments : AveragedLinearArguments
where TPredictor : IPredictorProducing<Float>
where TTransformer : IPredictionTransformer<TModel>
where TModel : IPredictor
{
protected IScalarOutputLoss LossFunction;

Expand All @@ -74,8 +76,8 @@ public abstract class AveragedLinearTrainer<TArguments, TPredictor> : OnlineLine
// We'll keep a few things global to prevent garbage collection
protected int NumNoUpdates;

protected AveragedLinearTrainer(TArguments args, IHostEnvironment env, string name)
: base(args, env, name)
protected AveragedLinearTrainer(TArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
: base(args, env, name, label)
{
Contracts.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), UserErrorPositive);
Contracts.CheckUserArg(!args.ResetWeightsAfterXExamples.HasValue || args.ResetWeightsAfterXExamples > 0, nameof(args.ResetWeightsAfterXExamples), UserErrorPositive);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.Numeric;
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Core.Data;

[assembly: LoadableClass(AveragedPerceptronTrainer.Summary, typeof(AveragedPerceptronTrainer), typeof(AveragedPerceptronTrainer.Arguments),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
Expand All @@ -29,8 +30,7 @@ namespace Microsoft.ML.Runtime.Learners
// - Feature normalization. By default, rescaling between min and max values for every feature
// - Prediction calibration to produce probabilities. Off by default, if on, uses exponential (aka Platt) calibration.
/// <include file='doc.xml' path='doc/members/member[@name="AP"]/*' />
public sealed class AveragedPerceptronTrainer :
AveragedLinearTrainer<AveragedPerceptronTrainer.Arguments, LinearBinaryPredictor>
public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<AveragedPerceptronTrainer.Arguments, BinaryPredictionTransformer<LinearBinaryPredictor> , LinearBinaryPredictor>
{
public const string LoadNameValue = "AveragedPerceptron";
internal const string UserNameValue = "Averaged Perceptron";
Expand All @@ -49,22 +49,36 @@ public class Arguments : AveragedLinearArguments
public int MaxCalibrationExamples = 1000000;
}

protected override bool NeedCalibration => true;

public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
: base(args, env, UserNameValue)
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
{
LossFunction = Args.LossFunction.CreateComponent(env);

OutputColumns = new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
};
}

public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

protected override bool NeedCalibration => true;

protected override SchemaShape.Column[] OutputColumns { get; }

protected override void CheckLabel(RoleMappedData data)
{
Contracts.AssertValue(data);
data.CheckBinaryLabel();
}

private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
}

protected override LinearBinaryPredictor CreatePredictor()
{
Contracts.Assert(WeightsScale == 1);
Expand All @@ -87,6 +101,9 @@ protected override LinearBinaryPredictor CreatePredictor()
return new LinearBinaryPredictor(Host, ref weights, bias);
}

protected override BinaryPredictionTransformer<LinearBinaryPredictor> MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<LinearBinaryPredictor>(Host, model, trainSchema, FeatureColumn.Name);

[TlcModule.EntryPoint(Name = "Trainers.AveragedPerceptronBinaryClassifier",
Desc = Summary,
UserName = UserNameValue,
Expand Down
24 changes: 21 additions & 3 deletions src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@

namespace Microsoft.ML.Runtime.Learners
{
using Microsoft.ML.Core.Data;
using TPredictor = LinearBinaryPredictor;

/// <summary>
/// Linear SVM that implements PEGASOS for training. See: http://ttic.uchicago.edu/~shai/papers/ShalevSiSr07.pdf
/// </summary>
public sealed class LinearSvm : OnlineLinearTrainer<LinearSvm.Arguments, TPredictor>
public sealed class LinearSvm : OnlineLinearTrainer<LinearSvm.Arguments, BinaryPredictionTransformer<LinearBinaryPredictor>, LinearBinaryPredictor>
{
public const string LoadNameValue = "LinearSVM";
public const string ShortName = "svm";
Expand Down Expand Up @@ -83,13 +84,22 @@ public sealed class Arguments : OnlineLinearArguments
protected override bool NeedCalibration => true;

public LinearSvm(IHostEnvironment env, Arguments args)
: base(args, env, UserNameValue)
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
{
Contracts.CheckUserArg(args.Lambda > 0, nameof(args.Lambda), UserErrorPositive);
Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive);

OutputColumns = new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
};
}

public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

protected override SchemaShape.Column[] OutputColumns { get; }

protected override void CheckLabel(RoleMappedData data)
{
Expand All @@ -105,6 +115,11 @@ protected override Float Margin(ref VBuffer<Float> feat)
return Bias + VectorUtils.DotProduct(ref feat, ref Weights) * WeightsScale;
}

private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
}

protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)
{
base.InitCore(ch, numFeatures, predictor);
Expand Down Expand Up @@ -237,5 +252,8 @@ public static CommonOutputs.BinaryClassificationOutput TrainLinearSvm(IHostEnvir
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
}

protected override BinaryPredictionTransformer<LinearBinaryPredictor> MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<LinearBinaryPredictor>(Host, model, trainSchema, FeatureColumn.Name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@

namespace Microsoft.ML.Runtime.Learners
{
using Microsoft.ML.Core.Data;
Copy link
Contributor

@Zruty0 Zruty0 Sep 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using [](start = 4, length = 5)

consolidate #Closed

using TPredictor = LinearRegressionPredictor;
Copy link
Contributor

@Zruty0 Zruty0 Sep 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TPredictor [](start = 10, length = 10)

let's remove this #Closed


/// <include file='doc.xml' path='doc/members/member[@name="OGD"]/*' />
public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer<OnlineGradientDescentTrainer.Arguments, TPredictor>
public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer<OnlineGradientDescentTrainer.Arguments, RegressionPredictionTransformer<LinearRegressionPredictor>, LinearRegressionPredictor>
{
internal const string LoadNameValue = "OnlineGradientDescent";
internal const string UserNameValue = "Stochastic Gradient Descent (Regression)";
Expand All @@ -53,19 +54,26 @@ public Arguments()
}

public OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args)
: base(args, env, UserNameValue)
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
{
LossFunction = args.LossFunction.CreateComponent(env);

OutputColumns = new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false)
};
}

public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
public override PredictionKind PredictionKind => PredictionKind.Regression;

protected override SchemaShape.Column[] OutputColumns { get; }

protected override void CheckLabel(RoleMappedData data)
{
data.CheckRegressionLabel();
}

protected override TPredictor CreatePredictor()
protected override LinearRegressionPredictor CreatePredictor()
{
Contracts.Assert(WeightsScale == 1);
VBuffer<Float> weights = default(VBuffer<Float>);
Expand All @@ -85,6 +93,11 @@ protected override TPredictor CreatePredictor()
return new LinearRegressionPredictor(Host, ref weights, bias);
}

private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, true);
}

[TlcModule.EntryPoint(Name = "Trainers.OnlineGradientDescentRegressor",
Desc = "Train a Online gradient descent perceptron.",
UserName = UserNameValue,
Expand All @@ -102,5 +115,8 @@ public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment en
() => new OnlineGradientDescentTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}

protected override RegressionPredictionTransformer<TPredictor> MakeTransformer(TPredictor model, ISchema trainSchema)
=> new RegressionPredictionTransformer<LinearRegressionPredictor>(Host, model, trainSchema, FeatureColumn.Name);
}
}
28 changes: 22 additions & 6 deletions src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Globalization;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
Expand Down Expand Up @@ -41,11 +42,13 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
public int StreamingCacheSize = 1000000;
}

public abstract class OnlineLinearTrainer<TArguments, TPredictor> : TrainerBase<TPredictor>
public abstract class OnlineLinearTrainer<TArguments, TTransformer, TModel> : TrainerEstimatorBase<TTransformer, TModel>
where TTransformer : IPredictionTransformer<TModel>
where TModel : IPredictor
where TArguments : OnlineLinearArguments
where TPredictor : IPredictorProducing<Float>
{
protected readonly TArguments Args;
protected readonly string Name;

// Initialized by InitCore
protected int NumFeatures;
Expand Down Expand Up @@ -74,15 +77,16 @@ public abstract class OnlineLinearTrainer<TArguments, TPredictor> : TrainerBase<

protected virtual bool NeedCalibration => false;

protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name)
: base(env, name)
protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(name), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.InitialWeights))
{
Contracts.CheckValue(args, nameof(args));
Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive);
Contracts.CheckUserArg(args.InitWtsDiameter >= 0, nameof(args.InitWtsDiameter), UserErrorNonNegative);
Contracts.CheckUserArg(args.StreamingCacheSize > 0, nameof(args.StreamingCacheSize), UserErrorPositive);

Args = args;
Name = name;
// REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue.
Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true);
}
Expand Down Expand Up @@ -111,7 +115,7 @@ protected void ScaleWeightsIfNeeded()
ScaleWeights();
}

public override TPredictor Train(TrainContext context)
protected override TModel TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var initPredictor = context.InitialPredictor;
Expand Down Expand Up @@ -148,10 +152,22 @@ public override TPredictor Train(TrainContext context)
return CreatePredictor();
}

protected abstract TPredictor CreatePredictor();
protected abstract TModel CreatePredictor();

protected abstract void CheckLabel(RoleMappedData data);

private static SchemaShape.Column MakeWeightColumn(string weightColumn)
{
if (weightColumn == null)
return null;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}

private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
{
return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
}

protected virtual void TrainCore(IChannel ch, RoleMappedData data)
{
bool shuffle = Args.Shuffle;
Expand Down