diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index c4af3170ad..b4467b9f8b 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -52,7 +52,7 @@ public abstract class PredictionTransformerBase : IPredictionTransformer [BestFriend] private protected ISchemaBindableMapper BindableMapper; [BestFriend] - private protected DataViewSchema TrainSchema; + internal DataViewSchema TrainSchema; /// /// Whether a call to should succeed, on an diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index da615d7e32..944f3dd9e5 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -256,6 +256,14 @@ public IReadOnlyList ChangeModelThreshold(BinaryPredictionTransformer model, float threshold) + where TModel : class + { + if (model.Threshold == threshold) + return model; + return new BinaryPredictionTransformer(Environment, model.Model, model.TrainSchema, model.FeatureColumnName, threshold, model.ThresholdColumn); + } + /// /// The list of trainers for performing binary classification. /// diff --git a/test/Microsoft.ML.Functional.Tests/Prediction.cs b/test/Microsoft.ML.Functional.Tests/Prediction.cs index 4605f953bd..627e06e775 100644 --- a/test/Microsoft.ML.Functional.Tests/Prediction.cs +++ b/test/Microsoft.ML.Functional.Tests/Prediction.cs @@ -2,14 +2,30 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Collections.Generic; +using Microsoft.ML.Calibrators; +using Microsoft.ML.Data; +using Microsoft.ML.Functional.Tests.Datasets; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; +using Microsoft.ML.Trainers; using Xunit; +using Xunit.Abstractions; namespace Microsoft.ML.Functional.Tests { - public class PredictionScenarios + public class PredictionScenarios : BaseTestClass { + public PredictionScenarios(ITestOutputHelper output) : base(output) + { + } + + class Prediction + { + public float Score { get; set; } + public bool PredictedLabel { get; set; } + } /// /// Reconfigurable predictions: The following should be possible: A user trains a binary classifier, /// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold @@ -19,36 +35,64 @@ public class PredictionScenarios [Fact] public void ReconfigurablePrediction() { - var mlContext = new MLContext(seed: 789); - - // Get the dataset, create a train and test - var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), - hasHeader: TestDatasets.housing.fileHasHeader, separatorChar: TestDatasets.housing.fileSeparator) - .Load(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename)); - var split = mlContext.Data.TrainTestSplit(data, testFraction: 0.2); - - // Create a pipeline to train on the housing data - var pipeline = mlContext.Transforms.Concatenate("Features", new string[] { - "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", - "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"}) - .Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue")) - .Append(mlContext.Regression.Trainers.Ols()); - - var model = pipeline.Fit(split.TrainSet); - - var scoredTest = model.Transform(split.TestSet); - var metrics = mlContext.Regression.Evaluate(scoredTest); - - Common.AssertMetrics(metrics); - - // Todo #2465: Allow the setting of threshold and thresholdColumn for scoring. - // This is no longer possible in the API - //var newModel = new BinaryPredictionTransformer>(ml, model.Model, trainData.Schema, model.FeatureColumnName, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability); - //var newScoredTest = newModel.Transform(pipeline.Transform(testData)); - //var newMetrics = mlContext.BinaryClassification.Evaluate(scoredTest); - // And the Threshold and ThresholdColumn properties are not settable. - //var predictor = model.LastTransformer; - //predictor.Threshold = 0.01; // Not possible + var mlContext = new MLContext(seed: 1); + + var data = mlContext.Data.LoadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), + hasHeader: TestDatasets.Sentiment.fileHasHeader, + separatorChar: TestDatasets.Sentiment.fileSeparator); + + // Create a training pipeline. + var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText") + .AppendCacheCheckpoint(mlContext) + .Append(mlContext.BinaryClassification.Trainers.LogisticRegression( + new LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 })); + + // Train the model. + var model = pipeline.Fit(data); + var engine = mlContext.Model.CreatePredictionEngine(model); + var pr = engine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" }); + // Score is 0.64 so predicted label is true. + Assert.True(pr.PredictedLabel); + Assert.True(pr.Score > 0); + var transformers = new List(); + foreach (var transform in model) + { + if (transform != model.LastTransformer) + transformers.Add(transform); + } + transformers.Add(mlContext.BinaryClassification.ChangeModelThreshold(model.LastTransformer, 0.7f)); + var newModel = new TransformerChain>>(transformers.ToArray()); + var newEngine = mlContext.Model.CreatePredictionEngine(newModel); + pr = newEngine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" }); + // Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false. + + Assert.False(pr.PredictedLabel); + Assert.False(pr.Score > 0.7); } + + [Fact] + public void ReconfigurablePredictionNoPipeline() + { + var mlContext = new MLContext(seed: 1); + + var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset()); + var pipeline = mlContext.BinaryClassification.Trainers.LogisticRegression( + new Trainers.LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 }); + var model = pipeline.Fit(data); + var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, -2.0f); + var rnd = new Random(1); + var randomDataPoint = TypeTestData.GetRandomInstance(rnd); + var engine = mlContext.Model.CreatePredictionEngine(model); + var pr = engine.Predict(randomDataPoint); + // Score is -1.38 so predicted label is false. + Assert.False(pr.PredictedLabel); + Assert.True(pr.Score <= 0); + var newEngine = mlContext.Model.CreatePredictionEngine(newModel); + pr = newEngine.Predict(randomDataPoint); + // Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true. + Assert.True(pr.PredictedLabel); + Assert.True(pr.Score <= 0); + } + } }