Skip to content

Commit e00d19d

Browse files
authored
Added tests for text featurizer options (Part1). (#3006)
1 parent ce56462 commit e00d19d

File tree

2 files changed

+214
-2
lines changed

2 files changed

+214
-2
lines changed

src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ internal TextFeaturizingEstimator(IHostEnvironment env, string name, IEnumerable
393393
if (options != null)
394394
OptionalSettings = options;
395395

396-
_stopWordsRemover = null;
396+
_stopWordsRemover = OptionalSettings.StopWordsRemover;
397397
_dictionary = null;
398398
_wordFeatureExtractor = OptionalSettings.WordFeatureExtractorFactory;
399399
_charFeatureExtractor = OptionalSettings.CharFeatureExtractorFactory;

test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs

+213-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

55
using System;
66
using System.IO;
7+
using System.Text.RegularExpressions;
78
using Microsoft.ML;
89
using Microsoft.ML.Data;
910
using Microsoft.ML.Data.IO;
@@ -26,6 +27,217 @@ public TextFeaturizerTests(ITestOutputHelper helper)
2627
{
2728
}
2829

30+
private class TestClass
31+
{
32+
public string A;
33+
public string[] OutputTokens;
34+
}
35+
36+
[Fact]
37+
public void TextFeaturizerWithPredefinedStopWordRemoverTest()
38+
{
39+
var data = new[] { new TestClass() { A = "This is some text with english stop words", OutputTokens=null},
40+
new TestClass() { A = "No stop words", OutputTokens=null } };
41+
var dataView = ML.Data.LoadFromEnumerable(data);
42+
43+
var options = new TextFeaturizingEstimator.Options() { StopWordsRemoverOptions = new StopWordsRemovingEstimator.Options(), OutputTokensColumnName = "OutputTokens" };
44+
var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A");
45+
var model = pipeline.Fit(dataView);
46+
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
47+
var prediction = engine.Predict(data[0]);
48+
Assert.Equal("text english stop words", string.Join(" ", prediction.OutputTokens));
49+
50+
prediction = engine.Predict(data[1]);
51+
Assert.Equal("stop words", string.Join(" ", prediction.OutputTokens));
52+
}
53+
54+
[Fact]
55+
public void TextFeaturizerWithCustomStopWordRemoverTest()
56+
{
57+
var data = new[] { new TestClass() { A = "This is some text with english stop words", OutputTokens=null},
58+
new TestClass() { A = "No stop words", OutputTokens=null } };
59+
var dataView = ML.Data.LoadFromEnumerable(data);
60+
61+
var options = new TextFeaturizingEstimator.Options()
62+
{
63+
StopWordsRemoverOptions = new CustomStopWordsRemovingEstimator.Options()
64+
{
65+
StopWords = new[] { "stop", "words" }
66+
},
67+
OutputTokensColumnName = "OutputTokens",
68+
CaseMode = TextNormalizingEstimator.CaseMode.None
69+
};
70+
var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A");
71+
var model = pipeline.Fit(dataView);
72+
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
73+
var prediction = engine.Predict(data[0]);
74+
Assert.Equal("This is some text with english", string.Join(" ", prediction.OutputTokens));
75+
76+
prediction = engine.Predict(data[1]);
77+
Assert.Equal("No", string.Join(" ", prediction.OutputTokens));
78+
}
79+
80+
private void TestCaseMode(IDataView dataView, TestClass[] data, TextNormalizingEstimator.CaseMode caseMode)
81+
{
82+
var options = new TextFeaturizingEstimator.Options()
83+
{
84+
CaseMode = caseMode,
85+
OutputTokensColumnName = "OutputTokens"
86+
};
87+
var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A");
88+
var model = pipeline.Fit(dataView);
89+
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
90+
var prediction1 = engine.Predict(data[0]);
91+
var prediction2 = engine.Predict(data[1]);
92+
93+
string expected1 = null;
94+
string expected2 = null;
95+
if (caseMode == TextNormalizingEstimator.CaseMode.Upper)
96+
{
97+
expected1 = data[0].A.ToUpper();
98+
expected2 = data[1].A.ToUpper();
99+
}
100+
else if (caseMode == TextNormalizingEstimator.CaseMode.Lower)
101+
{
102+
expected1 = data[0].A.ToLower();
103+
expected2 = data[1].A.ToLower();
104+
}
105+
else if (caseMode == TextNormalizingEstimator.CaseMode.None)
106+
{
107+
expected1 = data[0].A;
108+
expected2 = data[1].A;
109+
}
110+
111+
Assert.Equal(expected1, string.Join(" ", prediction1.OutputTokens));
112+
Assert.Equal(expected2, string.Join(" ", prediction2.OutputTokens));
113+
}
114+
115+
[Fact]
116+
public void TextFeaturizerWithUpperCaseTest()
117+
{
118+
var data = new[] { new TestClass() { A = "This is some text with english stop words", OutputTokens=null},
119+
new TestClass() { A = "No stop words", OutputTokens=null } };
120+
var dataView = ML.Data.LoadFromEnumerable(data);
121+
122+
TestCaseMode(dataView, data, TextNormalizingEstimator.CaseMode.Lower);
123+
TestCaseMode(dataView, data, TextNormalizingEstimator.CaseMode.Upper);
124+
TestCaseMode(dataView, data, TextNormalizingEstimator.CaseMode.None);
125+
}
126+
127+
128+
private void TestKeepNumbers(IDataView dataView, TestClass[] data, bool keepNumbers)
129+
{
130+
var options = new TextFeaturizingEstimator.Options()
131+
{
132+
KeepNumbers = keepNumbers,
133+
CaseMode = TextNormalizingEstimator.CaseMode.None,
134+
OutputTokensColumnName = "OutputTokens"
135+
};
136+
var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A");
137+
var model = pipeline.Fit(dataView);
138+
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
139+
var prediction1 = engine.Predict(data[0]);
140+
var prediction2 = engine.Predict(data[1]);
141+
142+
if (keepNumbers)
143+
{
144+
Assert.Equal(data[0].A, string.Join(" ", prediction1.OutputTokens));
145+
Assert.Equal(data[1].A, string.Join(" ", prediction2.OutputTokens));
146+
}
147+
else
148+
{
149+
Assert.Equal(data[0].A.Replace("123 ", "").Replace("425", "").Replace("25", "").Replace("23", ""), string.Join(" ", prediction1.OutputTokens));
150+
Assert.Equal(data[1].A, string.Join(" ", prediction2.OutputTokens));
151+
}
152+
}
153+
154+
[Fact]
155+
public void TextFeaturizerWithKeepNumbersTest()
156+
{
157+
var data = new[] { new TestClass() { A = "This is some text with numbers 123 $425 25.23", OutputTokens=null},
158+
new TestClass() { A = "No numbers", OutputTokens=null } };
159+
var dataView = ML.Data.LoadFromEnumerable(data);
160+
161+
TestKeepNumbers(dataView, data, true);
162+
TestKeepNumbers(dataView, data, false);
163+
}
164+
165+
private void TestKeepPunctuations(IDataView dataView, TestClass[] data, bool keepPunctuations)
166+
{
167+
var options = new TextFeaturizingEstimator.Options()
168+
{
169+
KeepPunctuations = keepPunctuations,
170+
CaseMode = TextNormalizingEstimator.CaseMode.None,
171+
OutputTokensColumnName = "OutputTokens"
172+
};
173+
var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A");
174+
var model = pipeline.Fit(dataView);
175+
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
176+
var prediction1 = engine.Predict(data[0]);
177+
var prediction2 = engine.Predict(data[1]);
178+
179+
if (keepPunctuations)
180+
{
181+
Assert.Equal(data[0].A, string.Join(" ", prediction1.OutputTokens));
182+
Assert.Equal(data[1].A, string.Join(" ", prediction2.OutputTokens));
183+
}
184+
else
185+
{
186+
var expected = Regex.Replace(data[0].A, "[,|_|'|\"|;|\\.]", "");
187+
Assert.Equal(expected, string.Join(" ", prediction1.OutputTokens));
188+
Assert.Equal(data[1].A, string.Join(" ", prediction2.OutputTokens));
189+
}
190+
}
191+
192+
[Fact]
193+
public void TextFeaturizerWithKeepPunctuationsTest()
194+
{
195+
var data = new[] { new TestClass() { A = "This, is; some_ ,text 'with\" punctuations.", OutputTokens=null},
196+
new TestClass() { A = "No punctuations", OutputTokens=null } };
197+
var dataView = ML.Data.LoadFromEnumerable(data);
198+
199+
TestKeepPunctuations(dataView, data, true);
200+
TestKeepPunctuations(dataView, data, false);
201+
}
202+
203+
private void TestKeepDiacritics(IDataView dataView, TestClass[] data, bool keepDiacritics)
204+
{
205+
var options = new TextFeaturizingEstimator.Options()
206+
{
207+
KeepDiacritics = keepDiacritics,
208+
CaseMode = TextNormalizingEstimator.CaseMode.None,
209+
OutputTokensColumnName = "OutputTokens"
210+
};
211+
var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A");
212+
var model = pipeline.Fit(dataView);
213+
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
214+
var prediction1 = engine.Predict(data[0]);
215+
var prediction2 = engine.Predict(data[1]);
216+
217+
if (keepDiacritics)
218+
{
219+
Assert.Equal(data[0].A, string.Join(" ", prediction1.OutputTokens));
220+
Assert.Equal(data[1].A, string.Join(" ", prediction2.OutputTokens));
221+
}
222+
else
223+
{
224+
Assert.Equal("This is some text with diacritics", string.Join(" ", prediction1.OutputTokens));
225+
Assert.Equal(data[1].A, string.Join(" ", prediction2.OutputTokens));
226+
}
227+
}
228+
229+
[Fact]
230+
public void TextFeaturizerWithKeepDiacriticsTest()
231+
{
232+
var data = new[] { new TestClass() { A = "Thîs îs sóme text with diácrîtîcs", OutputTokens=null},
233+
new TestClass() { A = "No diacritics", OutputTokens=null } };
234+
var dataView = ML.Data.LoadFromEnumerable(data);
235+
236+
TestKeepDiacritics(dataView, data, true);
237+
TestKeepDiacritics(dataView, data, false);
238+
}
239+
240+
29241
[Fact]
30242
public void TextFeaturizerWorkout()
31243
{

0 commit comments

Comments
 (0)