Skip to content

Commit b468056

Browse files
author
Ivan Matantsev
committed
idv and decomposableTrain
1 parent d051138 commit b468056

File tree

4 files changed

+162
-0
lines changed

4 files changed

+162
-0
lines changed

test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,25 @@ public ApiScenariosTests(ITestOutputHelper output) : base(output)
1717
{
1818
}
1919

20+
public const string IrisDataPath = "iris.data";
2021
public const string SentimentDataPath = "wikipedia-detox-250-line-data.tsv";
2122
public const string SentimentTestPath = "wikipedia-detox-250-line-test.tsv";
2223

24+
public class IrisData
25+
{
26+
public float SepalLength;
27+
public float SepalWidth;
28+
public float PetalLength;
29+
public float PetalWidth;
30+
public string Label;
31+
}
32+
33+
public class IrisPrediction
34+
{
35+
public string PredictedLabel;
36+
public float[] Score;
37+
}
38+
2339
public class SentimentData
2440
{
2541
[ColumnName("Label")]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using Microsoft.ML.Runtime.Api;
2+
using Microsoft.ML.Runtime.Data;
3+
using Microsoft.ML.Runtime.Learners;
4+
using System.Linq;
5+
using Xunit;
6+
7+
namespace Microsoft.ML.Tests.Scenarios.Api
8+
{
9+
10+
public partial class ApiScenariosTests
11+
{
12+
/// <summary>
13+
/// Decomposable train and predict: Train on Iris multiclass problem, which will require
14+
/// a transform on labels. Be able to reconstitute the pipeline for a prediction only task,
15+
/// which will essentially "drop" the transform over labels, while retaining the property
16+
/// that the predicted label for this has a key-type, the probability outputs for the classes
17+
/// have the class labels as slot names, etc. This should be do-able without ugly compromises like,
18+
/// say, injecting a dummy label.
19+
/// </summary>
20+
[Fact]
21+
void DecomposableTrainAndPredictcs()
22+
{
23+
var dataPath = GetDataPath(IrisDataPath);
24+
using (var env = new TlcEnvironment())
25+
{
26+
var loader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
27+
var term = new TermTransform(env, loader, "Label");
28+
var concat = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
29+
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 });
30+
31+
IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat;
32+
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");
33+
34+
// Auto-normalization.
35+
NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
36+
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
37+
38+
var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
39+
IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);
40+
41+
// Cut of term transform from pipeline.
42+
var new_scorer = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term);
43+
var keyToValue = new KeyToValueTransform(env, new_scorer, "PredictedLabel");
44+
var model = env.CreatePredictionEngine<IrisData, IrisPrediction>(keyToValue);
45+
46+
var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
47+
var testData = testLoader.AsEnumerable<IrisData>(env, false);
48+
foreach (var input in testData.Take(20))
49+
{
50+
var prediction = model.Predict(input);
51+
Assert.True(prediction.PredictedLabel == input.Label);
52+
}
53+
}
54+
}
55+
}
56+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
using Microsoft.ML.Runtime.Data;
2+
using Microsoft.ML.Runtime.Data.IO;
3+
using Microsoft.ML.Runtime.Learners;
4+
using Xunit;
5+
6+
namespace Microsoft.ML.Tests.Scenarios.Api
7+
{
8+
public partial class ApiScenariosTests
9+
{
10+
/// <summary>
11+
/// File-based saving of data: Come up with transform pipeline. Transform training and
12+
/// test data, and save the featurized data to some file, using the .idv format.
13+
/// Train and evaluate multiple models over that pre-featurized data. (Useful for
14+
/// sweeping scenarios, where you are training many times on the same data,
15+
/// and don't necessarily want to transform it every single time.)
16+
/// </summary>
17+
[Fact]
18+
void FileBasedSavingOfData()
19+
{
20+
var dataPath = GetDataPath(SentimentDataPath);
21+
var testDataPath = GetDataPath(SentimentTestPath);
22+
23+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
24+
{
25+
// Pipeline
26+
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
27+
28+
var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);
29+
var saver = new BinarySaver(env, new BinarySaver.Arguments());
30+
using (var ch = env.Start("SaveData"))
31+
using (var file = env.CreateOutputFile("i.idv"))
32+
{
33+
DataSaverUtils.SaveDataView(ch, saver, trans, file);
34+
}
35+
36+
var binData = new BinaryLoader(env, new BinaryLoader.Arguments(), new MultiFileSource("i.idv"));
37+
var trainRoles = new RoleMappedData(binData, label: "Label", feature: "Features");
38+
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
39+
{
40+
NumThreads = 1
41+
});
42+
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
43+
44+
DeleteOutputPath("temp.idv");
45+
}
46+
}
47+
}
48+
}

test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,48 @@ private static TextTransform.Arguments MakeSentimentTextTransformArgs(bool norma
114114
};
115115
}
116116

117+
private static TextLoader.Arguments MakeIrisTextLoaderArgs()
118+
{
119+
120+
return new TextLoader.Arguments()
121+
{
122+
Separator = "comma",
123+
HasHeader = true,
124+
Column = new[]
125+
{
126+
new TextLoader.Column()
127+
{
128+
Name = "SepalLength",
129+
Source = new [] { new TextLoader.Range() { Min=0, Max=0} },
130+
Type = DataKind.R4
131+
},
132+
new TextLoader.Column()
133+
{
134+
Name = "SepalWidth",
135+
Source = new [] { new TextLoader.Range() { Min=1, Max=1} },
136+
Type = DataKind.R4
137+
},
138+
new TextLoader.Column()
139+
{
140+
Name = "PetalLength",
141+
Source = new [] { new TextLoader.Range() { Min=2, Max=2} },
142+
Type = DataKind.R4
143+
},
144+
new TextLoader.Column()
145+
{
146+
Name = "PetalWidth",
147+
Source = new [] { new TextLoader.Range() { Min=3, Max=3} },
148+
Type = DataKind.R4
149+
},
150+
new TextLoader.Column()
151+
{
152+
Name = "Label",
153+
Source = new [] { new TextLoader.Range() { Min=4, Max=4} },
154+
Type = DataKind.Text
155+
}
156+
}
157+
};
158+
}
117159
private static TextLoader.Arguments MakeSentimentTextLoaderArgs()
118160
{
119161
return new TextLoader.Arguments()

0 commit comments

Comments
 (0)