Skip to content

Commit 4ee52f8

Browse files
author
Ivan Matantsev
committed
cross validation
1 parent b468056 commit 4ee52f8

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using Microsoft.ML.Models;
2+
using Microsoft.ML.Runtime.Data;
3+
using Microsoft.ML.Runtime.Learners;
4+
using System;
5+
using System.Collections.Generic;
6+
using Xunit;
7+
8+
namespace Microsoft.ML.Tests.Scenarios.Api
9+
{
10+
public partial class ApiScenariosTests
11+
{
12+
/// <summary>
13+
/// Cross-validation: Have a mechanism to do cross validation, that is, you come up with
14+
/// a data source (optionally with stratification column), come up with an instantiable transform
15+
/// and trainer pipeline, and it will handle (1) splitting up the data, (2) training the separate
16+
/// pipelines on in-fold data, (3) scoring on the out-fold data, (4) returning the set of
17+
/// evaluations and optionally trained pipes. (People always want metrics out of xfold,
18+
/// they sometimes want the actual models too.)
19+
/// </summary>
20+
[Fact]
21+
void CrossValidation()
22+
{
23+
var dataPath = GetDataPath(SentimentDataPath);
24+
var testDataPath = GetDataPath(SentimentTestPath);
25+
26+
int numFolds = 5;
27+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
28+
{
29+
// Pipeline.
30+
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
31+
32+
var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader);
33+
var random = new GenerateNumberTransform(env, trans, "StratificationColumn");
34+
// Train.
35+
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
36+
{
37+
NumThreads = 1
38+
});
39+
40+
// Auto-caching.
41+
IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, random, prefetch: null) : random;
42+
var metrics = new List<BinaryClassificationMetrics>();
43+
for (int fold = 0; fold < numFolds; fold++)
44+
{
45+
var trainFilter = new RangeFilter(env, new RangeFilter.Arguments()
46+
{
47+
Column = "StratificationColumn",
48+
Min = (Double)fold / numFolds,
49+
Max = (Double)(fold + 1) / numFolds,
50+
Complement = true
51+
}, trainData);
52+
53+
// Auto-normalization.
54+
var trainRoles = new RoleMappedData(trainFilter, label: "Label", feature: "Features");
55+
NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
56+
57+
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
58+
var testFilter = new RangeFilter(env, new RangeFilter.Arguments()
59+
{
60+
Column = "StratificationColumn",
61+
Min = (Double)fold / numFolds,
62+
Max = (Double)(fold + 1) / numFolds,
63+
Complement = false
64+
}, trainData);
65+
// Auto-normalization.
66+
var testRoles = new RoleMappedData(testFilter, label: "Label", feature: "Features");
67+
NormalizeTransform.CreateIfNeeded(env, ref testRoles, trainer);
68+
69+
IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, testRoles, env, testRoles.Schema);
70+
71+
BinaryClassifierMamlEvaluator eval = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { });
72+
var dataEval = new RoleMappedData(scorer, testRoles.Schema.GetColumnRoleNames(), opt: true);
73+
var dict = eval.Evaluate(dataEval);
74+
var foldMetrics = BinaryClassificationMetrics.FromMetrics(env, dict["OverallMetrics"], dict["ConfusionMatrix"]);
75+
metrics.AddRange(foldMetrics);
76+
}
77+
}
78+
}
79+
}
80+
}

0 commit comments

Comments
 (0)