-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Conversion of prior and random trainers to estimators #876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
dda393a
602a581
6adb80b
a04ba3f
e083f1b
f07205f
e3e5d90
7211756
f05fe55
0221dfb
3dcb705
df28a35
f9dc9d1
14c47c0
d4c8f31
77392ee
886d514
005e53f
c661496
2042edd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,6 @@ | |
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Float = System.Single; | ||
|
||
using System; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.CommandLine; | ||
|
@@ -13,6 +11,8 @@ | |
using Microsoft.ML.Runtime.Model; | ||
using Microsoft.ML.Runtime.Training; | ||
using Microsoft.ML.Runtime.Internal.Internallearn; | ||
using Microsoft.ML.Core.Data; | ||
using System.Linq; | ||
|
||
[assembly: LoadableClass(RandomTrainer.Summary, typeof(RandomTrainer), typeof(RandomTrainer.Arguments), | ||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }, | ||
|
@@ -38,20 +38,16 @@ namespace Microsoft.ML.Runtime.Learners | |
/// <summary> | ||
/// A trainer that trains a predictor that returns random values | ||
/// </summary> | ||
public sealed class RandomTrainer : TrainerBase<RandomPredictor> | ||
|
||
public sealed class RandomTrainer : TrainerBase<RandomPredictor>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think this can go away too. It will eventually, to my understanding. @Zruty0 can confirm. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I remove that then I have to implement the ITrainer interface as well, which would basically mean implementing all of TrainerBase. In reply to: 218152453 [](ancestors = 218152453) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
ITrainerEstimator<BinaryPredictionTransformer<RandomPredictor>, RandomPredictor> | ||
{ | ||
internal const string LoadNameValue = "RandomPredictor"; | ||
internal const string UserNameValue = "Random Predictor"; | ||
internal const string Summary = "A toy predictor that returns a random value."; | ||
|
||
public class Arguments | ||
public sealed class Arguments | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
as you go through those, try adding XML comments to everything public, as they display on the docs site: docs.microsoft.com #Resolved |
||
{ | ||
// Some sample arguments | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr")] | ||
public Float LearningRate = (Float)1.0; | ||
|
||
[Argument(ArgumentType.AtMostOnce, HelpText = "Some bool arg", ShortName = "boolarg")] | ||
public bool BooleanArg = false; | ||
} | ||
|
||
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; | ||
|
@@ -65,20 +61,40 @@ public RandomTrainer(IHostEnvironment env, Arguments args) | |
Host.CheckValue(args, nameof(args)); | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is obviously not sufficient. You need to list the columns that you are going to output. #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
public BinaryPredictionTransformer<RandomPredictor> Fit(IDataView input) | ||
{ | ||
var cachedTrain = Info.WantCaching ? new CacheDataView(Host, input, prefetch: null) : input; | ||
|
||
RoleMappedData trainRoles = new RoleMappedData(cachedTrain); | ||
var pred = Train(new TrainContext(trainRoles)); | ||
return new BinaryPredictionTransformer<RandomPredictor>(Host, pred, cachedTrain.Schema, null); | ||
} | ||
|
||
public override RandomPredictor Train(TrainContext context) | ||
{ | ||
Host.CheckValue(context, nameof(context)); | ||
return new RandomPredictor(Host, Host.Rand.Next()); | ||
} | ||
|
||
public SchemaShape GetOutputSchema(SchemaShape inputSchema) | ||
{ | ||
Host.CheckValue(inputSchema, nameof(inputSchema)); | ||
|
||
var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); | ||
var newColumn = new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, isKey: false); | ||
outColumns[DefaultColumnNames.Score] = newColumn; | ||
|
||
return new SchemaShape(outColumns.Values); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// The predictor implements the Predict() interface. The predictor returns a | ||
/// uniform random probability and classification assignment. | ||
/// </summary> | ||
public sealed class RandomPredictor : | ||
PredictorBase<Float>, | ||
IDistPredictorProducing<Float, Float>, | ||
PredictorBase<float>, | ||
IDistPredictorProducing<float, float>, | ||
IValueMapperDist, | ||
ICanSaveModel | ||
{ | ||
|
@@ -96,7 +112,7 @@ private static VersionInfo GetVersionInfo() | |
// Keep all the serializable state here. | ||
private readonly int _seed; | ||
private readonly object _instanceLock; | ||
private readonly Random _random; | ||
private readonly TauswortheHybrid _random; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
would |
||
|
||
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; | ||
public ColumnType InputType { get; } | ||
|
@@ -109,7 +125,7 @@ public RandomPredictor(IHostEnvironment env, int seed) | |
_seed = seed; | ||
|
||
_instanceLock = new object(); | ||
_random = new Random(_seed); | ||
_random = RandomUtils.Create(_seed); | ||
|
||
InputType = new VectorType(NumberType.Float); | ||
} | ||
|
@@ -126,7 +142,7 @@ private RandomPredictor(IHostEnvironment env, ModelLoadContext ctx) | |
_seed = ctx.Reader.ReadInt32(); | ||
|
||
_instanceLock = new object(); | ||
_random = new Random(_seed); | ||
_random = RandomUtils.Create(_seed); | ||
} | ||
|
||
public static RandomPredictor Create(IHostEnvironment env, ModelLoadContext ctx) | ||
|
@@ -154,24 +170,24 @@ protected override void SaveCore(ModelSaveContext ctx) | |
|
||
public ValueMapper<TIn, TOut> GetMapper<TIn, TOut>() | ||
{ | ||
Contracts.Check(typeof(TIn) == typeof(VBuffer<Float>)); | ||
Contracts.Check(typeof(TOut) == typeof(Float)); | ||
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>)); | ||
Contracts.Check(typeof(TOut) == typeof(float)); | ||
|
||
ValueMapper<VBuffer<Float>, Float> del = Map; | ||
ValueMapper<VBuffer<float>, float> del = Map; | ||
return (ValueMapper<TIn, TOut>)(Delegate)del; | ||
} | ||
|
||
public ValueMapper<TIn, TOut, TDist> GetMapper<TIn, TOut, TDist>() | ||
{ | ||
Contracts.Check(typeof(TIn) == typeof(VBuffer<Float>)); | ||
Contracts.Check(typeof(TOut) == typeof(Float)); | ||
Contracts.Check(typeof(TDist) == typeof(Float)); | ||
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>)); | ||
Contracts.Check(typeof(TOut) == typeof(float)); | ||
Contracts.Check(typeof(TDist) == typeof(float)); | ||
|
||
ValueMapper<VBuffer<Float>, Float, Float> del = MapDist; | ||
ValueMapper<VBuffer<float>, float, float> del = MapDist; | ||
return (ValueMapper<TIn, TOut, TDist>)(Delegate)del; | ||
} | ||
|
||
private Float PredictCore() | ||
private float PredictCore() | ||
{ | ||
// Predict can be called from different threads. | ||
// Ensure your implementation is thread-safe | ||
|
@@ -183,20 +199,20 @@ private Float PredictCore() | |
} | ||
} | ||
|
||
private void Map(ref VBuffer<Float> src, ref Float dst) | ||
private void Map(ref VBuffer<float> src, ref float dst) | ||
{ | ||
dst = PredictCore(); | ||
} | ||
|
||
private void MapDist(ref VBuffer<Float> src, ref Float score, ref Float prob) | ||
private void MapDist(ref VBuffer<float> src, ref float score, ref float prob) | ||
{ | ||
score = PredictCore(); | ||
prob = (score + 1) / 2; | ||
} | ||
} | ||
|
||
// Learns the prior distribution for 0/1 class labels and just outputs that. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Make it /// #Resolved |
||
public sealed class PriorTrainer : TrainerBase<PriorPredictor> | ||
public sealed class PriorTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<PriorPredictor>, PriorPredictor> | ||
{ | ||
internal const string LoadNameValue = "PriorPredictor"; | ||
internal const string UserNameValue = "Prior Predictor"; | ||
|
@@ -210,13 +226,27 @@ public sealed class Arguments | |
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false); | ||
public override TrainerInfo Info => _info; | ||
|
||
protected override SchemaShape.Column[] OutputColumns { get; } | ||
|
||
public PriorTrainer(IHostEnvironment env, Arguments args) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
private #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
: base(env, LoadNameValue) | ||
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), MakeFeatureColumn(DefaultColumnNames.Score), MakeFeatureColumn(DefaultColumnNames.Label), null) | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
keep the args check. #Resolved |
||
Host.CheckValue(args, nameof(args)); | ||
OutputColumns = new[] | ||
{ | ||
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, isKey: false) | ||
}; | ||
} | ||
|
||
public PriorTrainer(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can it be private? |
||
: base(host, feature, label, weight) | ||
{ | ||
OutputColumns = new[] | ||
{ | ||
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, isKey: false) | ||
}; | ||
} | ||
|
||
public override PriorPredictor Train(TrainContext context) | ||
protected override PriorPredictor TrainModelCore(TrainContext context) | ||
{ | ||
Contracts.CheckValue(context, nameof(context)); | ||
var data = context.TrainingSet; | ||
|
@@ -234,16 +264,16 @@ public override PriorPredictor Train(TrainContext context) | |
using (var cursor = data.Data.GetRowCursor(c => c == col || c == colWeight)) | ||
{ | ||
var getLab = cursor.GetLabelFloatGetter(data); | ||
var getWeight = colWeight >= 0 ? cursor.GetGetter<Float>(colWeight) : null; | ||
Float lab = default(Float); | ||
Float weight = 1; | ||
var getWeight = colWeight >= 0 ? cursor.GetGetter<float>(colWeight) : null; | ||
float lab = default(float); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
you can omit it. #Resolved |
||
float weight = 1; | ||
while (cursor.MoveNext()) | ||
{ | ||
getLab(ref lab); | ||
if (getWeight != null) | ||
{ | ||
getWeight(ref weight); | ||
if (!(0 < weight && weight < Float.PositiveInfinity)) | ||
if (!(0 < weight && weight < float.PositiveInfinity)) | ||
continue; | ||
} | ||
|
||
|
@@ -255,14 +285,23 @@ public override PriorPredictor Train(TrainContext context) | |
} | ||
} | ||
|
||
Float prob = prob = pos + neg > 0 ? (Float)(pos / (pos + neg)) : Float.NaN; | ||
float prob = prob = pos + neg > 0 ? (float)(pos / (pos + neg)) : float.NaN; | ||
return new PriorPredictor(Host, prob); | ||
} | ||
|
||
protected override BinaryPredictionTransformer<PriorPredictor> MakeTransformer(PriorPredictor model, ISchema trainSchema) | ||
=> new BinaryPredictionTransformer<PriorPredictor>(Host, model, trainSchema, FeatureColumn.Name); | ||
|
||
private static SchemaShape.Column MakeFeatureColumn(string featureColumn) => | ||
new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); | ||
|
||
private static SchemaShape.Column MakeLabelColumn(string labelColumn) => | ||
new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); | ||
} | ||
|
||
public sealed class PriorPredictor : | ||
PredictorBase<Float>, | ||
IDistPredictorProducing<Float, Float>, | ||
PredictorBase<float>, | ||
IDistPredictorProducing<float, float>, | ||
IValueMapperDist, | ||
ICanSaveModel | ||
{ | ||
|
@@ -277,13 +316,13 @@ private static VersionInfo GetVersionInfo() | |
loaderSignature: LoaderSignature); | ||
} | ||
|
||
private readonly Float _prob; | ||
private readonly Float _raw; | ||
private readonly float _prob; | ||
private readonly float _raw; | ||
|
||
public PriorPredictor(IHostEnvironment env, Float prob) | ||
public PriorPredictor(IHostEnvironment env, float prob) | ||
: base(env, LoaderSignature) | ||
{ | ||
Host.Check(!Float.IsNaN(prob)); | ||
Host.Check(!float.IsNaN(prob)); | ||
|
||
_prob = prob; | ||
_raw = 2 * _prob - 1; // This could be other functions -- logodds for instance | ||
|
@@ -298,7 +337,7 @@ private PriorPredictor(IHostEnvironment env, ModelLoadContext ctx) | |
// Float: _prob | ||
|
||
_prob = ctx.Reader.ReadFloat(); | ||
Host.CheckDecode(!Float.IsNaN(_prob)); | ||
Host.CheckDecode(!float.IsNaN(_prob)); | ||
|
||
_raw = 2 * _prob - 1; | ||
|
||
|
@@ -321,7 +360,7 @@ protected override void SaveCore(ModelSaveContext ctx) | |
// *** Binary format *** | ||
// Float: _prob | ||
|
||
Contracts.Assert(!Float.IsNaN(_prob)); | ||
Contracts.Assert(!float.IsNaN(_prob)); | ||
ctx.Writer.Write(_prob); | ||
} | ||
|
||
|
@@ -333,29 +372,29 @@ public override PredictionKind PredictionKind | |
|
||
public ValueMapper<TIn, TOut> GetMapper<TIn, TOut>() | ||
{ | ||
Contracts.Check(typeof(TIn) == typeof(VBuffer<Float>)); | ||
Contracts.Check(typeof(TOut) == typeof(Float)); | ||
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>)); | ||
Contracts.Check(typeof(TOut) == typeof(float)); | ||
|
||
ValueMapper<VBuffer<Float>, Float> del = Map; | ||
ValueMapper<VBuffer<float>, float> del = Map; | ||
return (ValueMapper<TIn, TOut>)(Delegate)del; | ||
} | ||
|
||
public ValueMapper<TIn, TOut, TDist> GetMapper<TIn, TOut, TDist>() | ||
{ | ||
Contracts.Check(typeof(TIn) == typeof(VBuffer<Float>)); | ||
Contracts.Check(typeof(TOut) == typeof(Float)); | ||
Contracts.Check(typeof(TDist) == typeof(Float)); | ||
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>)); | ||
Contracts.Check(typeof(TOut) == typeof(float)); | ||
Contracts.Check(typeof(TDist) == typeof(float)); | ||
|
||
ValueMapper<VBuffer<Float>, Float, Float> del = MapDist; | ||
ValueMapper<VBuffer<float>, float, float> del = MapDist; | ||
return (ValueMapper<TIn, TOut, TDist>)(Delegate)del; | ||
} | ||
|
||
private void Map(ref VBuffer<Float> src, ref Float dst) | ||
private void Map(ref VBuffer<float> src, ref float dst) | ||
{ | ||
dst = _raw; | ||
} | ||
|
||
private void MapDist(ref VBuffer<Float> src, ref Float score, ref Float prob) | ||
private void MapDist(ref VBuffer<float> src, ref float score, ref float prob) | ||
{ | ||
score = _raw; | ||
prob = _prob; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
maml.exe CV tr=PriorPredictor threads=- dout=%Output% data=%Data% seed=1 | ||
Not adding a normalizer. | ||
Not training a calibrator because it is not needed. | ||
Not adding a normalizer. | ||
Not training a calibrator because it is not needed. | ||
TEST POSITIVE RATIO: 0.3702 (134.0/(134.0+228.0)) | ||
Confusion table | ||
||====================== | ||
PREDICTED || positive | negative | Recall | ||
TRUTH ||====================== | ||
positive || 0 | 134 | 0.0000 | ||
negative || 0 | 228 | 1.0000 | ||
||====================== | ||
Precision || 0.0000 | 0.6298 | | ||
OVERALL 0/1 ACCURACY: 0.629834 | ||
LOG LOSS/instance: 0.959786 | ||
Test-set entropy (prior Log-Loss/instance): 0.950799 | ||
LOG-LOSS REDUCTION (RIG): -0.945203 | ||
AUC: 0.500000 | ||
TEST POSITIVE RATIO: 0.3175 (107.0/(107.0+230.0)) | ||
Confusion table | ||
||====================== | ||
PREDICTED || positive | negative | Recall | ||
TRUTH ||====================== | ||
positive || 0 | 107 | 0.0000 | ||
negative || 0 | 230 | 1.0000 | ||
||====================== | ||
Precision || 0.0000 | 0.6825 | | ||
OVERALL 0/1 ACCURACY: 0.682493 | ||
LOG LOSS/instance: 0.910421 | ||
Test-set entropy (prior Log-Loss/instance): 0.901650 | ||
LOG-LOSS REDUCTION (RIG): -0.972725 | ||
AUC: 0.500000 | ||
|
||
OVERALL RESULTS | ||
--------------------------------------- | ||
AUC: 0.500000 (0.0000) | ||
Accuracy: 0.656163 (0.0263) | ||
Positive precision: 0.000000 (0.0000) | ||
Positive recall: 0.000000 (0.0000) | ||
Negative precision: 0.656163 (0.0263) | ||
Negative recall: 1.000000 (0.0000) | ||
Log-loss: 0.935104 (0.0247) | ||
Log-loss reduction: -0.958964 (0.0138) | ||
F1 Score: NaN (NaN) | ||
AUPRC: 0.418968 (0.0212) | ||
|
||
--------------------------------------- | ||
Physical memory usage(MB): %Number% | ||
Virtual memory usage(MB): %Number% | ||
%DateTime% Time elapsed(s): %Number% | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
PriorPredictor | ||
AUC Accuracy Positive precision Positive recall Negative precision Negative recall Log-loss Log-loss reduction F1 Score AUPRC Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings | ||
0.5 0.656163 0 0 0.656163 1 0.935104 -0.958964 NaN 0.418968 PriorPredictor %Data% %Output% 99 0 0 maml.exe CV tr=PriorPredictor threads=- dout=%Output% data=%Data% seed=1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sort namespaces #Closed