Skip to content

Configurable Threshold for binary models #2969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer
[BestFriend]
private protected ISchemaBindableMapper BindableMapper;
[BestFriend]
private protected DataViewSchema TrainSchema;
internal DataViewSchema TrainSchema;

/// <summary>
/// Whether a call to <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
Expand Down
35 changes: 35 additions & 0 deletions src/Microsoft.ML.Data/TrainCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML.Calibrators;
Expand Down Expand Up @@ -274,6 +275,40 @@ public CrossValidationResult<CalibratedBinaryClassificationMetrics>[] CrossValid
Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
}

/// <summary>
/// Change threshold for binary model.
/// </summary>
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
/// <param name="chain">Chain of transformers.</param>
/// <param name="threshold">New threshold.</param>
/// <returns></returns>
public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshold<TModel>(TransformerChain<BinaryPredictionTransformer<TModel>> chain, float threshold)
Copy link
Member

@wschin wschin Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshold<TModel>(TransformerChain<BinaryPredictionTransformer<TModel>> chain, float threshold)
public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeDecisionThreshold<TModel>(TransformerChain<BinaryPredictionTransformer<TModel>> chain, float threshold)

Maybe? #WontFix

Copy link
Member

@wschin wschin Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this should be a new function. Could we add a parameter, threshold, to all binary trainers? #Pending

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, we can add that as parameter to binary trainer. Question is if you train your model, how you gonna change threshold? Retrain model?
I think this method has right to live.


In reply to: 266062138 [](ancestors = 266062138)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retrain looks fine to me. I really don't feel adding a helper function is a good idea. This is not a Transformer, so I expect it will become a orphan in the future. Like FFM, PFI and so on don't care about it because it's not a standard binary classifier.


In reply to: 266088129 [](ancestors = 266088129,266062138)

Copy link
Contributor

@TomFinley TomFinley Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this should be a new function. Could we add a parameter, threshold, to all binary trainers? #Pending

Historically we have found that adding options to "all" trainers just invites inconsistency and is a nightmare from a maintainability perspective. For those reasons we no longer do that. So I strongly object to that. There is also the larger, more practical problem that choosing the right threshold is something that you can only really do once you have investigated it -- that is, it is very often a post training operation, not something you do pre-training.

This sort of "composable" nature of IDataView is actually I think something we need to reiterate, since it was the key to making our development efforts scale; and that composability is built around having simple, comprehensible units of computation. Not big bundled components that tried to do everything themselves. We already tried that way, and life was a lot worse and more inconsistent before we had it, and reverting to the "old ways" of every conceivable functionality bundled into a single operation would just reintroduce the old problems that led us to move to many operations of simple operators in the first place. #Resolved

Copy link
Contributor

@TomFinley TomFinley Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I think about it more, there's something about this idea of getting ITransformer implementors from existing ITransformer implementors that I find very appealing. Not just for this (which is a worthy use of this idea), but many other scenarios as well.

So for example, certain regressor algorithms are parametric w.r.t. their labels (in fact, most are). But there's a problem with merely normalizing the label, because then the predicted label is according to that same scale. In sklearn you could accomplish this fairly easily via the inverse_transform method on their equivalent of what we call a normalizer, the StandardScalar. So imagine you could get from a NormalizerTransformer another NormalizerTransformer that provides the inverse offset and scaling for any affine normalization, and whatnot. That would be pretty nice, would it not be?

So far from discouraging this pattern, I think we should do more of it. #Resolved

where TModel : class
{
if (chain.LastTransformer.Threshold == threshold)
return chain;
List<ITransformer> transformers = new List<ITransformer>();
var predictionTransformer = chain.LastTransformer;
Copy link
Member

@sfilipi sfilipi Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chain.LastTransformer [](start = 40, length = 21)

I don't like the assumption that the predictor is the last one, it might not be.

IMO the only API existing for this should be the second one.

If we have to have this API, i think we should minimally take in the index of the predicitonTransformer, in the pipeline, and check whether that transformer is a binaryTransformer. #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point @sfilipi. I think you're probably right about this.


In reply to: 266034490 [](ancestors = 266034490)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea was to provide helper function for user since work with transform chain is kinda painful, at least from my point.
But I can have one method.


In reply to: 268206172 [](ancestors = 268206172,266034490)

foreach (var transform in chain)
{
if (transform != predictionTransformer)
Copy link
Contributor

@TomFinley TomFinley Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predictionTransformer [](start = 33, length = 21)

Can we change this just a little please? I would prefer that we just add all transforms except the last unconditionally, which would be a fairly easy thing to do.

Edit: Actually no @sfilipi is right, I think operating over chains is misguided now that I see her argument... #Resolved

transformers.Add(transform);
}

transformers.Add(new BinaryPredictionTransformer<TModel>(Environment, predictionTransformer.Model,
predictionTransformer.TrainSchema, predictionTransformer.FeatureColumn,
threshold, predictionTransformer.ThresholdColumn));
return new TransformerChain<BinaryPredictionTransformer<TModel>>(transformers.ToArray());
}

public BinaryPredictionTransformer<TModel> ChangeModelThreshold<TModel>(BinaryPredictionTransformer<TModel> model, float threshold)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs documentation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should put XML comments on all public members.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where TModel : class
{
if (model.Threshold == threshold)
Copy link
Member

@sfilipi sfilipi Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (model.Threshold == threshold) [](start = 12, length = 33)

do you want to warn here? #WontFix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should provide the same warning that C# does when you have a variable like int a = 5 and then assign 5 to it later.


In reply to: 265862991 [](ancestors = 265862991)

return model;
return new BinaryPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumn, threshold, model.ThresholdColumn);
}

/// <summary>
/// The list of trainers for performing binary classification.
/// </summary>
Expand Down
95 changes: 64 additions & 31 deletions test/Microsoft.ML.Functional.Tests/Prediction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,26 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.ML.Functional.Tests.Datasets;
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.ML.Functional.Tests
{
public class PredictionScenarios
public class PredictionScenarios : BaseTestClass
{
public PredictionScenarios(ITestOutputHelper output) : base(output)
{
}

class Answer
Copy link
Member

@sfilipi sfilipi Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answer [](start = 14, length = 6)

Prediction or DataWithPrediction #Resolved

{
public float Score { get; set; }
public bool PredictedLabel { get; set; }
}
/// <summary>
/// Reconfigurable predictions: The following should be possible: A user trains a binary classifier,
/// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold
Expand All @@ -19,36 +31,57 @@ public class PredictionScenarios
[Fact]
public void ReconfigurablePrediction()
{
var mlContext = new MLContext(seed: 789);

// Get the dataset, create a train and test
var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(),
hasHeader: TestDatasets.housing.fileHasHeader, separatorChar: TestDatasets.housing.fileSeparator)
.Load(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename));
var split = mlContext.Data.TrainTestSplit(data, testFraction: 0.2);

// Create a pipeline to train on the housing data
var pipeline = mlContext.Transforms.Concatenate("Features", new string[] {
"CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling",
"PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"})
.Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue"))
.Append(mlContext.Regression.Trainers.Ols());

var model = pipeline.Fit(split.TrainSet);

var scoredTest = model.Transform(split.TestSet);
var metrics = mlContext.Regression.Evaluate(scoredTest);

Common.AssertMetrics(metrics);

// Todo #2465: Allow the setting of threshold and thresholdColumn for scoring.
Copy link
Contributor

@rogancarr rogancarr Mar 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo #2465: [](start = 15, length = 11)

Thank you! #Resolved

// This is no longer possible in the API
//var newModel = new BinaryPredictionTransformer<IPredictorProducing<float>>(ml, model.Model, trainData.Schema, model.FeatureColumnName, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability);
//var newScoredTest = newModel.Transform(pipeline.Transform(testData));
//var newMetrics = mlContext.BinaryClassification.Evaluate(scoredTest);
// And the Threshold and ThresholdColumn properties are not settable.
//var predictor = model.LastTransformer;
//predictor.Threshold = 0.01; // Not possible
var mlContext = new MLContext(seed: 1);

var data = mlContext.Data.LoadFromTextFile<TweetSentiment>(GetDataPath(TestDatasets.Sentiment.trainFilename),
Copy link
Member

@wschin wschin Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we try not to load file everywhere? It will be faster to just use in-memory data. #WontFix

Copy link
Contributor

@rogancarr rogancarr Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have standard test datasets saved to files that we use in tests. #ByDesign

hasHeader: TestDatasets.Sentiment.fileHasHeader,
separatorChar: TestDatasets.Sentiment.fileSeparator);

// Create a training pipeline.
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
.AppendCacheCheckpoint(mlContext)
.Append(mlContext.BinaryClassification.Trainers.LogisticRegression(
new Trainers.LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 }));

// Train the model.
var model = pipeline.Fit(data);
var engine = model.CreatePredictionEngine<TweetSentiment, Answer>(mlContext);
var pr = engine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" });
// Score is 0.64 so predicted label is true.
Assert.True(pr.PredictedLabel);
Assert.True(pr.Score > 0);
var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, 0.7f);
var newEngine = newModel.CreatePredictionEngine<TweetSentiment, Answer>(mlContext);
pr = newEngine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" });
// Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false.

Assert.False(pr.PredictedLabel);
Assert.False(pr.Score > 0.7);
}

[Fact]
public void ReconfigurablePredictionNoPipeline()
{
var mlContext = new MLContext(seed: 1);

var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());
var pipeline = mlContext.BinaryClassification.Trainers.LogisticRegression(
new Trainers.LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 });
var model = pipeline.Fit(data);
var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, -2.0f);
var rnd = new Random(1);
var randomDataPoint = TypeTestData.GetRandomInstance(rnd);
var engine = model.CreatePredictionEngine<TypeTestData, Answer>(mlContext);
var pr = engine.Predict(randomDataPoint);
// Score is -1.38 so predicted label is false.
Assert.False(pr.PredictedLabel);
Assert.True(pr.Score <= 0);
var newEngine = newModel.CreatePredictionEngine<TypeTestData, Answer>(mlContext);
pr = newEngine.Predict(randomDataPoint);
// Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true.
Assert.True(pr.PredictedLabel);
Assert.True(pr.Score <= 0);
}

}
}