Skip to content

Commit 4fd8a9c

Browse files
briancyluiZruty0
authored andcommitted
Add new benchmarks to test\Microsoft.ML.Benchmarks (#722)
1 parent 841ba78 commit 4fd8a9c

File tree

3 files changed

+177
-3
lines changed

3 files changed

+177
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using BenchmarkDotNet.Attributes;
6+
using BenchmarkDotNet.Running;
7+
using Microsoft.ML.Runtime;
8+
using Microsoft.ML.Runtime.Api;
9+
using Microsoft.ML.Runtime.CommandLine;
10+
using Microsoft.ML.Runtime.Data;
11+
using Microsoft.ML.Runtime.Learners;
12+
13+
namespace Microsoft.ML.Benchmarks
14+
{
15+
public class KMeansAndLogisticRegressionBench
16+
{
17+
private static string s_dataPath;
18+
19+
[Benchmark]
20+
public IPredictor TrainKMeansAndLR() => TrainKMeansAndLRCore();
21+
22+
[GlobalSetup]
23+
public void Setup()
24+
{
25+
s_dataPath = Program.GetDataPath("adult.train");
26+
}
27+
28+
private static IPredictor TrainKMeansAndLRCore()
29+
{
30+
string dataPath = s_dataPath;
31+
32+
using (var env = new TlcEnvironment(seed: 1))
33+
{
34+
// Pipeline
35+
var loader = new TextLoader(env,
36+
new TextLoader.Arguments()
37+
{
38+
HasHeader = true,
39+
Separator = ",",
40+
Column = new[] {
41+
new TextLoader.Column()
42+
{
43+
Name = "Label",
44+
Source = new [] { new TextLoader.Range() { Min = 14, Max = 14} },
45+
Type = DataKind.R4
46+
},
47+
new TextLoader.Column()
48+
{
49+
Name = "CatFeatures",
50+
Source = new [] {
51+
new TextLoader.Range() { Min = 1, Max = 1 },
52+
new TextLoader.Range() { Min = 3, Max = 3 },
53+
new TextLoader.Range() { Min = 5, Max = 9 },
54+
new TextLoader.Range() { Min = 13, Max = 13 }
55+
},
56+
Type = DataKind.TX
57+
},
58+
new TextLoader.Column()
59+
{
60+
Name = "NumFeatures",
61+
Source = new [] {
62+
new TextLoader.Range() { Min = 0, Max = 0 },
63+
new TextLoader.Range() { Min = 2, Max = 2 },
64+
new TextLoader.Range() { Min = 4, Max = 4 },
65+
new TextLoader.Range() { Min = 10, Max = 12 }
66+
},
67+
Type = DataKind.R4
68+
}
69+
}
70+
}, new MultiFileSource(dataPath));
71+
72+
IDataTransform trans = CategoricalTransform.Create(env, new CategoricalTransform.Arguments
73+
{
74+
Column = new[]
75+
{
76+
new CategoricalTransform.Column { Name = "CatFeatures", Source = "CatFeatures" }
77+
}
78+
}, loader);
79+
80+
trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "NumFeatures");
81+
trans = new ConcatTransform(env, trans, "Features", "NumFeatures", "CatFeatures");
82+
trans = TrainAndScoreTransform.Create(env, new TrainAndScoreTransform.Arguments
83+
{
84+
Trainer = new SubComponent<ITrainer, SignatureTrainer>("KMeans", "k=100"),
85+
FeatureColumn = "Features"
86+
}, trans);
87+
trans = new ConcatTransform(env, trans, "Features", "Features", "Score");
88+
89+
// Train
90+
var trainer = new LogisticRegression(env, new LogisticRegression.Arguments() { EnforceNonNegativity = true, OptTol = 1e-3f });
91+
var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
92+
return trainer.Train(trainRoles);
93+
}
94+
}
95+
}
96+
}

test/Microsoft.ML.Benchmarks/Microsoft.ML.Benchmarks.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
<PackageReference Include="BenchmarkDotNet" Version="$(BenchmarkDotNetVersion)" />
1414
</ItemGroup>
1515
<ItemGroup>
16+
<ProjectReference Include="..\..\src\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
1617
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
1718
<ProjectReference Include="..\..\src\Microsoft.ML\Microsoft.ML.csproj" />
1819
</ItemGroup>

test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs

+80-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
using BenchmarkDotNet.Attributes;
66
using BenchmarkDotNet.Engines;
7-
using Microsoft.ML.Data;
87
using Microsoft.ML.Models;
8+
using Microsoft.ML.Runtime;
99
using Microsoft.ML.Runtime.Api;
10+
using Microsoft.ML.Runtime.Data;
11+
using Microsoft.ML.Runtime.Learners;
1012
using Microsoft.ML.Trainers;
1113
using Microsoft.ML.Transforms;
1214
using System;
@@ -19,6 +21,7 @@ public class StochasticDualCoordinateAscentClassifierBench
1921
internal static ClassificationMetrics s_metrics;
2022
private static PredictionModel<IrisData, IrisPrediction> s_trainedModel;
2123
private static string s_dataPath;
24+
private static string s_sentimentDataPath;
2225
private static IrisData[][] s_batches;
2326
private static readonly int[] s_batchSizes = new int[] { 1, 2, 5 };
2427
private readonly Random r = new Random(0);
@@ -35,10 +38,11 @@ public class StochasticDualCoordinateAscentClassifierBench
3538
public void Setup()
3639
{
3740
s_dataPath = Program.GetDataPath("iris.txt");
41+
s_sentimentDataPath = Program.GetDataPath("wikipedia-detox-250-line-data.tsv");
3842
s_trainedModel = TrainCore();
3943
IrisPrediction prediction = s_trainedModel.Predict(s_example);
4044

41-
var testData = new TextLoader(s_dataPath).CreateFrom<IrisData>(useHeader: true);
45+
var testData = new Data.TextLoader(s_dataPath).CreateFrom<IrisData>(useHeader: true);
4246
var evaluator = new ClassificationEvaluator();
4347
s_metrics = evaluator.Evaluate(s_trainedModel, testData);
4448

@@ -69,6 +73,9 @@ public void Setup()
6973
[Benchmark]
7074
public void PredictIrisBatchOf5() => Consume(s_trainedModel.Predict(s_batches[2]));
7175

76+
[Benchmark]
77+
public IPredictor TrainSentiment() => TrainSentimentCore();
78+
7279
private void Consume(IEnumerable<IrisPrediction> predictions)
7380
{
7481
foreach (var prediction in predictions)
@@ -79,7 +86,7 @@ private static PredictionModel<IrisData, IrisPrediction> TrainCore()
7986
{
8087
var pipeline = new LearningPipeline();
8188

82-
pipeline.Add(new TextLoader(s_dataPath).CreateFrom<IrisData>(useHeader: true));
89+
pipeline.Add(new Data.TextLoader(s_dataPath).CreateFrom<IrisData>(useHeader: true));
8390
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
8491
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
8592

@@ -89,6 +96,76 @@ private static PredictionModel<IrisData, IrisPrediction> TrainCore()
8996
return model;
9097
}
9198

99+
private static IPredictor TrainSentimentCore()
100+
{
101+
var dataPath = s_sentimentDataPath;
102+
using (var env = new TlcEnvironment(seed: 1))
103+
{
104+
// Pipeline
105+
var loader = new TextLoader(env,
106+
new TextLoader.Arguments()
107+
{
108+
AllowQuoting = false,
109+
AllowSparse = false,
110+
Separator = "tab",
111+
HasHeader = true,
112+
Column = new[]
113+
{
114+
new TextLoader.Column()
115+
{
116+
Name = "Label",
117+
Source = new [] { new TextLoader.Range() { Min=0, Max=0} },
118+
Type = DataKind.Num
119+
},
120+
121+
new TextLoader.Column()
122+
{
123+
Name = "SentimentText",
124+
Source = new [] { new TextLoader.Range() { Min=1, Max=1} },
125+
Type = DataKind.Text
126+
}
127+
}
128+
}, new MultiFileSource(dataPath));
129+
130+
var text = TextTransform.Create(env,
131+
new TextTransform.Arguments()
132+
{
133+
Column = new TextTransform.Column
134+
{
135+
Name = "WordEmbeddings",
136+
Source = new[] { "SentimentText" }
137+
},
138+
KeepDiacritics = false,
139+
KeepPunctuations = false,
140+
TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
141+
OutputTokens = true,
142+
StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
143+
VectorNormalizer = TextTransform.TextNormKind.None,
144+
CharFeatureExtractor = null,
145+
WordFeatureExtractor = null,
146+
}, loader);
147+
148+
var trans = new WordEmbeddingsTransform(env,
149+
new WordEmbeddingsTransform.Arguments()
150+
{
151+
Column = new WordEmbeddingsTransform.Column[1]
152+
{
153+
new WordEmbeddingsTransform.Column
154+
{
155+
Name = "Features",
156+
Source = "WordEmbeddings_TransformedText"
157+
}
158+
},
159+
ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe,
160+
}, text);
161+
162+
// Train
163+
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { MaxIterations = 20 });
164+
var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
165+
return trainer.Train(trainRoles);
166+
}
167+
}
168+
92169
public class IrisData
93170
{
94171
[Column("0")]

0 commit comments

Comments
 (0)