Skip to content

Conversion of prior and random trainers to estimators #876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 19, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 90 additions & 51 deletions src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Float = System.Single;

using System;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
Expand All @@ -13,6 +11,8 @@
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Core.Data;
using System.Linq;
Copy link
Contributor

@Zruty0 Zruty0 Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

System [](start = 6, length = 6)

sort namespaces #Closed


[assembly: LoadableClass(RandomTrainer.Summary, typeof(RandomTrainer), typeof(RandomTrainer.Arguments),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) },
Expand All @@ -38,20 +38,16 @@ namespace Microsoft.ML.Runtime.Learners
/// <summary>
/// A trainer that trains a predictor that returns random values
/// </summary>
public sealed class RandomTrainer : TrainerBase<RandomPredictor>

public sealed class RandomTrainer : TrainerBase<RandomPredictor>,
Copy link
Member

@sfilipi sfilipi Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TrainerBase [](start = 40, length = 28)

I think this can go away too. It will eventually, to my understanding. @Zruty0 can confirm. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remove that then I have to implement the ITrainer interface as well, which would basically mean implementing all of TrainerBase.
This is because I am not deriving from TrainerEstimatorBase, which implements ITrainer.


In reply to: 218152453 [](ancestors = 218152453)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is an exception :)


In reply to: 218156771 [](ancestors = 218156771,218152453)

ITrainerEstimator<BinaryPredictionTransformer<RandomPredictor>, RandomPredictor>
{
internal const string LoadNameValue = "RandomPredictor";
internal const string UserNameValue = "Random Predictor";
internal const string Summary = "A toy predictor that returns a random value.";

public class Arguments
public sealed class Arguments
Copy link
Member

@sfilipi sfilipi Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public [](start = 8, length = 6)

as you go through those, try adding XML comments to everything public, as they display on the docs site: docs.microsoft.com #Resolved

{
// Some sample arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr")]
public Float LearningRate = (Float)1.0;

[Argument(ArgumentType.AtMostOnce, HelpText = "Some bool arg", ShortName = "boolarg")]
public bool BooleanArg = false;
}

public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
Expand All @@ -65,20 +61,40 @@ public RandomTrainer(IHostEnvironment env, Arguments args)
Host.CheckValue(args, nameof(args));
}

Copy link
Contributor

@Zruty0 Zruty0 Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is obviously not sufficient. You need to list the columns that you are going to output. #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same goes for the other trainer


In reply to: 216508453 [](ancestors = 216508453)

public BinaryPredictionTransformer<RandomPredictor> Fit(IDataView input)
{
var cachedTrain = Info.WantCaching ? new CacheDataView(Host, input, prefetch: null) : input;

RoleMappedData trainRoles = new RoleMappedData(cachedTrain);
var pred = Train(new TrainContext(trainRoles));
return new BinaryPredictionTransformer<RandomPredictor>(Host, pred, cachedTrain.Schema, null);
}

public override RandomPredictor Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
return new RandomPredictor(Host, Host.Rand.Next());
}

public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
var newColumn = new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, isKey: false);
outColumns[DefaultColumnNames.Score] = newColumn;

return new SchemaShape(outColumns.Values);
}
}

/// <summary>
/// The predictor implements the Predict() interface. The predictor returns a
/// uniform random probability and classification assignment.
/// </summary>
public sealed class RandomPredictor :
PredictorBase<Float>,
IDistPredictorProducing<Float, Float>,
PredictorBase<float>,
IDistPredictorProducing<float, float>,
IValueMapperDist,
ICanSaveModel
{
Expand All @@ -96,7 +112,7 @@ private static VersionInfo GetVersionInfo()
// Keep all the serializable state here.
private readonly int _seed;
private readonly object _instanceLock;
private readonly Random _random;
private readonly TauswortheHybrid _random;
Copy link
Contributor

@Zruty0 Zruty0 Sep 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TauswortheHybrid [](start = 25, length = 16)

would IRandom suffice here? #Closed


public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
public ColumnType InputType { get; }
Expand All @@ -109,7 +125,7 @@ public RandomPredictor(IHostEnvironment env, int seed)
_seed = seed;

_instanceLock = new object();
_random = new Random(_seed);
_random = RandomUtils.Create(_seed);

InputType = new VectorType(NumberType.Float);
}
Expand All @@ -126,7 +142,7 @@ private RandomPredictor(IHostEnvironment env, ModelLoadContext ctx)
_seed = ctx.Reader.ReadInt32();

_instanceLock = new object();
_random = new Random(_seed);
_random = RandomUtils.Create(_seed);
}

public static RandomPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
Expand Down Expand Up @@ -154,24 +170,24 @@ protected override void SaveCore(ModelSaveContext ctx)

public ValueMapper<TIn, TOut> GetMapper<TIn, TOut>()
{
Contracts.Check(typeof(TIn) == typeof(VBuffer<Float>));
Contracts.Check(typeof(TOut) == typeof(Float));
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>));
Contracts.Check(typeof(TOut) == typeof(float));

ValueMapper<VBuffer<Float>, Float> del = Map;
ValueMapper<VBuffer<float>, float> del = Map;
return (ValueMapper<TIn, TOut>)(Delegate)del;
}

public ValueMapper<TIn, TOut, TDist> GetMapper<TIn, TOut, TDist>()
{
Contracts.Check(typeof(TIn) == typeof(VBuffer<Float>));
Contracts.Check(typeof(TOut) == typeof(Float));
Contracts.Check(typeof(TDist) == typeof(Float));
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>));
Contracts.Check(typeof(TOut) == typeof(float));
Contracts.Check(typeof(TDist) == typeof(float));

ValueMapper<VBuffer<Float>, Float, Float> del = MapDist;
ValueMapper<VBuffer<float>, float, float> del = MapDist;
return (ValueMapper<TIn, TOut, TDist>)(Delegate)del;
}

private Float PredictCore()
private float PredictCore()
{
// Predict can be called from different threads.
// Ensure your implementation is thread-safe
Expand All @@ -183,20 +199,20 @@ private Float PredictCore()
}
}

private void Map(ref VBuffer<Float> src, ref Float dst)
private void Map(ref VBuffer<float> src, ref float dst)
{
dst = PredictCore();
}

private void MapDist(ref VBuffer<Float> src, ref Float score, ref Float prob)
private void MapDist(ref VBuffer<float> src, ref float score, ref float prob)
{
score = PredictCore();
prob = (score + 1) / 2;
}
}

// Learns the prior distribution for 0/1 class labels and just outputs that.
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// [](start = 4, length = 2)

Make it ///

#Resolved

public sealed class PriorTrainer : TrainerBase<PriorPredictor>
public sealed class PriorTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<PriorPredictor>, PriorPredictor>
{
internal const string LoadNameValue = "PriorPredictor";
internal const string UserNameValue = "Prior Predictor";
Expand All @@ -210,13 +226,27 @@ public sealed class Arguments
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
public override TrainerInfo Info => _info;

protected override SchemaShape.Column[] OutputColumns { get; }

public PriorTrainer(IHostEnvironment env, Arguments args)
Copy link
Member

@sfilipi sfilipi Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public [](start = 8, length = 6)

private #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should I make the constructor private?


In reply to: 218154524 [](ancestors = 218154524)

: base(env, LoadNameValue)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), MakeFeatureColumn(DefaultColumnNames.Score), MakeFeatureColumn(DefaultColumnNames.Label), null)
{
Copy link
Member

@sfilipi sfilipi Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{ [](start = 8, length = 1)

keep the args check. #Resolved

Host.CheckValue(args, nameof(args));
OutputColumns = new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, isKey: false)
};
}

public PriorTrainer(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight)
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public [](start = 8, length = 6)

Can it be private?
I understand what you use TrainerEstimatorBase which requires feature column, but for this class it's pointless, and exposing it to user would be potentially confusing. #Resolved

: base(host, feature, label, weight)
{
OutputColumns = new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, isKey: false)
};
}

public override PriorPredictor Train(TrainContext context)
protected override PriorPredictor TrainModelCore(TrainContext context)
{
Contracts.CheckValue(context, nameof(context));
var data = context.TrainingSet;
Expand All @@ -234,16 +264,16 @@ public override PriorPredictor Train(TrainContext context)
using (var cursor = data.Data.GetRowCursor(c => c == col || c == colWeight))
{
var getLab = cursor.GetLabelFloatGetter(data);
var getWeight = colWeight >= 0 ? cursor.GetGetter<Float>(colWeight) : null;
Float lab = default(Float);
Float weight = 1;
var getWeight = colWeight >= 0 ? cursor.GetGetter<float>(colWeight) : null;
float lab = default(float);
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float [](start = 36, length = 5)

you can omit it. #Resolved

float weight = 1;
while (cursor.MoveNext())
{
getLab(ref lab);
if (getWeight != null)
{
getWeight(ref weight);
if (!(0 < weight && weight < Float.PositiveInfinity))
if (!(0 < weight && weight < float.PositiveInfinity))
continue;
}

Expand All @@ -255,14 +285,23 @@ public override PriorPredictor Train(TrainContext context)
}
}

Float prob = prob = pos + neg > 0 ? (Float)(pos / (pos + neg)) : Float.NaN;
float prob = prob = pos + neg > 0 ? (float)(pos / (pos + neg)) : float.NaN;
return new PriorPredictor(Host, prob);
}

protected override BinaryPredictionTransformer<PriorPredictor> MakeTransformer(PriorPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<PriorPredictor>(Host, model, trainSchema, FeatureColumn.Name);

private static SchemaShape.Column MakeFeatureColumn(string featureColumn) =>
new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);

private static SchemaShape.Column MakeLabelColumn(string labelColumn) =>
new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
}

public sealed class PriorPredictor :
PredictorBase<Float>,
IDistPredictorProducing<Float, Float>,
PredictorBase<float>,
IDistPredictorProducing<float, float>,
IValueMapperDist,
ICanSaveModel
{
Expand All @@ -277,13 +316,13 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}

private readonly Float _prob;
private readonly Float _raw;
private readonly float _prob;
private readonly float _raw;

public PriorPredictor(IHostEnvironment env, Float prob)
public PriorPredictor(IHostEnvironment env, float prob)
: base(env, LoaderSignature)
{
Host.Check(!Float.IsNaN(prob));
Host.Check(!float.IsNaN(prob));

_prob = prob;
_raw = 2 * _prob - 1; // This could be other functions -- logodds for instance
Expand All @@ -298,7 +337,7 @@ private PriorPredictor(IHostEnvironment env, ModelLoadContext ctx)
// Float: _prob

_prob = ctx.Reader.ReadFloat();
Host.CheckDecode(!Float.IsNaN(_prob));
Host.CheckDecode(!float.IsNaN(_prob));

_raw = 2 * _prob - 1;

Expand All @@ -321,7 +360,7 @@ protected override void SaveCore(ModelSaveContext ctx)
// *** Binary format ***
// Float: _prob

Contracts.Assert(!Float.IsNaN(_prob));
Contracts.Assert(!float.IsNaN(_prob));
ctx.Writer.Write(_prob);
}

Expand All @@ -333,29 +372,29 @@ public override PredictionKind PredictionKind

public ValueMapper<TIn, TOut> GetMapper<TIn, TOut>()
{
Contracts.Check(typeof(TIn) == typeof(VBuffer<Float>));
Contracts.Check(typeof(TOut) == typeof(Float));
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>));
Contracts.Check(typeof(TOut) == typeof(float));

ValueMapper<VBuffer<Float>, Float> del = Map;
ValueMapper<VBuffer<float>, float> del = Map;
return (ValueMapper<TIn, TOut>)(Delegate)del;
}

public ValueMapper<TIn, TOut, TDist> GetMapper<TIn, TOut, TDist>()
{
Contracts.Check(typeof(TIn) == typeof(VBuffer<Float>));
Contracts.Check(typeof(TOut) == typeof(Float));
Contracts.Check(typeof(TDist) == typeof(Float));
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>));
Contracts.Check(typeof(TOut) == typeof(float));
Contracts.Check(typeof(TDist) == typeof(float));

ValueMapper<VBuffer<Float>, Float, Float> del = MapDist;
ValueMapper<VBuffer<float>, float, float> del = MapDist;
return (ValueMapper<TIn, TOut, TDist>)(Delegate)del;
}

private void Map(ref VBuffer<Float> src, ref Float dst)
private void Map(ref VBuffer<float> src, ref float dst)
{
dst = _raw;
}

private void MapDist(ref VBuffer<Float> src, ref Float score, ref Float prob)
private void MapDist(ref VBuffer<float> src, ref float score, ref float prob)
{
score = _raw;
prob = _prob;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
maml.exe CV tr=PriorPredictor threads=- dout=%Output% data=%Data% seed=1
Not adding a normalizer.
Not training a calibrator because it is not needed.
Not adding a normalizer.
Not training a calibrator because it is not needed.
TEST POSITIVE RATIO: 0.3702 (134.0/(134.0+228.0))
Confusion table
||======================
PREDICTED || positive | negative | Recall
TRUTH ||======================
positive || 0 | 134 | 0.0000
negative || 0 | 228 | 1.0000
||======================
Precision || 0.0000 | 0.6298 |
OVERALL 0/1 ACCURACY: 0.629834
LOG LOSS/instance: 0.959786
Test-set entropy (prior Log-Loss/instance): 0.950799
LOG-LOSS REDUCTION (RIG): -0.945203
AUC: 0.500000
TEST POSITIVE RATIO: 0.3175 (107.0/(107.0+230.0))
Confusion table
||======================
PREDICTED || positive | negative | Recall
TRUTH ||======================
positive || 0 | 107 | 0.0000
negative || 0 | 230 | 1.0000
||======================
Precision || 0.0000 | 0.6825 |
OVERALL 0/1 ACCURACY: 0.682493
LOG LOSS/instance: 0.910421
Test-set entropy (prior Log-Loss/instance): 0.901650
LOG-LOSS REDUCTION (RIG): -0.972725
AUC: 0.500000

OVERALL RESULTS
---------------------------------------
AUC: 0.500000 (0.0000)
Accuracy: 0.656163 (0.0263)
Positive precision: 0.000000 (0.0000)
Positive recall: 0.000000 (0.0000)
Negative precision: 0.656163 (0.0263)
Negative recall: 1.000000 (0.0000)
Log-loss: 0.935104 (0.0247)
Log-loss reduction: -0.958964 (0.0138)
F1 Score: NaN (NaN)
AUPRC: 0.418968 (0.0212)

---------------------------------------
Physical memory usage(MB): %Number%
Virtual memory usage(MB): %Number%
%DateTime% Time elapsed(s): %Number%

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
PriorPredictor
AUC Accuracy Positive precision Positive recall Negative precision Negative recall Log-loss Log-loss reduction F1 Score AUPRC Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings
0.5 0.656163 0 0 0.656163 1 0.935104 -0.958964 NaN 0.418968 PriorPredictor %Data% %Output% 99 0 0 maml.exe CV tr=PriorPredictor threads=- dout=%Output% data=%Data% seed=1

Loading