Skip to content

Commit f85cd9c

Browse files
author
Ivan Matantsev
committed
ReconfigurablePrediction
and shorter execution
1 parent fffa4eb commit f85cd9c

File tree

3 files changed

+62
-5
lines changed

3 files changed

+62
-5
lines changed

test/Microsoft.ML.Tests/Scenarios/Api/CrossValidation.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ void CrossValidation()
3434
// Train.
3535
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
3636
{
37-
NumThreads = 1
37+
NumThreads = 1,
38+
ConvergenceTolerance = 1f
3839
});
3940

4041
// Auto-caching.

test/Microsoft.ML.Tests/Scenarios/Api/Evaluation.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
using Microsoft.ML.Runtime.Data;
22
using Microsoft.ML.Runtime.Learners;
3-
using Microsoft.ML.Runtime.Api;
4-
using System;
5-
using System.Collections.Generic;
6-
using System.Text;
73
using Xunit;
84
using Microsoft.ML.Models;
95

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using Microsoft.ML.Models;
2+
using Microsoft.ML.Runtime.Data;
3+
using Microsoft.ML.Runtime.Learners;
4+
using Xunit;
5+
6+
namespace Microsoft.ML.Tests.Scenarios.Api
7+
{
8+
public partial class ApiScenariosTests
9+
{
10+
/// <summary>
11+
/// Reconfigurable predictions: The following should be possible: A user trains a binary classifier,
12+
/// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold
13+
/// and configures the scorer (or more precisely instantiates a new scorer over the same predictor)
14+
/// with some threshold derived from that.
15+
/// </summary>
16+
[Fact]
17+
void ReconfigurablePrediction()
18+
{
19+
var dataPath = GetDataPath(SentimentDataPath);
20+
var testDataPath = GetDataPath(SentimentTestPath);
21+
22+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
23+
{
24+
// Pipeline
25+
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
26+
27+
var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);
28+
29+
// Train
30+
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
31+
{
32+
NumThreads = 1
33+
});
34+
35+
var cached = new CacheDataView(env, trans, prefetch: null);
36+
var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
37+
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
38+
var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
39+
IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);
40+
41+
var dataEval = new RoleMappedData(scorer, label: "Label", feature: "Features", opt: true);
42+
43+
var evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { });
44+
var metricsDict = evaluator.Evaluate(dataEval);
45+
46+
var metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0];
47+
48+
var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor, null);
49+
var mapper = bindable.Bind(env, trainRoles.Schema);
50+
var newScorer = new BinaryClassifierScorer(env, new BinaryClassifierScorer.Arguments { Threshold = 0.01f, ThresholdColumn = DefaultColumnNames.Probability },
51+
scoreRoles.Data, mapper, trainRoles.Schema);
52+
53+
dataEval = new RoleMappedData(newScorer, label: "Label", feature: "Features", opt: true);
54+
var new_evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { Threshold = 0.01f, UseRawScoreThreshold = false });
55+
metricsDict = new_evaluator.Evaluate(dataEval);
56+
var new_metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0];
57+
}
58+
}
59+
}
60+
}

0 commit comments

Comments
 (0)