2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
5
+ using System ;
6
+ using System . Collections . Generic ;
7
+ using Microsoft . ML . Calibrators ;
8
+ using Microsoft . ML . Data ;
9
+ using Microsoft . ML . Functional . Tests . Datasets ;
5
10
using Microsoft . ML . RunTests ;
6
11
using Microsoft . ML . TestFramework ;
12
+ using Microsoft . ML . Trainers ;
7
13
using Xunit ;
14
+ using Xunit . Abstractions ;
8
15
9
16
namespace Microsoft . ML . Functional . Tests
10
17
{
11
- public class PredictionScenarios
18
+ public class PredictionScenarios : BaseTestClass
12
19
{
20
+ public PredictionScenarios ( ITestOutputHelper output ) : base ( output )
21
+ {
22
+ }
23
+
24
+ class Prediction
25
+ {
26
+ public float Score { get ; set ; }
27
+ public bool PredictedLabel { get ; set ; }
28
+ }
13
29
/// <summary>
14
30
/// Reconfigurable predictions: The following should be possible: A user trains a binary classifier,
15
31
/// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold
@@ -19,36 +35,64 @@ public class PredictionScenarios
19
35
[ Fact ]
20
36
public void ReconfigurablePrediction ( )
21
37
{
22
- var mlContext = new MLContext ( seed : 789 ) ;
23
-
24
- // Get the dataset, create a train and test
25
- var data = mlContext . Data . CreateTextLoader ( TestDatasets . housing . GetLoaderColumns ( ) ,
26
- hasHeader : TestDatasets . housing . fileHasHeader , separatorChar : TestDatasets . housing . fileSeparator )
27
- . Load ( BaseTestClass . GetDataPath ( TestDatasets . housing . trainFilename ) ) ;
28
- var split = mlContext . Data . TrainTestSplit ( data , testFraction : 0.2 ) ;
29
-
30
- // Create a pipeline to train on the housing data
31
- var pipeline = mlContext . Transforms . Concatenate ( "Features" , new string [ ] {
32
- "CrimesPerCapita" , "PercentResidental" , "PercentNonRetail" , "CharlesRiver" , "NitricOxides" , "RoomsPerDwelling" ,
33
- "PercentPre40s" , "EmploymentDistance" , "HighwayDistance" , "TaxRate" , "TeacherRatio" } )
34
- . Append ( mlContext . Transforms . CopyColumns ( "Label" , "MedianHomeValue" ) )
35
- . Append ( mlContext . Regression . Trainers . Ols ( ) ) ;
36
-
37
- var model = pipeline . Fit ( split . TrainSet ) ;
38
-
39
- var scoredTest = model . Transform ( split . TestSet ) ;
40
- var metrics = mlContext . Regression . Evaluate ( scoredTest ) ;
41
-
42
- Common . AssertMetrics ( metrics ) ;
43
-
44
- // Todo #2465: Allow the setting of threshold and thresholdColumn for scoring.
45
- // This is no longer possible in the API
46
- //var newModel = new BinaryPredictionTransformer<IPredictorProducing<float>>(ml, model.Model, trainData.Schema, model.FeatureColumnName, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability);
47
- //var newScoredTest = newModel.Transform(pipeline.Transform(testData));
48
- //var newMetrics = mlContext.BinaryClassification.Evaluate(scoredTest);
49
- // And the Threshold and ThresholdColumn properties are not settable.
50
- //var predictor = model.LastTransformer;
51
- //predictor.Threshold = 0.01; // Not possible
38
+ var mlContext = new MLContext ( seed : 1 ) ;
39
+
40
+ var data = mlContext . Data . LoadFromTextFile < TweetSentiment > ( GetDataPath ( TestDatasets . Sentiment . trainFilename ) ,
41
+ hasHeader : TestDatasets . Sentiment . fileHasHeader ,
42
+ separatorChar : TestDatasets . Sentiment . fileSeparator ) ;
43
+
44
+ // Create a training pipeline.
45
+ var pipeline = mlContext . Transforms . Text . FeaturizeText ( "Features" , "SentimentText" )
46
+ . AppendCacheCheckpoint ( mlContext )
47
+ . Append ( mlContext . BinaryClassification . Trainers . LogisticRegression (
48
+ new LogisticRegressionBinaryTrainer . Options { NumberOfThreads = 1 } ) ) ;
49
+
50
+ // Train the model.
51
+ var model = pipeline . Fit ( data ) ;
52
+ var engine = mlContext . Model . CreatePredictionEngine < TweetSentiment , Prediction > ( model ) ;
53
+ var pr = engine . Predict ( new TweetSentiment ( ) { SentimentText = "Good Bad job" } ) ;
54
+ // Score is 0.64 so predicted label is true.
55
+ Assert . True ( pr . PredictedLabel ) ;
56
+ Assert . True ( pr . Score > 0 ) ;
57
+ var transformers = new List < ITransformer > ( ) ;
58
+ foreach ( var transform in model )
59
+ {
60
+ if ( transform != model . LastTransformer )
61
+ transformers . Add ( transform ) ;
62
+ }
63
+ transformers . Add ( mlContext . BinaryClassification . ChangeModelThreshold ( model . LastTransformer , 0.7f ) ) ;
64
+ var newModel = new TransformerChain < BinaryPredictionTransformer < CalibratedModelParametersBase < LinearBinaryModelParameters , PlattCalibrator > > > ( transformers . ToArray ( ) ) ;
65
+ var newEngine = mlContext . Model . CreatePredictionEngine < TweetSentiment , Prediction > ( newModel ) ;
66
+ pr = newEngine . Predict ( new TweetSentiment ( ) { SentimentText = "Good Bad job" } ) ;
67
+ // Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false.
68
+
69
+ Assert . False ( pr . PredictedLabel ) ;
70
+ Assert . False ( pr . Score > 0.7 ) ;
52
71
}
72
+
73
+ [ Fact ]
74
+ public void ReconfigurablePredictionNoPipeline ( )
75
+ {
76
+ var mlContext = new MLContext ( seed : 1 ) ;
77
+
78
+ var data = mlContext . Data . LoadFromEnumerable ( TypeTestData . GenerateDataset ( ) ) ;
79
+ var pipeline = mlContext . BinaryClassification . Trainers . LogisticRegression (
80
+ new Trainers . LogisticRegressionBinaryTrainer . Options { NumberOfThreads = 1 } ) ;
81
+ var model = pipeline . Fit ( data ) ;
82
+ var newModel = mlContext . BinaryClassification . ChangeModelThreshold ( model , - 2.0f ) ;
83
+ var rnd = new Random ( 1 ) ;
84
+ var randomDataPoint = TypeTestData . GetRandomInstance ( rnd ) ;
85
+ var engine = mlContext . Model . CreatePredictionEngine < TypeTestData , Prediction > ( model ) ;
86
+ var pr = engine . Predict ( randomDataPoint ) ;
87
+ // Score is -1.38 so predicted label is false.
88
+ Assert . False ( pr . PredictedLabel ) ;
89
+ Assert . True ( pr . Score <= 0 ) ;
90
+ var newEngine = mlContext . Model . CreatePredictionEngine < TypeTestData , Prediction > ( newModel ) ;
91
+ pr = newEngine . Predict ( randomDataPoint ) ;
92
+ // Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true.
93
+ Assert . True ( pr . PredictedLabel ) ;
94
+ Assert . True ( pr . Score <= 0 ) ;
95
+ }
96
+
53
97
}
54
98
}
0 commit comments