Skip to content

Conversion of prior and random trainers to estimators #876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 19, 2018
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,18 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer

public TModel Model { get; }

public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn = null)
Copy link
Contributor

@Zruty0 Zruty0 Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

= null [](start = 108, length = 7)

nope, that should still be required param. Just 'null' should now be acceptable. #Closed

{
Contracts.CheckValue(host, nameof(host));
Contracts.CheckValueOrNull(featureColumn);
Host = host;
Host.CheckValue(trainSchema, nameof(trainSchema));

Model = model;
FeatureColumn = featureColumn;
if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
if (!trainSchema.TryGetColumnIndex(featureColumn, out int col) && (featureColumn != null))
Copy link
Contributor

@Zruty0 Zruty0 Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

&& [](start = 75, length = 2)

swap the 2 conditions or you get a wrong error :) #Closed

throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
FeatureColumnType = trainSchema.GetColumnType(col);
FeatureColumnType = (featureColumn != null) ? trainSchema.GetColumnType(col) : null;

TrainSchema = trainSchema;
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
Expand Down Expand Up @@ -133,7 +134,7 @@ public sealed class BinaryPredictionTransformer<TModel> : PredictionTransformerB
public readonly string ThresholdColumn;
public readonly float Threshold;

public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn,
public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn = null,
float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
{
Expand Down
143 changes: 92 additions & 51 deletions src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Float = System.Single;

using System;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
Expand All @@ -13,6 +11,8 @@
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Core.Data;
using System.Linq;
Copy link
Contributor

@Zruty0 Zruty0 Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

System [](start = 6, length = 6)

sort namespaces #Closed


[assembly: LoadableClass(RandomTrainer.Summary, typeof(RandomTrainer), typeof(RandomTrainer.Arguments),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) },
Expand All @@ -38,20 +38,16 @@ namespace Microsoft.ML.Runtime.Learners
/// <summary>
/// A trainer that trains a predictor that returns random values
/// </summary>
public sealed class RandomTrainer : TrainerBase<RandomPredictor>

public sealed class RandomTrainer : TrainerBase<RandomPredictor>,
Copy link
Member

@sfilipi sfilipi Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TrainerBase [](start = 40, length = 28)

I think this can go away too. It will eventually, to my understanding. @Zruty0 can confirm. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remove that then I have to implement the ITrainer interface as well, which would basically mean implementing all of TrainerBase.
This is because I am not deriving from TrainerEstimatorBase, which implements ITrainer.


In reply to: 218152453 [](ancestors = 218152453)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is an exception :)


In reply to: 218156771 [](ancestors = 218156771,218152453)

ITrainerEstimator<BinaryPredictionTransformer<RandomPredictor>, RandomPredictor>
{
internal const string LoadNameValue = "RandomPredictor";
internal const string UserNameValue = "Random Predictor";
internal const string Summary = "A toy predictor that returns a random value.";

public class Arguments
public sealed class Arguments
Copy link
Member

@sfilipi sfilipi Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public [](start = 8, length = 6)

as you go through those, try adding XML comments to everything public, as they display on the docs site: docs.microsoft.com #Resolved

{
// Some sample arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr")]
public Float LearningRate = (Float)1.0;

[Argument(ArgumentType.AtMostOnce, HelpText = "Some bool arg", ShortName = "boolarg")]
public bool BooleanArg = false;
}

public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
Expand All @@ -65,20 +61,40 @@ public RandomTrainer(IHostEnvironment env, Arguments args)
Host.CheckValue(args, nameof(args));
}

Copy link
Contributor

@Zruty0 Zruty0 Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is obviously not sufficient. You need to list the columns that you are going to output. #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same goes for the other trainer


In reply to: 216508453 [](ancestors = 216508453)

public BinaryPredictionTransformer<RandomPredictor> Fit(IDataView input)
{
var cachedTrain = Info.WantCaching ? new CacheDataView(Host, input, prefetch: null) : input;

RoleMappedData trainRoles = new RoleMappedData(cachedTrain);
var pred = Train(new TrainContext(trainRoles));
return new BinaryPredictionTransformer<RandomPredictor>(Host, pred, cachedTrain.Schema, featureColumn: null);
}

public override RandomPredictor Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
return new RandomPredictor(Host, Host.Rand.Next());
}

public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
var newColumn = new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, isKey: false);
outColumns[DefaultColumnNames.Score] = newColumn;

return new SchemaShape(outColumns.Values);
}
}

/// <summary>
/// The predictor implements the Predict() interface. The predictor returns a
/// uniform random probability and classification assignment.
/// </summary>
public sealed class RandomPredictor :
PredictorBase<Float>,
IDistPredictorProducing<Float, Float>,
PredictorBase<float>,
IDistPredictorProducing<float, float>,
IValueMapperDist,
ICanSaveModel
{
Expand All @@ -96,7 +112,7 @@ private static VersionInfo GetVersionInfo()
// Keep all the serializable state here.
private readonly int _seed;
private readonly object _instanceLock;
private readonly Random _random;
private readonly IRandom _random;

public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
public ColumnType InputType { get; }
Expand All @@ -109,7 +125,7 @@ public RandomPredictor(IHostEnvironment env, int seed)
_seed = seed;

_instanceLock = new object();
_random = new Random(_seed);
_random = RandomUtils.Create(_seed);

InputType = new VectorType(NumberType.Float);
}
Expand All @@ -126,7 +142,9 @@ private RandomPredictor(IHostEnvironment env, ModelLoadContext ctx)
_seed = ctx.Reader.ReadInt32();

_instanceLock = new object();
_random = new Random(_seed);
_random = RandomUtils.Create(_seed);

InputType = new VectorType(NumberType.Float);
}

public static RandomPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
Expand Down Expand Up @@ -154,24 +172,24 @@ protected override void SaveCore(ModelSaveContext ctx)

public ValueMapper<TIn, TOut> GetMapper<TIn, TOut>()
{
Contracts.Check(typeof(TIn) == typeof(VBuffer<Float>));
Contracts.Check(typeof(TOut) == typeof(Float));
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>));
Contracts.Check(typeof(TOut) == typeof(float));

ValueMapper<VBuffer<Float>, Float> del = Map;
ValueMapper<VBuffer<float>, float> del = Map;
return (ValueMapper<TIn, TOut>)(Delegate)del;
}

public ValueMapper<TIn, TOut, TDist> GetMapper<TIn, TOut, TDist>()
{
Contracts.Check(typeof(TIn) == typeof(VBuffer<Float>));
Contracts.Check(typeof(TOut) == typeof(Float));
Contracts.Check(typeof(TDist) == typeof(Float));
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>));
Contracts.Check(typeof(TOut) == typeof(float));
Contracts.Check(typeof(TDist) == typeof(float));

ValueMapper<VBuffer<Float>, Float, Float> del = MapDist;
ValueMapper<VBuffer<float>, float, float> del = MapDist;
return (ValueMapper<TIn, TOut, TDist>)(Delegate)del;
}

private Float PredictCore()
private float PredictCore()
{
// Predict can be called from different threads.
// Ensure your implementation is thread-safe
Expand All @@ -183,20 +201,20 @@ private Float PredictCore()
}
}

private void Map(ref VBuffer<Float> src, ref Float dst)
private void Map(ref VBuffer<float> src, ref float dst)
{
dst = PredictCore();
}

private void MapDist(ref VBuffer<Float> src, ref Float score, ref Float prob)
private void MapDist(ref VBuffer<float> src, ref float score, ref float prob)
{
score = PredictCore();
prob = (score + 1) / 2;
}
}

// Learns the prior distribution for 0/1 class labels and just outputs that.
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// [](start = 4, length = 2)

Make it ///

#Resolved

public sealed class PriorTrainer : TrainerBase<PriorPredictor>
public sealed class PriorTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<PriorPredictor>, PriorPredictor>
{
internal const string LoadNameValue = "PriorPredictor";
internal const string UserNameValue = "Prior Predictor";
Expand All @@ -210,13 +228,27 @@ public sealed class Arguments
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
public override TrainerInfo Info => _info;

protected override SchemaShape.Column[] OutputColumns { get; }

public PriorTrainer(IHostEnvironment env, Arguments args)
Copy link
Member

@sfilipi sfilipi Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public [](start = 8, length = 6)

private #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should I make the constructor private?


In reply to: 218154524 [](ancestors = 218154524)

: base(env, LoadNameValue)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), MakeFeatureColumn(DefaultColumnNames.Score), MakeFeatureColumn(DefaultColumnNames.Label), null)
{
Copy link
Member

@sfilipi sfilipi Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{ [](start = 8, length = 1)

keep the args check. #Resolved

Host.CheckValue(args, nameof(args));
OutputColumns = new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, isKey: false)
};
}

public override PriorPredictor Train(TrainContext context)
public PriorTrainer(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight)
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public [](start = 8, length = 6)

Can it be private?
I understand what you use TrainerEstimatorBase which requires feature column, but for this class it's pointless, and exposing it to user would be potentially confusing. #Resolved

: base(host, feature, label, weight)
{
OutputColumns = new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, isKey: false)
};
}

protected override PriorPredictor TrainModelCore(TrainContext context)
{
Contracts.CheckValue(context, nameof(context));
var data = context.TrainingSet;
Expand All @@ -234,16 +266,16 @@ public override PriorPredictor Train(TrainContext context)
using (var cursor = data.Data.GetRowCursor(c => c == col || c == colWeight))
{
var getLab = cursor.GetLabelFloatGetter(data);
var getWeight = colWeight >= 0 ? cursor.GetGetter<Float>(colWeight) : null;
Float lab = default(Float);
Float weight = 1;
var getWeight = colWeight >= 0 ? cursor.GetGetter<float>(colWeight) : null;
float lab = default(float);
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float [](start = 36, length = 5)

you can omit it. #Resolved

float weight = 1;
while (cursor.MoveNext())
{
getLab(ref lab);
if (getWeight != null)
{
getWeight(ref weight);
if (!(0 < weight && weight < Float.PositiveInfinity))
if (!(0 < weight && weight < float.PositiveInfinity))
continue;
}

Expand All @@ -255,14 +287,23 @@ public override PriorPredictor Train(TrainContext context)
}
}

Float prob = prob = pos + neg > 0 ? (Float)(pos / (pos + neg)) : Float.NaN;
float prob = prob = pos + neg > 0 ? (float)(pos / (pos + neg)) : float.NaN;
return new PriorPredictor(Host, prob);
}

protected override BinaryPredictionTransformer<PriorPredictor> MakeTransformer(PriorPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<PriorPredictor>(Host, model, trainSchema, FeatureColumn.Name);

private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
=> new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);

private static SchemaShape.Column MakeLabelColumn(string labelColumn)
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
}

public sealed class PriorPredictor :
PredictorBase<Float>,
IDistPredictorProducing<Float, Float>,
PredictorBase<float>,
IDistPredictorProducing<float, float>,
IValueMapperDist,
ICanSaveModel
{
Expand All @@ -277,13 +318,13 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}

private readonly Float _prob;
private readonly Float _raw;
private readonly float _prob;
private readonly float _raw;

public PriorPredictor(IHostEnvironment env, Float prob)
public PriorPredictor(IHostEnvironment env, float prob)
: base(env, LoaderSignature)
{
Host.Check(!Float.IsNaN(prob));
Host.Check(!float.IsNaN(prob));

_prob = prob;
_raw = 2 * _prob - 1; // This could be other functions -- logodds for instance
Expand All @@ -298,7 +339,7 @@ private PriorPredictor(IHostEnvironment env, ModelLoadContext ctx)
// Float: _prob

_prob = ctx.Reader.ReadFloat();
Host.CheckDecode(!Float.IsNaN(_prob));
Host.CheckDecode(!float.IsNaN(_prob));

_raw = 2 * _prob - 1;

Expand All @@ -321,7 +362,7 @@ protected override void SaveCore(ModelSaveContext ctx)
// *** Binary format ***
// Float: _prob

Contracts.Assert(!Float.IsNaN(_prob));
Contracts.Assert(!float.IsNaN(_prob));
ctx.Writer.Write(_prob);
}

Expand All @@ -333,29 +374,29 @@ public override PredictionKind PredictionKind

public ValueMapper<TIn, TOut> GetMapper<TIn, TOut>()
{
Contracts.Check(typeof(TIn) == typeof(VBuffer<Float>));
Contracts.Check(typeof(TOut) == typeof(Float));
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>));
Contracts.Check(typeof(TOut) == typeof(float));

ValueMapper<VBuffer<Float>, Float> del = Map;
ValueMapper<VBuffer<float>, float> del = Map;
return (ValueMapper<TIn, TOut>)(Delegate)del;
}

public ValueMapper<TIn, TOut, TDist> GetMapper<TIn, TOut, TDist>()
{
Contracts.Check(typeof(TIn) == typeof(VBuffer<Float>));
Contracts.Check(typeof(TOut) == typeof(Float));
Contracts.Check(typeof(TDist) == typeof(Float));
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>));
Contracts.Check(typeof(TOut) == typeof(float));
Contracts.Check(typeof(TDist) == typeof(float));

ValueMapper<VBuffer<Float>, Float, Float> del = MapDist;
ValueMapper<VBuffer<float>, float, float> del = MapDist;
return (ValueMapper<TIn, TOut, TDist>)(Delegate)del;
}

private void Map(ref VBuffer<Float> src, ref Float dst)
private void Map(ref VBuffer<float> src, ref float dst)
{
dst = _raw;
}

private void MapDist(ref VBuffer<Float> src, ref Float score, ref Float prob)
private void MapDist(ref VBuffer<float> src, ref float score, ref float prob)
{
score = _raw;
prob = _prob;
Expand Down
Loading