-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Ap, LinearSVM, OGD as estimators #849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
96fd88e
d47014e
23da0eb
101c2e8
58dbbac
0c176aa
6a01926
40fcd33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
using Microsoft.ML.Runtime.Learners; | ||
using Microsoft.ML.Runtime.Numeric; | ||
using Microsoft.ML.Runtime.Training; | ||
using Microsoft.ML.Core.Data; | ||
|
||
[assembly: LoadableClass(AveragedPerceptronTrainer.Summary, typeof(AveragedPerceptronTrainer), typeof(AveragedPerceptronTrainer.Arguments), | ||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, | ||
|
@@ -29,14 +30,15 @@ namespace Microsoft.ML.Runtime.Learners | |
// - Feature normalization. By default, rescaling between min and max values for every feature | ||
// - Prediction calibration to produce probabilities. Off by default, if on, uses exponential (aka Platt) calibration. | ||
/// <include file='doc.xml' path='doc/members/member[@name="AP"]/*' /> | ||
public sealed class AveragedPerceptronTrainer : | ||
AveragedLinearTrainer<AveragedPerceptronTrainer.Arguments, LinearBinaryPredictor> | ||
public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPredictionTransformer<LinearBinaryPredictor> , LinearBinaryPredictor> | ||
{ | ||
public const string LoadNameValue = "AveragedPerceptron"; | ||
internal const string UserNameValue = "Averaged Perceptron"; | ||
internal const string ShortName = "ap"; | ||
internal const string Summary = "Averaged Perceptron Binary Classifier."; | ||
|
||
internal new readonly Arguments Args; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
this should truly be private, but our analyzer wants the private properties to be lowercased. This i should change or add a separate rule for 'private new'. Thoughts? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
public class Arguments : AveragedLinearArguments | ||
{ | ||
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] | ||
|
@@ -49,22 +51,37 @@ public class Arguments : AveragedLinearArguments | |
public int MaxCalibrationExamples = 1000000; | ||
} | ||
|
||
protected override bool NeedCalibration => true; | ||
|
||
public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args) | ||
: base(args, env, UserNameValue) | ||
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn)) | ||
{ | ||
Args = args; | ||
LossFunction = Args.LossFunction.CreateComponent(env); | ||
|
||
OutputColumns = new[] | ||
{ | ||
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), | ||
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), | ||
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) | ||
}; | ||
} | ||
|
||
public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } | ||
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; | ||
|
||
protected override bool NeedCalibration => true; | ||
|
||
protected override SchemaShape.Column[] OutputColumns { get; } | ||
|
||
protected override void CheckLabel(RoleMappedData data) | ||
{ | ||
Contracts.AssertValue(data); | ||
data.CheckBinaryLabel(); | ||
} | ||
|
||
private static SchemaShape.Column MakeLabelColumn(string labelColumn) | ||
{ | ||
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); | ||
} | ||
|
||
protected override LinearBinaryPredictor CreatePredictor() | ||
{ | ||
Contracts.Assert(WeightsScale == 1); | ||
|
@@ -87,6 +104,9 @@ protected override LinearBinaryPredictor CreatePredictor() | |
return new LinearBinaryPredictor(Host, ref weights, bias); | ||
} | ||
|
||
protected override BinaryPredictionTransformer<LinearBinaryPredictor> MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema) | ||
=> new BinaryPredictionTransformer<LinearBinaryPredictor>(Host, model, trainSchema, FeatureColumn.Name); | ||
|
||
[TlcModule.EntryPoint(Name = "Trainers.AveragedPerceptronBinaryClassifier", | ||
Desc = Summary, | ||
UserName = UserNameValue, | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not certain I welcome this move. It seems like it is repeating the mistake of
IIncrementalValidatingTrainer
, which would be bad enough if it was just for trainers alone but it seems to be for absolutely every estimator. #ResolvedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree
In reply to: 215803874 [](ancestors = 215803874)