|
| 1 | +using Microsoft.ML.Models; |
| 2 | +using Microsoft.ML.Runtime.Data; |
| 3 | +using Microsoft.ML.Runtime.Learners; |
| 4 | +using Xunit; |
| 5 | + |
| 6 | +namespace Microsoft.ML.Tests.Scenarios.Api |
| 7 | +{ |
| 8 | + public partial class ApiScenariosTests |
| 9 | + { |
| 10 | + /// <summary> |
| 11 | + /// Reconfigurable predictions: The following should be possible: A user trains a binary classifier, |
| 12 | + /// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold |
| 13 | + /// and configures the scorer (or more precisely instantiates a new scorer over the same predictor) |
| 14 | + /// with some threshold derived from that. |
| 15 | + /// </summary> |
| 16 | + [Fact] |
| 17 | + void ReconfigurablePrediction() |
| 18 | + { |
| 19 | + var dataPath = GetDataPath(SentimentDataPath); |
| 20 | + var testDataPath = GetDataPath(SentimentTestPath); |
| 21 | + |
| 22 | + using (var env = new TlcEnvironment(seed: 1, conc: 1)) |
| 23 | + { |
| 24 | + // Pipeline |
| 25 | + var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); |
| 26 | + |
| 27 | + var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); |
| 28 | + |
| 29 | + // Train |
| 30 | + var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments |
| 31 | + { |
| 32 | + NumThreads = 1 |
| 33 | + }); |
| 34 | + |
| 35 | + var cached = new CacheDataView(env, trans, prefetch: null); |
| 36 | + var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); |
| 37 | + var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); |
| 38 | + var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); |
| 39 | + IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); |
| 40 | + |
| 41 | + var dataEval = new RoleMappedData(scorer, label: "Label", feature: "Features", opt: true); |
| 42 | + |
| 43 | + var evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { }); |
| 44 | + var metricsDict = evaluator.Evaluate(dataEval); |
| 45 | + |
| 46 | + var metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0]; |
| 47 | + |
| 48 | + var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor, null); |
| 49 | + var mapper = bindable.Bind(env, trainRoles.Schema); |
| 50 | + var newScorer = new BinaryClassifierScorer(env, new BinaryClassifierScorer.Arguments { Threshold = 0.01f, ThresholdColumn = DefaultColumnNames.Probability }, |
| 51 | + scoreRoles.Data, mapper, trainRoles.Schema); |
| 52 | + |
| 53 | + dataEval = new RoleMappedData(newScorer, label: "Label", feature: "Features", opt: true); |
| 54 | + var new_evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { Threshold = 0.01f, UseRawScoreThreshold = false }); |
| 55 | + metricsDict = new_evaluator.Evaluate(dataEval); |
| 56 | + var new_metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0]; |
| 57 | + } |
| 58 | + } |
| 59 | + } |
| 60 | +} |
0 commit comments