Skip to content

Commit 6299a1f

Browse files
author
Pete Luferenko
committed
Merge branch 'feature/api-examples' into feature/estimators
2 parents 6aeb7cc + 20e59a2 commit 6299a1f

File tree

3 files changed

+215
-0
lines changed

3 files changed

+215
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using Microsoft.ML.Runtime.Api;
2+
using Microsoft.ML.Runtime.Data;
3+
using Microsoft.ML.TestFramework;
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Text;
7+
using Xunit.Abstractions;
8+
9+
namespace Microsoft.ML.Tests.Scenarios.Api
10+
{
11+
/// <summary>
12+
/// Common utility functions for API scenarios tests.
13+
/// </summary>
14+
public partial class ApiScenariosTests : BaseTestClass
15+
{
16+
public ApiScenariosTests(ITestOutputHelper output) : base(output)
17+
{
18+
}
19+
20+
public const string SentimentDataPath = "wikipedia-detox-250-line-data.tsv";
21+
public const string SentimentTestPath = "wikipedia-detox-250-line-test.tsv";
22+
23+
public class SentimentData
24+
{
25+
[ColumnName("Label")]
26+
public bool Sentiment;
27+
public string SentimentText;
28+
}
29+
30+
public class SentimentPrediction
31+
{
32+
[ColumnName("PredictedLabel")]
33+
public bool Sentiment;
34+
35+
public float Score;
36+
}
37+
}
38+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Data;
6+
using Microsoft.ML.Runtime.Learners;
7+
using Microsoft.ML.Runtime.Api;
8+
using Xunit;
9+
using System;
10+
using System.Linq;
11+
12+
namespace Microsoft.ML.Tests.Scenarios.Api
13+
{
14+
public partial class ApiScenariosTests
15+
{
16+
/// <summary>
17+
/// Start with a dataset in a text file. Run text featurization on text values.
18+
/// Train a linear model over that. (I am thinking sentiment classification.)
19+
/// Out of the result, produce some structure over which you can get predictions programmatically
20+
/// (e.g., the prediction does not happen over a file as it did during training).
21+
/// </summary>
22+
[Fact]
23+
public void SimpleTrainAnPredict()
24+
{
25+
var dataPath = GetDataPath(SentimentDataPath);
26+
var testDataPath = GetDataPath(SentimentTestPath);
27+
28+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
29+
{
30+
// Pipeline
31+
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
32+
33+
var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);
34+
35+
// Train
36+
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
37+
{
38+
NumThreads = 1
39+
});
40+
41+
var cached = new CacheDataView(env, trans, prefetch: null);
42+
var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
43+
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
44+
45+
var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
46+
IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);
47+
48+
// Create prediction engine and test predictions.
49+
var model = env.CreatePredictionEngine<SentimentData, SentimentPrediction>(scorer);
50+
51+
// Take a couple examples out of the test data and run predictions on top.
52+
var testLoader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath)));
53+
var testData = testLoader.AsEnumerable<SentimentData>(env, false);
54+
foreach (var input in testData.Take(5))
55+
{
56+
var prediction = model.Predict(input);
57+
// Verify that predictions match and scores are separated from zero.
58+
Assert.Equal(input.Sentiment, prediction.Sentiment);
59+
Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1);
60+
}
61+
}
62+
}
63+
64+
private static TextTransform.Arguments MakeSentimentTextTransformArgs()
65+
{
66+
return new TextTransform.Arguments()
67+
{
68+
Column = new TextTransform.Column
69+
{
70+
Name = "Features",
71+
Source = new[] { "SentimentText" }
72+
},
73+
KeepDiacritics = false,
74+
KeepPunctuations = false,
75+
TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
76+
OutputTokens = true,
77+
StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
78+
VectorNormalizer = TextTransform.TextNormKind.L2,
79+
CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 3, AllLengths = false },
80+
WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 2, AllLengths = true },
81+
};
82+
}
83+
84+
private static TextLoader.Arguments MakeSentimentTextLoaderArgs()
85+
{
86+
return new TextLoader.Arguments()
87+
{
88+
Separator = "tab",
89+
HasHeader = true,
90+
Column = new[]
91+
{
92+
new TextLoader.Column()
93+
{
94+
Name = "Label",
95+
Source = new [] { new TextLoader.Range() { Min=0, Max=0} },
96+
Type = DataKind.BL
97+
},
98+
99+
new TextLoader.Column()
100+
{
101+
Name = "SentimentText",
102+
Source = new [] { new TextLoader.Range() { Min=1, Max=1} },
103+
Type = DataKind.Text
104+
}
105+
}
106+
};
107+
}
108+
}
109+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
using Microsoft.ML.Runtime.Api;
2+
using Microsoft.ML.Runtime.Data;
3+
using Microsoft.ML.Runtime.Learners;
4+
using Microsoft.ML.Runtime.Model;
5+
using System.Linq;
6+
using Xunit;
7+
8+
namespace Microsoft.ML.Tests.Scenarios.Api
9+
{
10+
public partial class ApiScenariosTests
11+
{
12+
/// <summary>
13+
/// Train, save/load model, predict:
14+
/// Serve the scenario where training and prediction happen in different processes (or even different machines).
15+
/// The actual test will not run in different processes, but will simulate the idea that the
16+
/// "communication pipe" is just a serialized model of some form.
17+
/// </summary>
18+
[Fact]
19+
public void TrainSaveModelAndPredict()
20+
{
21+
var dataPath = GetDataPath(SentimentDataPath);
22+
var testDataPath = GetDataPath(SentimentTestPath);
23+
24+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
25+
{
26+
// Pipeline
27+
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
28+
29+
var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);
30+
31+
// Train
32+
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
33+
{
34+
NumThreads = 1
35+
});
36+
37+
var cached = new CacheDataView(env, trans, prefetch: null);
38+
var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
39+
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
40+
41+
PredictionEngine<SentimentData, SentimentPrediction> model;
42+
using (var file = env.CreateTempFile())
43+
{
44+
// Save model.
45+
var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
46+
using (var ch = env.Start("saving"))
47+
TrainUtils.SaveModel(env, ch, file, predictor, scoreRoles);
48+
49+
// Load model.
50+
using (var fs = file.OpenReadStream())
51+
model = env.CreatePredictionEngine<SentimentData, SentimentPrediction>(fs);
52+
}
53+
54+
// Take a couple examples out of the test data and run predictions on top.
55+
var testLoader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath)));
56+
var testData = testLoader.AsEnumerable<SentimentData>(env, false);
57+
foreach (var input in testData.Take(5))
58+
{
59+
var prediction = model.Predict(input);
60+
// Verify that predictions match and scores are separated from zero.
61+
Assert.Equal(input.Sentiment, prediction.Sentiment);
62+
Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1);
63+
}
64+
}
65+
66+
}
67+
}
68+
}

0 commit comments

Comments
 (0)