4
4
5
5
using BenchmarkDotNet . Attributes ;
6
6
using BenchmarkDotNet . Engines ;
7
- using Microsoft . ML . Data ;
8
7
using Microsoft . ML . Models ;
8
+ using Microsoft . ML . Runtime ;
9
9
using Microsoft . ML . Runtime . Api ;
10
+ using Microsoft . ML . Runtime . Data ;
11
+ using Microsoft . ML . Runtime . Learners ;
10
12
using Microsoft . ML . Trainers ;
11
13
using Microsoft . ML . Transforms ;
12
14
using System ;
@@ -19,6 +21,7 @@ public class StochasticDualCoordinateAscentClassifierBench
19
21
internal static ClassificationMetrics s_metrics ;
20
22
private static PredictionModel < IrisData , IrisPrediction > s_trainedModel ;
21
23
private static string s_dataPath ;
24
+ private static string s_sentimentDataPath ;
22
25
private static IrisData [ ] [ ] s_batches ;
23
26
private static readonly int [ ] s_batchSizes = new int [ ] { 1 , 2 , 5 } ;
24
27
private readonly Random r = new Random ( 0 ) ;
@@ -35,10 +38,11 @@ public class StochasticDualCoordinateAscentClassifierBench
35
38
public void Setup ( )
36
39
{
37
40
s_dataPath = Program . GetDataPath ( "iris.txt" ) ;
41
+ s_sentimentDataPath = Program . GetDataPath ( "wikipedia-detox-250-line-data.tsv" ) ;
38
42
s_trainedModel = TrainCore ( ) ;
39
43
IrisPrediction prediction = s_trainedModel . Predict ( s_example ) ;
40
44
41
- var testData = new TextLoader ( s_dataPath ) . CreateFrom < IrisData > ( useHeader : true ) ;
45
+ var testData = new Data . TextLoader ( s_dataPath ) . CreateFrom < IrisData > ( useHeader : true ) ;
42
46
var evaluator = new ClassificationEvaluator ( ) ;
43
47
s_metrics = evaluator . Evaluate ( s_trainedModel , testData ) ;
44
48
@@ -69,6 +73,9 @@ public void Setup()
69
73
[ Benchmark ]
70
74
public void PredictIrisBatchOf5 ( ) => Consume ( s_trainedModel . Predict ( s_batches [ 2 ] ) ) ;
71
75
76
+ [ Benchmark ]
77
+ public IPredictor TrainSentiment ( ) => TrainSentimentCore ( ) ;
78
+
72
79
private void Consume ( IEnumerable < IrisPrediction > predictions )
73
80
{
74
81
foreach ( var prediction in predictions )
@@ -79,7 +86,7 @@ private static PredictionModel<IrisData, IrisPrediction> TrainCore()
79
86
{
80
87
var pipeline = new LearningPipeline ( ) ;
81
88
82
- pipeline . Add ( new TextLoader ( s_dataPath ) . CreateFrom < IrisData > ( useHeader : true ) ) ;
89
+ pipeline . Add ( new Data . TextLoader ( s_dataPath ) . CreateFrom < IrisData > ( useHeader : true ) ) ;
83
90
pipeline . Add ( new ColumnConcatenator ( outputColumn : "Features" ,
84
91
"SepalLength" , "SepalWidth" , "PetalLength" , "PetalWidth" ) ) ;
85
92
@@ -89,6 +96,76 @@ private static PredictionModel<IrisData, IrisPrediction> TrainCore()
89
96
return model ;
90
97
}
91
98
99
+ private static IPredictor TrainSentimentCore ( )
100
+ {
101
+ var dataPath = s_sentimentDataPath ;
102
+ using ( var env = new TlcEnvironment ( seed : 1 ) )
103
+ {
104
+ // Pipeline
105
+ var loader = new TextLoader ( env ,
106
+ new TextLoader . Arguments ( )
107
+ {
108
+ AllowQuoting = false ,
109
+ AllowSparse = false ,
110
+ Separator = "tab" ,
111
+ HasHeader = true ,
112
+ Column = new [ ]
113
+ {
114
+ new TextLoader . Column ( )
115
+ {
116
+ Name = "Label" ,
117
+ Source = new [ ] { new TextLoader . Range ( ) { Min = 0 , Max = 0 } } ,
118
+ Type = DataKind . Num
119
+ } ,
120
+
121
+ new TextLoader . Column ( )
122
+ {
123
+ Name = "SentimentText" ,
124
+ Source = new [ ] { new TextLoader . Range ( ) { Min = 1 , Max = 1 } } ,
125
+ Type = DataKind . Text
126
+ }
127
+ }
128
+ } , new MultiFileSource ( dataPath ) ) ;
129
+
130
+ var text = TextTransform . Create ( env ,
131
+ new TextTransform . Arguments ( )
132
+ {
133
+ Column = new TextTransform . Column
134
+ {
135
+ Name = "WordEmbeddings" ,
136
+ Source = new [ ] { "SentimentText" }
137
+ } ,
138
+ KeepDiacritics = false ,
139
+ KeepPunctuations = false ,
140
+ TextCase = Runtime . TextAnalytics . TextNormalizerTransform . CaseNormalizationMode . Lower ,
141
+ OutputTokens = true ,
142
+ StopWordsRemover = new Runtime . TextAnalytics . PredefinedStopWordsRemoverFactory ( ) ,
143
+ VectorNormalizer = TextTransform . TextNormKind . None ,
144
+ CharFeatureExtractor = null ,
145
+ WordFeatureExtractor = null ,
146
+ } , loader ) ;
147
+
148
+ var trans = new WordEmbeddingsTransform ( env ,
149
+ new WordEmbeddingsTransform . Arguments ( )
150
+ {
151
+ Column = new WordEmbeddingsTransform . Column [ 1 ]
152
+ {
153
+ new WordEmbeddingsTransform . Column
154
+ {
155
+ Name = "Features" ,
156
+ Source = "WordEmbeddings_TransformedText"
157
+ }
158
+ } ,
159
+ ModelKind = WordEmbeddingsTransform . PretrainedModelKind . Sswe ,
160
+ } , text ) ;
161
+
162
+ // Train
163
+ var trainer = new SdcaMultiClassTrainer ( env , new SdcaMultiClassTrainer . Arguments ( ) { MaxIterations = 20 } ) ;
164
+ var trainRoles = new RoleMappedData ( trans , label : "Label" , feature : "Features" ) ;
165
+ return trainer . Train ( trainRoles ) ;
166
+ }
167
+ }
168
+
92
169
public class IrisData
93
170
{
94
171
[ Column ( "0" ) ]
0 commit comments