-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Configurable Threshold for binary models #2969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using Microsoft.Data.DataView; | ||
using Microsoft.ML.Calibrators; | ||
|
@@ -274,6 +275,40 @@ public CrossValidationResult<CalibratedBinaryClassificationMetrics>[] CrossValid | |
Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); | ||
} | ||
|
||
/// <summary> | ||
/// Change threshold for binary model. | ||
/// </summary> | ||
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam> | ||
/// <param name="chain">Chain of transformers.</param> | ||
/// <param name="threshold">New threshold.</param> | ||
/// <returns></returns> | ||
public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshold<TModel>(TransformerChain<BinaryPredictionTransformer<TModel>> chain, float threshold) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if this should be a new function. Could we add a parameter, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, we can add that as parameter to binary trainer. Question is if you train your model, how you gonna change threshold? Retrain model? In reply to: 266062138 [](ancestors = 266062138) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Retrain looks fine to me. I really don't feel adding a helper function is a good idea. This is not a Transformer, so I expect it will become a orphan in the future. Like FFM, PFI and so on don't care about it because it's not a standard binary classifier. In reply to: 266088129 [](ancestors = 266088129,266062138) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Historically we have found that adding options to "all" trainers just invites inconsistency and is a nightmare from a maintainability perspective. For those reasons we no longer do that. So I strongly object to that. There is also the larger, more practical problem that choosing the right threshold is something that you can only really do once you have investigated it -- that is, it is very often a post training operation, not something you do pre-training. This sort of "composable" nature of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I think about it more, there's something about this idea of getting So for example, certain regressor algorithms are parametric w.r.t. their labels (in fact, most are). But there's a problem with merely normalizing the label, because then the predicted label is according to that same scale. In So far from discouraging this pattern, I think we should do more of it. #Resolved |
||
where TModel : class | ||
{ | ||
if (chain.LastTransformer.Threshold == threshold) | ||
return chain; | ||
List<ITransformer> transformers = new List<ITransformer>(); | ||
var predictionTransformer = chain.LastTransformer; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't like the assumption that the predictor is the last one, it might not be. IMO the only API existing for this should be the second one. If we have to have this API, i think we should minimally take in the index of the predicitonTransformer, in the pipeline, and check whether that transformer is a binaryTransformer. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My idea was to provide helper function for user since work with transform chain is kinda painful, at least from my point. In reply to: 268206172 [](ancestors = 268206172,266034490) |
||
foreach (var transform in chain) | ||
{ | ||
if (transform != predictionTransformer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can we change this just a little please? I would prefer that we just add all transforms except the last unconditionally, which would be a fairly easy thing to do. Edit: Actually no @sfilipi is right, I think operating over chains is misguided now that I see her argument... #Resolved |
||
transformers.Add(transform); | ||
} | ||
|
||
transformers.Add(new BinaryPredictionTransformer<TModel>(Environment, predictionTransformer.Model, | ||
predictionTransformer.TrainSchema, predictionTransformer.FeatureColumn, | ||
threshold, predictionTransformer.ThresholdColumn)); | ||
return new TransformerChain<BinaryPredictionTransformer<TModel>>(transformers.ToArray()); | ||
} | ||
|
||
public BinaryPredictionTransformer<TModel> ChangeModelThreshold<TModel>(BinaryPredictionTransformer<TModel> model, float threshold) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. needs documentation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should put XML comments on all public members. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
where TModel : class | ||
{ | ||
if (model.Threshold == threshold) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
do you want to warn here? #WontFix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should provide the same warning that C# does when you have a variable like In reply to: 265862991 [](ancestors = 265862991) |
||
return model; | ||
return new BinaryPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumn, threshold, model.ThresholdColumn); | ||
} | ||
|
||
/// <summary> | ||
/// The list of trainers for performing binary classification. | ||
/// </summary> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,14 +2,26 @@ | |
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using Microsoft.ML.Functional.Tests.Datasets; | ||
using Microsoft.ML.RunTests; | ||
using Microsoft.ML.TestFramework; | ||
using Xunit; | ||
using Xunit.Abstractions; | ||
|
||
namespace Microsoft.ML.Functional.Tests | ||
{ | ||
public class PredictionScenarios | ||
public class PredictionScenarios : BaseTestClass | ||
{ | ||
public PredictionScenarios(ITestOutputHelper output) : base(output) | ||
{ | ||
} | ||
|
||
class Answer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Prediction or DataWithPrediction #Resolved |
||
{ | ||
public float Score { get; set; } | ||
public bool PredictedLabel { get; set; } | ||
} | ||
/// <summary> | ||
/// Reconfigurable predictions: The following should be possible: A user trains a binary classifier, | ||
/// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold | ||
|
@@ -19,36 +31,57 @@ public class PredictionScenarios | |
[Fact] | ||
public void ReconfigurablePrediction() | ||
{ | ||
var mlContext = new MLContext(seed: 789); | ||
|
||
// Get the dataset, create a train and test | ||
var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), | ||
hasHeader: TestDatasets.housing.fileHasHeader, separatorChar: TestDatasets.housing.fileSeparator) | ||
.Load(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename)); | ||
var split = mlContext.Data.TrainTestSplit(data, testFraction: 0.2); | ||
|
||
// Create a pipeline to train on the housing data | ||
var pipeline = mlContext.Transforms.Concatenate("Features", new string[] { | ||
"CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", | ||
"PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"}) | ||
.Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue")) | ||
.Append(mlContext.Regression.Trainers.Ols()); | ||
|
||
var model = pipeline.Fit(split.TrainSet); | ||
|
||
var scoredTest = model.Transform(split.TestSet); | ||
var metrics = mlContext.Regression.Evaluate(scoredTest); | ||
|
||
Common.AssertMetrics(metrics); | ||
|
||
// Todo #2465: Allow the setting of threshold and thresholdColumn for scoring. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thank you! #Resolved |
||
// This is no longer possible in the API | ||
//var newModel = new BinaryPredictionTransformer<IPredictorProducing<float>>(ml, model.Model, trainData.Schema, model.FeatureColumnName, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability); | ||
//var newScoredTest = newModel.Transform(pipeline.Transform(testData)); | ||
//var newMetrics = mlContext.BinaryClassification.Evaluate(scoredTest); | ||
// And the Threshold and ThresholdColumn properties are not settable. | ||
//var predictor = model.LastTransformer; | ||
//predictor.Threshold = 0.01; // Not possible | ||
var mlContext = new MLContext(seed: 1); | ||
|
||
var data = mlContext.Data.LoadFromTextFile<TweetSentiment>(GetDataPath(TestDatasets.Sentiment.trainFilename), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we try not to load file everywhere? It will be faster to just use in-memory data. #WontFix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have standard test datasets saved to files that we use in tests. #ByDesign |
||
hasHeader: TestDatasets.Sentiment.fileHasHeader, | ||
separatorChar: TestDatasets.Sentiment.fileSeparator); | ||
|
||
// Create a training pipeline. | ||
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText") | ||
.AppendCacheCheckpoint(mlContext) | ||
.Append(mlContext.BinaryClassification.Trainers.LogisticRegression( | ||
new Trainers.LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 })); | ||
|
||
// Train the model. | ||
var model = pipeline.Fit(data); | ||
var engine = model.CreatePredictionEngine<TweetSentiment, Answer>(mlContext); | ||
var pr = engine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" }); | ||
// Score is 0.64 so predicted label is true. | ||
Assert.True(pr.PredictedLabel); | ||
Assert.True(pr.Score > 0); | ||
var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, 0.7f); | ||
var newEngine = newModel.CreatePredictionEngine<TweetSentiment, Answer>(mlContext); | ||
pr = newEngine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" }); | ||
// Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false. | ||
|
||
Assert.False(pr.PredictedLabel); | ||
Assert.False(pr.Score > 0.7); | ||
} | ||
|
||
[Fact] | ||
public void ReconfigurablePredictionNoPipeline() | ||
{ | ||
var mlContext = new MLContext(seed: 1); | ||
|
||
var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset()); | ||
var pipeline = mlContext.BinaryClassification.Trainers.LogisticRegression( | ||
new Trainers.LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 }); | ||
var model = pipeline.Fit(data); | ||
var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, -2.0f); | ||
var rnd = new Random(1); | ||
var randomDataPoint = TypeTestData.GetRandomInstance(rnd); | ||
var engine = model.CreatePredictionEngine<TypeTestData, Answer>(mlContext); | ||
var pr = engine.Predict(randomDataPoint); | ||
// Score is -1.38 so predicted label is false. | ||
Assert.False(pr.PredictedLabel); | ||
Assert.True(pr.Score <= 0); | ||
var newEngine = newModel.CreatePredictionEngine<TypeTestData, Answer>(mlContext); | ||
pr = newEngine.Predict(randomDataPoint); | ||
// Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true. | ||
Assert.True(pr.PredictedLabel); | ||
Assert.True(pr.Score <= 0); | ||
} | ||
|
||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe? #WontFix