Skip to content

Conversion of prior and random trainers to estimators #876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 19, 2018
Merged
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Core.Data;

[assembly: LoadableClass(RandomTrainer.Summary, typeof(RandomTrainer), typeof(RandomTrainer.Arguments),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) },
Expand All @@ -38,38 +39,37 @@ namespace Microsoft.ML.Runtime.Learners
/// <summary>
/// A trainer that trains a predictor that returns random values
/// </summary>
public sealed class RandomTrainer : TrainerBase<RandomPredictor>

public sealed class RandomTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<RandomPredictor>, RandomPredictor>
Copy link
Contributor

@Zruty0 Zruty0 Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TrainerEstimatorBase [](start = 40, length = 20)

It doesn't help you to derive from this base class, and it also hurts: you now have to take label, feature and weight columns, but you don't actually need them at all.

So, don't derive from TrainerEstimatorBase and just implement ITrainerEstimator #Closed

{
internal const string LoadNameValue = "RandomPredictor";
internal const string UserNameValue = "Random Predictor";
internal const string Summary = "A toy predictor that returns a random value.";

public class Arguments
{
// Some sample arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr")]
public Float LearningRate = (Float)1.0;

[Argument(ArgumentType.AtMostOnce, HelpText = "Some bool arg", ShortName = "boolarg")]
public bool BooleanArg = false;
}

public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
public override TrainerInfo Info => _info;

public RandomTrainer(IHostEnvironment env, Arguments args)
Copy link
Contributor

@Zruty0 Zruty0 Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RandomTrainer [](start = 15, length = 13)

You should not remove this constructor. Once you start writing the test you will discover that it is necessary.

Add a second constructor instead. Both here and for Prior #Closed

: base(env, LoadNameValue)
protected override SchemaShape.Column[] OutputColumns => throw new NotImplementedException();

Copy link
Contributor

@Zruty0 Zruty0 Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is obviously not sufficient. You need to list the columns that you are going to output. #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same goes for the other trainer


In reply to: 216508453 [](ancestors = 216508453)

public RandomTrainer(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight)
: base(host, feature, label, weight)
{
Host.CheckValue(args, nameof(args));
}

public override RandomPredictor Train(TrainContext context)
protected override RandomPredictor TrainModelCore(TrainContext trainContext)
{
Host.CheckValue(context, nameof(context));
Host.CheckValue(trainContext, nameof(trainContext));
return new RandomPredictor(Host, Host.Rand.Next());
}

protected override BinaryPredictionTransformer<RandomPredictor> MakeTransformer(RandomPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<RandomPredictor>(Host, model, trainSchema, FeatureColumn.Name);
}

/// <summary>
Expand Down Expand Up @@ -196,7 +196,7 @@ private void MapDist(ref VBuffer<Float> src, ref Float score, ref Float prob)
}

// Learns the prior distribution for 0/1 class labels and just outputs that.
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// [](start = 4, length = 2)

Make it ///

#Resolved

public sealed class PriorTrainer : TrainerBase<PriorPredictor>
public sealed class PriorTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<PriorPredictor>, PriorPredictor>
{
internal const string LoadNameValue = "PriorPredictor";
internal const string UserNameValue = "Prior Predictor";
Expand All @@ -210,13 +210,14 @@ public sealed class Arguments
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
public override TrainerInfo Info => _info;

public PriorTrainer(IHostEnvironment env, Arguments args)
: base(env, LoadNameValue)
protected override SchemaShape.Column[] OutputColumns { get; }

public PriorTrainer(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight)
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public [](start = 8, length = 6)

Can it be private?
I understand what you use TrainerEstimatorBase which requires feature column, but for this class it's pointless, and exposing it to user would be potentially confusing. #Resolved

: base(host, feature, label, weight)
{
Host.CheckValue(args, nameof(args));
}

public override PriorPredictor Train(TrainContext context)
protected override PriorPredictor TrainModelCore(TrainContext context)
{
Contracts.CheckValue(context, nameof(context));
var data = context.TrainingSet;
Expand Down Expand Up @@ -258,6 +259,10 @@ public override PriorPredictor Train(TrainContext context)
Float prob = prob = pos + neg > 0 ? (Float)(pos / (pos + neg)) : Float.NaN;
return new PriorPredictor(Host, prob);
}

protected override BinaryPredictionTransformer<PriorPredictor> MakeTransformer(PriorPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<PriorPredictor>(Host, model, trainSchema, FeatureColumn.Name);

}

public sealed class PriorPredictor :
Expand Down