-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Conversion of prior and random trainers to estimators #876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
dda393a
602a581
6adb80b
a04ba3f
e083f1b
f07205f
e3e5d90
7211756
f05fe55
0221dfb
3dcb705
df28a35
f9dc9d1
14c47c0
d4c8f31
77392ee
886d514
005e53f
c661496
2042edd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
using Microsoft.ML.Runtime.Model; | ||
using Microsoft.ML.Runtime.Training; | ||
using Microsoft.ML.Runtime.Internal.Internallearn; | ||
using Microsoft.ML.Core.Data; | ||
|
||
[assembly: LoadableClass(RandomTrainer.Summary, typeof(RandomTrainer), typeof(RandomTrainer.Arguments), | ||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }, | ||
|
@@ -38,38 +39,37 @@ namespace Microsoft.ML.Runtime.Learners | |
/// <summary> | ||
/// A trainer that trains a predictor that returns random values | ||
/// </summary> | ||
public sealed class RandomTrainer : TrainerBase<RandomPredictor> | ||
|
||
public sealed class RandomTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<RandomPredictor>, RandomPredictor> | ||
{ | ||
internal const string LoadNameValue = "RandomPredictor"; | ||
internal const string UserNameValue = "Random Predictor"; | ||
internal const string Summary = "A toy predictor that returns a random value."; | ||
|
||
public class Arguments | ||
{ | ||
// Some sample arguments | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr")] | ||
public Float LearningRate = (Float)1.0; | ||
|
||
[Argument(ArgumentType.AtMostOnce, HelpText = "Some bool arg", ShortName = "boolarg")] | ||
public bool BooleanArg = false; | ||
} | ||
|
||
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; | ||
|
||
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false); | ||
public override TrainerInfo Info => _info; | ||
|
||
public RandomTrainer(IHostEnvironment env, Arguments args) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You should not remove this constructor. Once you start writing the test you will discover that it is necessary. Add a second constructor instead. Both here and for Prior #Closed |
||
: base(env, LoadNameValue) | ||
protected override SchemaShape.Column[] OutputColumns => throw new NotImplementedException(); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is obviously not sufficient. You need to list the columns that you are going to output. #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
public RandomTrainer(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight) | ||
: base(host, feature, label, weight) | ||
{ | ||
Host.CheckValue(args, nameof(args)); | ||
} | ||
|
||
public override RandomPredictor Train(TrainContext context) | ||
protected override RandomPredictor TrainModelCore(TrainContext trainContext) | ||
{ | ||
Host.CheckValue(context, nameof(context)); | ||
Host.CheckValue(trainContext, nameof(trainContext)); | ||
return new RandomPredictor(Host, Host.Rand.Next()); | ||
} | ||
|
||
protected override BinaryPredictionTransformer<RandomPredictor> MakeTransformer(RandomPredictor model, ISchema trainSchema) | ||
=> new BinaryPredictionTransformer<RandomPredictor>(Host, model, trainSchema, FeatureColumn.Name); | ||
} | ||
|
||
/// <summary> | ||
|
@@ -196,7 +196,7 @@ private void MapDist(ref VBuffer<Float> src, ref Float score, ref Float prob) | |
} | ||
|
||
// Learns the prior distribution for 0/1 class labels and just outputs that. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Make it /// #Resolved |
||
public sealed class PriorTrainer : TrainerBase<PriorPredictor> | ||
public sealed class PriorTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<PriorPredictor>, PriorPredictor> | ||
{ | ||
internal const string LoadNameValue = "PriorPredictor"; | ||
internal const string UserNameValue = "Prior Predictor"; | ||
|
@@ -210,13 +210,14 @@ public sealed class Arguments | |
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false); | ||
public override TrainerInfo Info => _info; | ||
|
||
public PriorTrainer(IHostEnvironment env, Arguments args) | ||
: base(env, LoadNameValue) | ||
protected override SchemaShape.Column[] OutputColumns { get; } | ||
|
||
public PriorTrainer(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can it be private? |
||
: base(host, feature, label, weight) | ||
{ | ||
Host.CheckValue(args, nameof(args)); | ||
} | ||
|
||
public override PriorPredictor Train(TrainContext context) | ||
protected override PriorPredictor TrainModelCore(TrainContext context) | ||
{ | ||
Contracts.CheckValue(context, nameof(context)); | ||
var data = context.TrainingSet; | ||
|
@@ -258,6 +259,10 @@ public override PriorPredictor Train(TrainContext context) | |
Float prob = prob = pos + neg > 0 ? (Float)(pos / (pos + neg)) : Float.NaN; | ||
return new PriorPredictor(Host, prob); | ||
} | ||
|
||
protected override BinaryPredictionTransformer<PriorPredictor> MakeTransformer(PriorPredictor model, ISchema trainSchema) | ||
=> new BinaryPredictionTransformer<PriorPredictor>(Host, model, trainSchema, FeatureColumn.Name); | ||
|
||
} | ||
|
||
public sealed class PriorPredictor : | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't help you to derive from this base class, and it also hurts: you now have to take label, feature and weight columns, but you don't actually need them at all.
So, don't derive from
TrainerEstimatorBase
and just implementITrainerEstimator
#Closed