Skip to content

Commit 0b638bf

Browse files
authored
Configurable Threshold for binary models (#2969)
1 parent d38a35e commit 0b638bf

File tree

3 files changed

+84
-32
lines changed

3 files changed

+84
-32
lines changed

src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer
5252
[BestFriend]
5353
private protected ISchemaBindableMapper BindableMapper;
5454
[BestFriend]
55-
private protected DataViewSchema TrainSchema;
55+
internal DataViewSchema TrainSchema;
5656

5757
/// <summary>
5858
/// Whether a call to <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> should succeed, on an

src/Microsoft.ML.Data/TrainCatalog.cs

+8
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,14 @@ public IReadOnlyList<CrossValidationResult<CalibratedBinaryClassificationMetrics
256256
Evaluate(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray();
257257
}
258258

259+
public BinaryPredictionTransformer<TModel> ChangeModelThreshold<TModel>(BinaryPredictionTransformer<TModel> model, float threshold)
260+
where TModel : class
261+
{
262+
if (model.Threshold == threshold)
263+
return model;
264+
return new BinaryPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumnName, threshold, model.ThresholdColumn);
265+
}
266+
259267
/// <summary>
260268
/// The list of trainers for performing binary classification.
261269
/// </summary>

test/Microsoft.ML.Functional.Tests/Prediction.cs

+75-31
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,30 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
6+
using System.Collections.Generic;
7+
using Microsoft.ML.Calibrators;
8+
using Microsoft.ML.Data;
9+
using Microsoft.ML.Functional.Tests.Datasets;
510
using Microsoft.ML.RunTests;
611
using Microsoft.ML.TestFramework;
12+
using Microsoft.ML.Trainers;
713
using Xunit;
14+
using Xunit.Abstractions;
815

916
namespace Microsoft.ML.Functional.Tests
1017
{
11-
public class PredictionScenarios
18+
public class PredictionScenarios : BaseTestClass
1219
{
20+
public PredictionScenarios(ITestOutputHelper output) : base(output)
21+
{
22+
}
23+
24+
class Prediction
25+
{
26+
public float Score { get; set; }
27+
public bool PredictedLabel { get; set; }
28+
}
1329
/// <summary>
1430
/// Reconfigurable predictions: The following should be possible: A user trains a binary classifier,
1531
/// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold
@@ -19,36 +35,64 @@ public class PredictionScenarios
1935
[Fact]
2036
public void ReconfigurablePrediction()
2137
{
22-
var mlContext = new MLContext(seed: 789);
23-
24-
// Get the dataset, create a train and test
25-
var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(),
26-
hasHeader: TestDatasets.housing.fileHasHeader, separatorChar: TestDatasets.housing.fileSeparator)
27-
.Load(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename));
28-
var split = mlContext.Data.TrainTestSplit(data, testFraction: 0.2);
29-
30-
// Create a pipeline to train on the housing data
31-
var pipeline = mlContext.Transforms.Concatenate("Features", new string[] {
32-
"CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling",
33-
"PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"})
34-
.Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue"))
35-
.Append(mlContext.Regression.Trainers.Ols());
36-
37-
var model = pipeline.Fit(split.TrainSet);
38-
39-
var scoredTest = model.Transform(split.TestSet);
40-
var metrics = mlContext.Regression.Evaluate(scoredTest);
41-
42-
Common.AssertMetrics(metrics);
43-
44-
// Todo #2465: Allow the setting of threshold and thresholdColumn for scoring.
45-
// This is no longer possible in the API
46-
//var newModel = new BinaryPredictionTransformer<IPredictorProducing<float>>(ml, model.Model, trainData.Schema, model.FeatureColumnName, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability);
47-
//var newScoredTest = newModel.Transform(pipeline.Transform(testData));
48-
//var newMetrics = mlContext.BinaryClassification.Evaluate(scoredTest);
49-
// And the Threshold and ThresholdColumn properties are not settable.
50-
//var predictor = model.LastTransformer;
51-
//predictor.Threshold = 0.01; // Not possible
38+
var mlContext = new MLContext(seed: 1);
39+
40+
var data = mlContext.Data.LoadFromTextFile<TweetSentiment>(GetDataPath(TestDatasets.Sentiment.trainFilename),
41+
hasHeader: TestDatasets.Sentiment.fileHasHeader,
42+
separatorChar: TestDatasets.Sentiment.fileSeparator);
43+
44+
// Create a training pipeline.
45+
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
46+
.AppendCacheCheckpoint(mlContext)
47+
.Append(mlContext.BinaryClassification.Trainers.LogisticRegression(
48+
new LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 }));
49+
50+
// Train the model.
51+
var model = pipeline.Fit(data);
52+
var engine = mlContext.Model.CreatePredictionEngine<TweetSentiment, Prediction>(model);
53+
var pr = engine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" });
54+
// Score is 0.64 so predicted label is true.
55+
Assert.True(pr.PredictedLabel);
56+
Assert.True(pr.Score > 0);
57+
var transformers = new List<ITransformer>();
58+
foreach (var transform in model)
59+
{
60+
if (transform != model.LastTransformer)
61+
transformers.Add(transform);
62+
}
63+
transformers.Add(mlContext.BinaryClassification.ChangeModelThreshold(model.LastTransformer, 0.7f));
64+
var newModel = new TransformerChain<BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>>(transformers.ToArray());
65+
var newEngine = mlContext.Model.CreatePredictionEngine<TweetSentiment, Prediction>(newModel);
66+
pr = newEngine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" });
67+
// Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false.
68+
69+
Assert.False(pr.PredictedLabel);
70+
Assert.False(pr.Score > 0.7);
5271
}
72+
73+
[Fact]
74+
public void ReconfigurablePredictionNoPipeline()
75+
{
76+
var mlContext = new MLContext(seed: 1);
77+
78+
var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());
79+
var pipeline = mlContext.BinaryClassification.Trainers.LogisticRegression(
80+
new Trainers.LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 });
81+
var model = pipeline.Fit(data);
82+
var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, -2.0f);
83+
var rnd = new Random(1);
84+
var randomDataPoint = TypeTestData.GetRandomInstance(rnd);
85+
var engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model);
86+
var pr = engine.Predict(randomDataPoint);
87+
// Score is -1.38 so predicted label is false.
88+
Assert.False(pr.PredictedLabel);
89+
Assert.True(pr.Score <= 0);
90+
var newEngine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(newModel);
91+
pr = newEngine.Predict(randomDataPoint);
92+
// Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true.
93+
Assert.True(pr.PredictedLabel);
94+
Assert.True(pr.Score <= 0);
95+
}
96+
5397
}
5498
}

0 commit comments

Comments
 (0)