Skip to content

Commit 70d3fb4

Browse files
author
Pete Luferenko
committed
Adding 'aspirational examples' and some more baseline scenarios
1 parent 1999de8 commit 70d3fb4

File tree

5 files changed

+145
-1
lines changed

5 files changed

+145
-1
lines changed

test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
<PropertyGroup>
33
<TargetFramework>netcoreapp2.0</TargetFramework>
44
</PropertyGroup>
5+
<ItemGroup>
6+
<Compile Remove="Scenarios\Api\AspirationalExamples.cs" />
7+
</ItemGroup>
58

69
<ItemGroup>
710
<ProjectReference Include="..\..\src\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj" />
@@ -23,4 +26,8 @@
2326
<NativeAssemblyReference Include="FastTreeNative" />
2427
<NativeAssemblyReference Include="LdaNative" />
2528
</ItemGroup>
29+
30+
<ItemGroup>
31+
<None Include="Scenarios\Api\AspirationalExamples.cs" />
32+
</ItemGroup>
2633
</Project>
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Microsoft.ML.Tests.Scenarios.Api
6+
{
7+
public class AspirationalExamples
8+
{
9+
public class IrisPrediction
10+
{
11+
public string PredictedLabel;
12+
}
13+
14+
public class IrisData
15+
{
16+
public float[] Features;
17+
}
18+
19+
public void FirstExperienceWithML()
20+
{
21+
// This is the 'getting started with ML' example, how we see it in our new API.
22+
// It currently doesn't compile, let alone work, but we still can discuss and improve the syntax.
23+
24+
// Load the data into the system.
25+
string dataPath = "iris-data.txt";
26+
var data = TextReader.ReadFile(dataPath, c => (Label: c.LoadString(0), Features: c.LoadFloat(1, 4)));
27+
28+
// Assign numeric values to text in the "Label" column, because only
29+
// numbers can be processed during model training.
30+
var transformer = data.MakeTransformer(row => (Label: row.Label.Dictionarize(), row.Features));
31+
var trainingData = transformer.Transform(data);
32+
33+
// Train a multiclass linear classifier.
34+
var learner = new StochasticDualCoordinateAscentClassifier();
35+
var classifier = learner.Train(trainingData);
36+
37+
// Obtain some predictions.
38+
var predictionEngine = new PredictionEngine<float[], string>(classifier, inputColumn: "Features", outputColumn: "PredictedLabel");
39+
string prediction = predictionEngine.Predict(new[] { 3.3f, 1.6f, 0.2f, 5.1f });
40+
}
41+
}
42+
}

test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public partial class ApiScenariosTests
2020
/// (e.g., the prediction does not happen over a file as it did during training).
2121
/// </summary>
2222
[Fact]
23-
public void SimpleTrainAnPredict()
23+
public void SimpleTrainAndPredict()
2424
{
2525
var dataPath = GetDataPath(SentimentDataPath);
2626
var testDataPath = GetDataPath(SentimentTestPath);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using Microsoft.ML.Runtime;
2+
using Microsoft.ML.Runtime.Data;
3+
using Microsoft.ML.Runtime.FastTree;
4+
using Microsoft.ML.Runtime.Learners;
5+
using Xunit;
6+
7+
namespace Microsoft.ML.Tests.Scenarios.Api
8+
{
9+
public partial class ApiScenariosTests
10+
{
11+
/// <summary>
12+
/// Train with initial predictor: Similar to the simple train scenario, but also accept a pre-trained initial model.
13+
/// The scenario might be one of the online linear learners that can take advantage of this, e.g., averaged perceptron.
14+
/// </summary>
15+
[Fact]
16+
public void TrainWithInitialPredictor()
17+
{
18+
var dataPath = GetDataPath(SentimentDataPath);
19+
20+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
21+
{
22+
// Pipeline
23+
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
24+
25+
var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);
26+
var trainData = trans;
27+
28+
var cachedTrain = new CacheDataView(env, trainData, prefetch: null);
29+
// Train the first predictor.
30+
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
31+
{
32+
NumThreads = 1
33+
});
34+
var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
35+
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
36+
37+
// Train the second predictor on the same data.
38+
var secondTrainer = new AveragedPerceptronTrainer(env, new AveragedPerceptronTrainer.Arguments());
39+
var finalPredictor = secondTrainer.Train(new TrainContext(trainRoles, initialPredictor: predictor));
40+
}
41+
}
42+
}
43+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using Microsoft.ML.Runtime.Data;
2+
using Microsoft.ML.Runtime.FastTree;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Text;
6+
using Xunit;
7+
using Microsoft.ML.Runtime.Api;
8+
using System.Linq;
9+
10+
namespace Microsoft.ML.Tests.Scenarios.Api
11+
{
12+
public partial class ApiScenariosTests
13+
{
14+
/// <summary>
15+
/// Train with validation set: Similar to the simple train scenario, but also support a validation set.
16+
/// The learner might be trees with early stopping.
17+
/// </summary>
18+
[Fact]
19+
public void TrainWithValidationSet()
20+
{
21+
var dataPath = GetDataPath(SentimentDataPath);
22+
var validationDataPath = GetDataPath(SentimentTestPath);
23+
24+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
25+
{
26+
// Pipeline
27+
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
28+
29+
var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);
30+
var trainData = trans;
31+
32+
// Apply the same transformations on the validation set.
33+
// Sadly, there is no way to easily apply the same loader to different data, so we either have
34+
// to create another loader, or to save the loader to model file and then reload.
35+
36+
// A new one is not always feasible, but this time it is.
37+
var validLoader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(validationDataPath));
38+
var validData = ApplyTransformUtils.ApplyAllTransformsToData(env, trainData, validLoader);
39+
40+
// Cache both datasets.
41+
var cachedTrain = new CacheDataView(env, trainData, prefetch: null);
42+
var cachedValid = new CacheDataView(env, validData, prefetch: null);
43+
44+
// Train.
45+
var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments());
46+
var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
47+
var validRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
48+
trainer.Train(new Runtime.TrainContext(trainRoles, validRoles));
49+
}
50+
}
51+
}
52+
}

0 commit comments

Comments
 (0)