Skip to content

Ap, LinearSVM, OGD as estimators #849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 8, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/IEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ public interface IEstimator<out TTransformer>
/// <summary>
/// Train and return a transformer.
/// </summary>
TTransformer Fit(IDataView input);
TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null);
Copy link
Contributor

@TomFinley TomFinley Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDataView validationData = null, IPredictor initialPredictor = null) [](start = 41, length = 69)

I am not certain I welcome this move. It seems like it is repeating the mistake of IIncrementalValidatingTrainer, which would be bad enough if it was just for trainers alone but it seems to be for absolutely every estimator. #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree


In reply to: 215803874 [](ancestors = 215803874)


/// <summary>
/// Schema propagation for estimators.
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public EstimatorChain()
LastEstimator = null;
}

public TransformerChain<TLastTransformer> Fit(IDataView input)
public TransformerChain<TLastTransformer> Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
{
// REVIEW: before fitting, run schema propagation.
// Currently, it throws.
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public DelegateEstimator(IEstimator<TTransformer> estimator, Action<TTransformer
_onFit = onFit;
}

public TTransformer Fit(IDataView input)
public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
{
var trans = _est.Fit(input);
_onFit(trans);
Expand All @@ -102,9 +102,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
}

/// <summary>
/// Given an estimator, return a wrapping object that will call a delegate once <see cref="IEstimator{TTransformer}.Fit(IDataView)"/>
/// Given an estimator, return a wrapping object that will call a delegate once <see cref="IEstimator{TTransformer}.Fit(IDataView, IDataView, IPredictor)"/>
/// is called. It is often important for an estimator to return information about what was fit, which is why the
/// <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> method returns a specifically typed object, rather than just a general
/// <see cref="IEstimator{TTransformer}.Fit(IDataView, IDataView, IPredictor)"/> method returns a specifically typed object, rather than just a general
/// <see cref="ITransformer"/>. However, at the same time, <see cref="IEstimator{TTransformer}"/> are often formed into pipelines
/// with many objects, so we may need to build a chain of estimators via <see cref="EstimatorChain{TLastTransformer}"/> where the
/// estimator for which we want to get the transformer is buried somewhere in this chain. For that scenario, we can through this
Expand All @@ -113,7 +113,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
/// <typeparam name="TTransformer">The type of <see cref="ITransformer"/> returned by <paramref name="estimator"/></typeparam>
/// <param name="estimator">The estimator to wrap</param>
/// <param name="onFit">The delegate that is called with the resulting <typeparamref name="TTransformer"/> instances once
/// <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> is called. Because <see cref="IEstimator{TTransformer}.Fit(IDataView)"/>
/// <see cref="IEstimator{TTransformer}.Fit(IDataView, IDataView, IPredictor)"/> is called. Because <see cref="IEstimator{TTransformer}.Fit(IDataView, IDataView, IPredictor)"/>
/// may be called multiple times, this delegate may also be called multiple times.</param>
/// <returns>A wrapping estimator that calls the indicated delegate whenever fit is called</returns>
public static IEstimator<TTransformer> WithOnFitDelegate<TTransformer>(this IEstimator<TTransformer> estimator, Action<TTransformer> onFit)
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Microsoft.ML.Runtime.Data
{
/// <summary>
/// The trivial implementation of <see cref="IEstimator{TTransformer}"/> that already has
/// the transformer and returns it on every call to <see cref="Fit(IDataView)"/>.
/// the transformer and returns it on every call to <see cref="Fit(IDataView, IDataView, IPredictor)"/>.
///
/// Concrete implementations still have to provide the schema propagation mechanism, since
/// there is no easy way to infer it from the transformer.
Expand All @@ -28,7 +28,7 @@ protected TrivialEstimator(IHost host, TTransformer transformer)
Transformer = transformer;
}

public TTransformer Fit(IDataView input) => Transformer;
public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => Transformer;

public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema);
}
Expand Down
8 changes: 7 additions & 1 deletion src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstim
where TTransformer : IPredictionTransformer<TModel>
where TModel : IPredictor
{
/// <summary>
/// A standard string to use in errors or warnings by subclasses, to communicate the idea that no valid
/// instances were able to be found.
/// </summary>
protected const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features.";

/// <summary>
/// The feature column that the trainer expects.
/// </summary>
Expand Down Expand Up @@ -61,7 +67,7 @@ public TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.
WeightColumn = weight;
}

public TTransformer Fit(IDataView input) => TrainTransformer(input);
public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => TrainTransformer(input, validationData, initialPredictor);

public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public CopyColumnsEstimator(IHostEnvironment env, params (string source, string
_columns = columns;
}

public CopyColumnsTransform Fit(IDataView input)
public CopyColumnsTransform Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
{
// Invoke schema validation.
GetOutputSchema(SchemaShape.Create(input.Schema));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/Normalizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ public Normalizer(IHostEnvironment env, params ColumnBase[] columns)
_columns = columns.ToArray();
}

public NormalizerTransformer Fit(IDataView input)
public NormalizerTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
{
_host.CheckValue(input, nameof(input));
return NormalizerTransformer.Train(_host, input, _columns);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/TermEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public TermEstimator(IHostEnvironment env, params TermTransform.ColumnInfo[] col
_columns = columns;
}

public TermTransform Fit(IDataView input) => new TermTransform(_host, input, _columns);
public TermTransform Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => new TermTransform(_host, input, _columns);

public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,7 @@ protected override void CheckLabel(RoleMappedData examples, out int weightSetCou
protected override BinaryPredictionTransformer<TScalarPredictor> MakeTransformer(TScalarPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<TScalarPredictor>(Host, model, trainSchema, FeatureColumn.Name);

public BinaryPredictionTransformer<TScalarPredictor> Train(IDataView trainData, IDataView validationData) => TrainTransformer(trainData, validationData);
public BinaryPredictionTransformer<TScalarPredictor> Train(IDataView trainData, IDataView validationData = null, IPredictor initialPredictor = null) => TrainTransformer(trainData, validationData, initialPredictor);
}

public sealed class StochasticGradientDescentClassificationTrainer :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,7 @@ protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

public override PredictionKind PredictionKind {
get { return PredictionKind.BinaryClassification; }
}
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

/// <summary>
/// Combine a bunch of models into one by averaging parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Numeric;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Core.Data;

// TODO: Check if it works properly if Averaged is set to false

Expand Down Expand Up @@ -52,12 +53,12 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments
public Float AveragedTolerance = (Float)1e-2;
}

public abstract class AveragedLinearTrainer<TArguments, TPredictor> : OnlineLinearTrainer<TArguments, TPredictor>
where TArguments : AveragedLinearArguments
where TPredictor : IPredictorProducing<Float>
public abstract class AveragedLinearTrainer<TTransformer, TModel> : OnlineLinearTrainer<TTransformer, TModel>
where TTransformer : IPredictionTransformer<TModel>
where TModel : IPredictor
{
protected readonly new AveragedLinearArguments Args;
protected IScalarOutputLoss LossFunction;

protected Float Gain;

// For computing averaged weights and bias (if needed)
Expand All @@ -74,15 +75,18 @@ public abstract class AveragedLinearTrainer<TArguments, TPredictor> : OnlineLine
// We'll keep a few things global to prevent garbage collection
protected int NumNoUpdates;

protected AveragedLinearTrainer(TArguments args, IHostEnvironment env, string name)
: base(args, env, name)
protected AveragedLinearTrainer(AveragedLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
: base(args, env, name, label)
{
Contracts.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), UserErrorPositive);
Contracts.CheckUserArg(!args.ResetWeightsAfterXExamples.HasValue || args.ResetWeightsAfterXExamples > 0, nameof(args.ResetWeightsAfterXExamples), UserErrorPositive);

// Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible.
Contracts.CheckUserArg(0 <= args.L2RegularizerWeight && args.L2RegularizerWeight < 0.5, nameof(args.L2RegularizerWeight), "must be in range [0, 0.5)");
Contracts.CheckUserArg(args.RecencyGain >= 0, nameof(args.RecencyGain), UserErrorNonNegative);
Contracts.CheckUserArg(args.AveragedTolerance >= 0, nameof(args.AveragedTolerance), UserErrorNonNegative);

Args = args;
}

protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.Numeric;
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Core.Data;

[assembly: LoadableClass(AveragedPerceptronTrainer.Summary, typeof(AveragedPerceptronTrainer), typeof(AveragedPerceptronTrainer.Arguments),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
Expand All @@ -29,14 +30,15 @@ namespace Microsoft.ML.Runtime.Learners
// - Feature normalization. By default, rescaling between min and max values for every feature
// - Prediction calibration to produce probabilities. Off by default, if on, uses exponential (aka Platt) calibration.
/// <include file='doc.xml' path='doc/members/member[@name="AP"]/*' />
public sealed class AveragedPerceptronTrainer :
AveragedLinearTrainer<AveragedPerceptronTrainer.Arguments, LinearBinaryPredictor>
public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPredictionTransformer<LinearBinaryPredictor> , LinearBinaryPredictor>
{
public const string LoadNameValue = "AveragedPerceptron";
internal const string UserNameValue = "Averaged Perceptron";
internal const string ShortName = "ap";
internal const string Summary = "Averaged Perceptron Binary Classifier.";

internal new readonly Arguments Args;
Copy link
Member Author

@sfilipi sfilipi Sep 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internal [](start = 8, length = 8)

this should truly be private, but our analyzer wants the private properties to be lowercased. This i should change or add a separate rule for 'private new'. Thoughts? #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just private readonly Arguments _args ?


In reply to: 215814336 [](ancestors = 215814336)


public class Arguments : AveragedLinearArguments
{
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
Expand All @@ -49,22 +51,37 @@ public class Arguments : AveragedLinearArguments
public int MaxCalibrationExamples = 1000000;
}

protected override bool NeedCalibration => true;

public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
: base(args, env, UserNameValue)
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
{
Args = args;
LossFunction = Args.LossFunction.CreateComponent(env);

OutputColumns = new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
};
}

public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

protected override bool NeedCalibration => true;

protected override SchemaShape.Column[] OutputColumns { get; }

protected override void CheckLabel(RoleMappedData data)
{
Contracts.AssertValue(data);
data.CheckBinaryLabel();
}

private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
}

protected override LinearBinaryPredictor CreatePredictor()
{
Contracts.Assert(WeightsScale == 1);
Expand All @@ -87,6 +104,9 @@ protected override LinearBinaryPredictor CreatePredictor()
return new LinearBinaryPredictor(Host, ref weights, bias);
}

protected override BinaryPredictionTransformer<LinearBinaryPredictor> MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<LinearBinaryPredictor>(Host, model, trainSchema, FeatureColumn.Name);

[TlcModule.EntryPoint(Name = "Trainers.AveragedPerceptronBinaryClassifier",
Desc = Summary,
UserName = UserNameValue,
Expand Down
28 changes: 25 additions & 3 deletions src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@

namespace Microsoft.ML.Runtime.Learners
{
using Microsoft.ML.Core.Data;
using TPredictor = LinearBinaryPredictor;

/// <summary>
/// Linear SVM that implements PEGASOS for training. See: http://ttic.uchicago.edu/~shai/papers/ShalevSiSr07.pdf
/// </summary>
public sealed class LinearSvm : OnlineLinearTrainer<LinearSvm.Arguments, TPredictor>
public sealed class LinearSvm : OnlineLinearTrainer<BinaryPredictionTransformer<LinearBinaryPredictor>, LinearBinaryPredictor>
{
public const string LoadNameValue = "LinearSVM";
public const string ShortName = "svm";
Expand All @@ -41,6 +42,8 @@ public sealed class LinearSvm : OnlineLinearTrainer<LinearSvm.Arguments, TPredic
+ "and all the negative examples are on the other. After this mapping, quadratic programming is used to find the separating hyperplane that maximizes the "
+ "margin, i.e., the minimal distance between it and the instances.";

internal new readonly Arguments Args;

public sealed class Arguments : OnlineLinearArguments
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Regularizer constant", ShortName = "lambda", SortOrder = 50)]
Expand Down Expand Up @@ -83,13 +86,24 @@ public sealed class Arguments : OnlineLinearArguments
protected override bool NeedCalibration => true;

public LinearSvm(IHostEnvironment env, Arguments args)
: base(args, env, UserNameValue)
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
{
Contracts.CheckUserArg(args.Lambda > 0, nameof(args.Lambda), UserErrorPositive);
Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive);

Args = args;

OutputColumns = new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
};
}

public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

protected override SchemaShape.Column[] OutputColumns { get; }

protected override void CheckLabel(RoleMappedData data)
{
Expand All @@ -105,6 +119,11 @@ protected override Float Margin(ref VBuffer<Float> feat)
return Bias + VectorUtils.DotProduct(ref feat, ref Weights) * WeightsScale;
}

private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
}

protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)
{
base.InitCore(ch, numFeatures, predictor);
Expand Down Expand Up @@ -237,5 +256,8 @@ public static CommonOutputs.BinaryClassificationOutput TrainLinearSvm(IHostEnvir
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
}

protected override BinaryPredictionTransformer<LinearBinaryPredictor> MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<LinearBinaryPredictor>(Host, model, trainSchema, FeatureColumn.Name);
}
}
Loading