|
| 1 | +// Licensed to the .NET Foundation under one or more agreements. |
| 2 | +// The .NET Foundation licenses this file to you under the MIT license. |
| 3 | +// See the LICENSE file in the project root for more information. |
| 4 | + |
| 5 | +using Microsoft.ML.Runtime.Data; |
| 6 | +using Microsoft.ML.Runtime.Learners; |
| 7 | +using Microsoft.ML.Runtime.Api; |
| 8 | +using Xunit; |
| 9 | +using System; |
| 10 | +using System.Linq; |
| 11 | + |
| 12 | +namespace Microsoft.ML.Tests.Scenarios.Api |
| 13 | +{ |
| 14 | + public partial class ApiScenariosTests |
| 15 | + { |
| 16 | + /// <summary> |
| 17 | + /// Start with a dataset in a text file. Run text featurization on text values. |
| 18 | + /// Train a linear model over that. (I am thinking sentiment classification.) |
| 19 | + /// Out of the result, produce some structure over which you can get predictions programmatically |
| 20 | + /// (e.g., the prediction does not happen over a file as it did during training). |
| 21 | + /// </summary> |
| 22 | + [Fact] |
| 23 | + public void SimpleTrainAnPredict() |
| 24 | + { |
| 25 | + var dataPath = GetDataPath(SentimentDataPath); |
| 26 | + var testDataPath = GetDataPath(SentimentTestPath); |
| 27 | + |
| 28 | + using (var env = new TlcEnvironment(seed: 1, conc: 1)) |
| 29 | + { |
| 30 | + // Pipeline |
| 31 | + var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); |
| 32 | + |
| 33 | + var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); |
| 34 | + |
| 35 | + // Train |
| 36 | + var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments |
| 37 | + { |
| 38 | + NumThreads = 1 |
| 39 | + }); |
| 40 | + |
| 41 | + var cached = new CacheDataView(env, trans, prefetch: null); |
| 42 | + var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); |
| 43 | + var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); |
| 44 | + |
| 45 | + var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); |
| 46 | + IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); |
| 47 | + |
| 48 | + // Create prediction engine and test predictions. |
| 49 | + var model = env.CreatePredictionEngine<SentimentData, SentimentPrediction>(scorer); |
| 50 | + |
| 51 | + // Take a couple examples out of the test data and run predictions on top. |
| 52 | + var testLoader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath))); |
| 53 | + var testData = testLoader.AsEnumerable<SentimentData>(env, false); |
| 54 | + foreach (var input in testData.Take(5)) |
| 55 | + { |
| 56 | + var prediction = model.Predict(input); |
| 57 | + // Verify that predictions match and scores are separated from zero. |
| 58 | + Assert.Equal(input.Sentiment, prediction.Sentiment); |
| 59 | + Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); |
| 60 | + } |
| 61 | + } |
| 62 | + } |
| 63 | + |
| 64 | + private static TextTransform.Arguments MakeSentimentTextTransformArgs() |
| 65 | + { |
| 66 | + return new TextTransform.Arguments() |
| 67 | + { |
| 68 | + Column = new TextTransform.Column |
| 69 | + { |
| 70 | + Name = "Features", |
| 71 | + Source = new[] { "SentimentText" } |
| 72 | + }, |
| 73 | + KeepDiacritics = false, |
| 74 | + KeepPunctuations = false, |
| 75 | + TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower, |
| 76 | + OutputTokens = true, |
| 77 | + StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(), |
| 78 | + VectorNormalizer = TextTransform.TextNormKind.L2, |
| 79 | + CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 3, AllLengths = false }, |
| 80 | + WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 2, AllLengths = true }, |
| 81 | + }; |
| 82 | + } |
| 83 | + |
| 84 | + private static TextLoader.Arguments MakeSentimentTextLoaderArgs() |
| 85 | + { |
| 86 | + return new TextLoader.Arguments() |
| 87 | + { |
| 88 | + Separator = "tab", |
| 89 | + HasHeader = true, |
| 90 | + Column = new[] |
| 91 | + { |
| 92 | + new TextLoader.Column() |
| 93 | + { |
| 94 | + Name = "Label", |
| 95 | + Source = new [] { new TextLoader.Range() { Min=0, Max=0} }, |
| 96 | + Type = DataKind.BL |
| 97 | + }, |
| 98 | + |
| 99 | + new TextLoader.Column() |
| 100 | + { |
| 101 | + Name = "SentimentText", |
| 102 | + Source = new [] { new TextLoader.Range() { Min=1, Max=1} }, |
| 103 | + Type = DataKind.Text |
| 104 | + } |
| 105 | + } |
| 106 | + }; |
| 107 | + } |
| 108 | + } |
| 109 | +} |
0 commit comments