Skip to content

Commit 6733515

Browse files
authored
Made 'StopWordsRemover' in TextFeaturizer configurable again. (#2962)
1 parent 665a366 commit 6733515

File tree

6 files changed

+111
-21
lines changed

6 files changed

+111
-21
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public static void Example()
3636
KeepPunctuations = false,
3737
KeepNumbers = false,
3838
OutputTokens = true,
39-
Language = TextFeaturizingEstimator.Language.English, // supports English, French, German, Dutch, Italian, Spanish, Japanese
39+
StopWordsRemoverOptions = new StopWordsRemovingEstimator.Options() { Language = TextFeaturizingEstimator.Language.English }, // supports English, French, German, Dutch, Italian, Spanish, Japanese
4040
}, "SentimentText");
4141

4242
// The transformed data for both pipelines.

src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs

+27
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,22 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a
489489
/// </summary>
490490
public sealed class StopWordsRemovingEstimator : TrivialEstimator<StopWordsRemovingTransformer>
491491
{
492+
/// <summary>
493+
/// Use stop words remover that can remove language-specific list of stop words (most common words) already defined in the system.
494+
/// </summary>
495+
public sealed class Options : IStopWordsRemoverOptions
496+
{
497+
/// <summary>
498+
/// Language of the text dataset. 'English' is default.
499+
/// </summary>
500+
public TextFeaturizingEstimator.Language Language;
501+
502+
public Options()
503+
{
504+
Language = TextFeaturizingEstimator.DefaultLanguage;
505+
}
506+
}
507+
492508
/// <summary>
493509
/// Describes how the transformer handles one column pair.
494510
/// </summary>
@@ -1065,6 +1081,17 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
10651081
/// </summary>
10661082
public sealed class CustomStopWordsRemovingEstimator : TrivialEstimator<CustomStopWordsRemovingTransformer>
10671083
{
1084+
/// <summary>
1085+
/// Use stop words remover that can removes language-specific list of stop words (most common words) already defined in the system.
1086+
/// </summary>
1087+
public sealed class Options : IStopWordsRemoverOptions
1088+
{
1089+
/// <summary>
1090+
/// List of stop words to remove.
1091+
/// </summary>
1092+
public string[] StopWords;
1093+
}
1094+
10681095
internal const string ExpectedColumnType = "vector of Text type";
10691096

10701097
/// <summary>

src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs

+74-14
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
namespace Microsoft.ML.Transforms.Text
2525
{
2626
using CaseMode = TextNormalizingEstimator.CaseMode;
27+
using StopWordsCol = StopWordsRemovingTransformer.Column;
28+
29+
/// <summary>
30+
/// Defines the different type of stop words remover supported.
31+
/// </summary>
32+
public interface IStopWordsRemoverOptions { }
33+
2734
// A transform that turns a collection of text documents into numerical feature vectors. The feature vectors are counts
2835
// of (word or character) ngrams in a given text. It offers ngram hashing (finding the ngram token string name to feature
2936
// integer index mapping through hashing) as an option.
@@ -93,10 +100,56 @@ public sealed class Options : TransformInputBase
93100
internal Column Columns;
94101

95102
[Argument(ArgumentType.AtMostOnce, HelpText = "Dataset language or 'AutoDetect' to detect language per row.", ShortName = "lang", SortOrder = 3)]
96-
public Language Language = DefaultLanguage;
103+
internal Language Language = DefaultLanguage;
104+
105+
[Argument(ArgumentType.Multiple, Name = "StopWordsRemover", HelpText = "Stopwords remover.", ShortName = "remover", NullName = "<None>", SortOrder = 4)]
106+
internal IStopWordsRemoverFactory StopWordsRemover;
97107

98-
[Argument(ArgumentType.Multiple, HelpText = "Use stop remover or not.", ShortName = "remover", SortOrder = 4)]
99-
public bool UsePredefinedStopWordRemover = false;
108+
/// <summary>
109+
/// The underlying state of <see cref="StopWordsRemover"/> and <see cref="StopWordsRemoverOptions"/>.
110+
/// </summary>
111+
private IStopWordsRemoverOptions _stopWordsRemoverOptions;
112+
113+
/// <summary>
114+
/// Option to set type of stop word remover to use.
115+
/// The following options are available
116+
/// <list type="bullet">
117+
/// <item>
118+
/// <description>The <see cref="StopWordsRemovingEstimator.Options"/> removes the language specific list of stop words from the input.</description>
119+
/// </item>
120+
/// <item>
121+
/// <description>The <see cref="CustomStopWordsRemovingEstimator.Options"/> uses user provided list of stop words.</description>
122+
/// </item>
123+
/// </list>
124+
/// Setting this to 'null' does not remove stop words from the input.
125+
/// </summary>
126+
public IStopWordsRemoverOptions StopWordsRemoverOptions
127+
{
128+
get { return _stopWordsRemoverOptions; }
129+
set
130+
{
131+
_stopWordsRemoverOptions = value;
132+
IStopWordsRemoverFactory options = null;
133+
if (_stopWordsRemoverOptions != null)
134+
{
135+
if (_stopWordsRemoverOptions is StopWordsRemovingEstimator.Options)
136+
{
137+
options = new PredefinedStopWordsRemoverFactory();
138+
Language = (_stopWordsRemoverOptions as StopWordsRemovingEstimator.Options).Language;
139+
}
140+
else if (_stopWordsRemoverOptions is CustomStopWordsRemovingEstimator.Options)
141+
{
142+
var stopwords = (_stopWordsRemoverOptions as CustomStopWordsRemovingEstimator.Options).StopWords;
143+
options = new CustomStopWordsRemovingTransformer.LoaderArguments()
144+
{
145+
Stopwords = stopwords,
146+
Stopword = string.Join(",", stopwords)
147+
};
148+
}
149+
}
150+
StopWordsRemover = options;
151+
}
152+
}
100153

101154
[Argument(ArgumentType.AtMostOnce, HelpText = "Casing text using the rules of the invariant culture.", Name="TextCase", ShortName = "case", SortOrder = 5)]
102155
public CaseMode CaseMode = TextNormalizingEstimator.Defaults.Mode;
@@ -202,6 +255,7 @@ public Options()
202255

203256
// These parameters are hardcoded for now.
204257
// REVIEW: expose them once sub-transforms are estimators.
258+
private IStopWordsRemoverFactory _stopWordsRemover;
205259
private TermLoaderArguments _dictionary;
206260
private INgramExtractorFactoryFactory _wordFeatureExtractor;
207261
private INgramExtractorFactoryFactory _charFeatureExtractor;
@@ -219,7 +273,7 @@ private sealed class TransformApplierParams
219273

220274
public readonly NormFunction Norm;
221275
public readonly Language Language;
222-
public readonly bool UsePredefinedStopWordRemover;
276+
public readonly IStopWordsRemoverFactory StopWordsRemover;
223277
public readonly CaseMode TextCase;
224278
public readonly bool KeepDiacritics;
225279
public readonly bool KeepPunctuations;
@@ -251,7 +305,9 @@ internal LpNormNormalizingEstimatorBase.NormFunction LpNorm
251305

252306
// These properties encode the logic needed to determine which transforms to apply.
253307
#region NeededTransforms
254-
public bool NeedsWordTokenizationTransform { get { return WordExtractorFactory != null || UsePredefinedStopWordRemover || OutputTextTokens; } }
308+
public bool NeedsWordTokenizationTransform { get { return WordExtractorFactory != null || NeedsRemoveStopwordsTransform || OutputTextTokens; } }
309+
310+
public bool NeedsRemoveStopwordsTransform { get { return StopWordsRemover != null; } }
255311

256312
public bool NeedsNormalizeTransform
257313
{
@@ -297,7 +353,7 @@ public TransformApplierParams(TextFeaturizingEstimator parent)
297353
CharExtractorFactory = parent._charFeatureExtractor?.CreateComponent(host, parent._dictionary);
298354
Norm = parent.OptionalSettings.Norm;
299355
Language = parent.OptionalSettings.Language;
300-
UsePredefinedStopWordRemover = parent.OptionalSettings.UsePredefinedStopWordRemover;
356+
StopWordsRemover = parent._stopWordsRemover;
301357
TextCase = parent.OptionalSettings.CaseMode;
302358
KeepDiacritics = parent.OptionalSettings.KeepDiacritics;
303359
KeepPunctuations = parent.OptionalSettings.KeepPunctuations;
@@ -339,6 +395,7 @@ internal TextFeaturizingEstimator(IHostEnvironment env, string name, IEnumerable
339395
if (options != null)
340396
OptionalSettings = options;
341397

398+
_stopWordsRemover = null;
342399
_dictionary = null;
343400
_wordFeatureExtractor = OptionalSettings.WordFeatureExtractorFactory;
344401
_charFeatureExtractor = OptionalSettings.CharFeatureExtractorFactory;
@@ -401,21 +458,23 @@ public ITransformer Fit(IDataView input)
401458
view = new WordTokenizingEstimator(h, xfCols).Fit(view).Transform(view);
402459
}
403460

404-
if (tparams.UsePredefinedStopWordRemover)
461+
if (tparams.NeedsRemoveStopwordsTransform)
405462
{
406463
Contracts.Assert(wordTokCols != null, "StopWords transform requires that word tokenization has been applied to the input text.");
407-
var xfCols = new StopWordsRemovingEstimator.ColumnOptions[wordTokCols.Length];
464+
var xfCols = new StopWordsCol[wordTokCols.Length];
408465
var dstCols = new string[wordTokCols.Length];
409466
for (int i = 0; i < wordTokCols.Length; i++)
410467
{
411-
var tempName = GenerateColumnName(view.Schema, wordTokCols[i], "StopWordsRemoverTransform");
412-
var col = new StopWordsRemovingEstimator.ColumnOptions(tempName, wordTokCols[i], tparams.StopwordsLanguage);
413-
dstCols[i] = tempName;
414-
tempCols.Add(tempName);
468+
var col = new StopWordsCol();
469+
col.Source = wordTokCols[i];
470+
col.Name = GenerateColumnName(view.Schema, wordTokCols[i], "StopWordsRemoverTransform");
471+
dstCols[i] = col.Name;
472+
tempCols.Add(col.Name);
473+
col.Language = tparams.StopwordsLanguage;
415474

416475
xfCols[i] = col;
417476
}
418-
view = new StopWordsRemovingEstimator(h, xfCols).Fit(view).Transform(view);
477+
view = tparams.StopWordsRemover.CreateComponent(h, view, xfCols);
419478
wordTokCols = dstCols;
420479
}
421480

@@ -442,7 +501,7 @@ public ITransformer Fit(IDataView input)
442501
if (tparams.CharExtractorFactory != null)
443502
{
444503
{
445-
var srcCols = tparams.UsePredefinedStopWordRemover ? wordTokCols : textCols;
504+
var srcCols = tparams.NeedsRemoveStopwordsTransform ? wordTokCols : textCols;
446505
charTokCols = new string[srcCols.Length];
447506
var xfCols = new (string outputColumnName, string inputColumnName)[srcCols.Length];
448507
for (int i = 0; i < srcCols.Length; i++)
@@ -567,6 +626,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
567626
internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView data)
568627
{
569628
var estimator = new TextFeaturizingEstimator(env, args.Columns.Name, args.Columns.Source ?? new[] { args.Columns.Name }, args);
629+
estimator._stopWordsRemover = args.StopWordsRemover;
570630
estimator._dictionary = args.Dictionary;
571631
// Review: I don't think the following two lines are needed.
572632
estimator._wordFeatureExtractor = args.WordFeatureExtractorFactory;

test/BaselineOutput/Common/EntryPoints/core_manifest.json

+7-4
Original file line numberDiff line numberDiff line change
@@ -22358,16 +22358,19 @@
2235822358
"Default": "English"
2235922359
},
2236022360
{
22361-
"Name": "UsePredefinedStopWordRemover",
22362-
"Type": "Bool",
22363-
"Desc": "Use stop remover or not.",
22361+
"Name": "StopWordsRemover",
22362+
"Type": {
22363+
"Kind": "Component",
22364+
"ComponentKind": "StopWordsRemover"
22365+
},
22366+
"Desc": "Stopwords remover.",
2236422367
"Aliases": [
2236522368
"remover"
2236622369
],
2236722370
"Required": false,
2236822371
"SortOrder": 4.0,
2236922372
"IsNullable": false,
22370-
"Default": false
22373+
"Default": null
2237122374
},
2237222375
{
2237322376
"Name": "TextCase",

test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public void TrainSentiment()
102102
{
103103
OutputTokens = true,
104104
KeepPunctuations = false,
105-
UsePredefinedStopWordRemover = true,
105+
StopWordsRemoverOptions = new StopWordsRemovingEstimator.Options(),
106106
Norm = TextFeaturizingEstimator.NormFunction.None,
107107
CharFeatureExtractor = null,
108108
WordFeatureExtractor = null,

test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,7 @@ public void EntryPointPipelineEnsembleText()
972972
{
973973
data = new TextFeaturizingEstimator(Env, "Features", new List<string> { "Text" },
974974
new TextFeaturizingEstimator.Options {
975-
UsePredefinedStopWordRemover = true,
975+
StopWordsRemoverOptions = new StopWordsRemovingEstimator.Options(),
976976
}).Fit(data).Transform(data);
977977
}
978978
else

0 commit comments

Comments
 (0)