|
| 1 | +using Microsoft.ML.Models; |
| 2 | +using Microsoft.ML.Runtime.Data; |
| 3 | +using Microsoft.ML.Runtime.Learners; |
| 4 | +using System; |
| 5 | +using System.Collections.Generic; |
| 6 | +using Xunit; |
| 7 | + |
| 8 | +namespace Microsoft.ML.Tests.Scenarios.Api |
| 9 | +{ |
| 10 | + public partial class ApiScenariosTests |
| 11 | + { |
| 12 | + /// <summary> |
| 13 | + /// Cross-validation: Have a mechanism to do cross validation, that is, you come up with |
| 14 | + /// a data source (optionally with stratification column), come up with an instantiable transform |
| 15 | + /// and trainer pipeline, and it will handle (1) splitting up the data, (2) training the separate |
| 16 | + /// pipelines on in-fold data, (3) scoring on the out-fold data, (4) returning the set of |
| 17 | + /// evaluations and optionally trained pipes. (People always want metrics out of xfold, |
| 18 | + /// they sometimes want the actual models too.) |
| 19 | + /// </summary> |
| 20 | + [Fact] |
| 21 | + void CrossValidation() |
| 22 | + { |
| 23 | + var dataPath = GetDataPath(SentimentDataPath); |
| 24 | + var testDataPath = GetDataPath(SentimentTestPath); |
| 25 | + |
| 26 | + int numFolds = 5; |
| 27 | + using (var env = new TlcEnvironment(seed: 1, conc: 1)) |
| 28 | + { |
| 29 | + // Pipeline. |
| 30 | + var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); |
| 31 | + |
| 32 | + var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader); |
| 33 | + var random = new GenerateNumberTransform(env, trans, "StratificationColumn"); |
| 34 | + // Train. |
| 35 | + var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments |
| 36 | + { |
| 37 | + NumThreads = 1 |
| 38 | + }); |
| 39 | + |
| 40 | + // Auto-caching. |
| 41 | + IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, random, prefetch: null) : random; |
| 42 | + var metrics = new List<BinaryClassificationMetrics>(); |
| 43 | + for (int fold = 0; fold < numFolds; fold++) |
| 44 | + { |
| 45 | + var trainFilter = new RangeFilter(env, new RangeFilter.Arguments() |
| 46 | + { |
| 47 | + Column = "StratificationColumn", |
| 48 | + Min = (Double)fold / numFolds, |
| 49 | + Max = (Double)(fold + 1) / numFolds, |
| 50 | + Complement = true |
| 51 | + }, trainData); |
| 52 | + |
| 53 | + // Auto-normalization. |
| 54 | + var trainRoles = new RoleMappedData(trainFilter, label: "Label", feature: "Features"); |
| 55 | + NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer); |
| 56 | + |
| 57 | + var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); |
| 58 | + var testFilter = new RangeFilter(env, new RangeFilter.Arguments() |
| 59 | + { |
| 60 | + Column = "StratificationColumn", |
| 61 | + Min = (Double)fold / numFolds, |
| 62 | + Max = (Double)(fold + 1) / numFolds, |
| 63 | + Complement = false |
| 64 | + }, trainData); |
| 65 | + // Auto-normalization. |
| 66 | + var testRoles = new RoleMappedData(testFilter, label: "Label", feature: "Features"); |
| 67 | + NormalizeTransform.CreateIfNeeded(env, ref testRoles, trainer); |
| 68 | + |
| 69 | + IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, testRoles, env, testRoles.Schema); |
| 70 | + |
| 71 | + BinaryClassifierMamlEvaluator eval = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { }); |
| 72 | + var dataEval = new RoleMappedData(scorer, testRoles.Schema.GetColumnRoleNames(), opt: true); |
| 73 | + var dict = eval.Evaluate(dataEval); |
| 74 | + var foldMetrics = BinaryClassificationMetrics.FromMetrics(env, dict["OverallMetrics"], dict["ConfusionMatrix"]); |
| 75 | + metrics.AddRange(foldMetrics); |
| 76 | + } |
| 77 | + } |
| 78 | + } |
| 79 | + } |
| 80 | +} |
0 commit comments