diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index e0f32bcaa1..6604961c57 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -31,6 +31,7 @@ private class TestClass { public string A; public string[] OutputTokens; + public float[] Features = null; } [Fact] @@ -41,7 +42,7 @@ public void TextFeaturizerWithPredefinedStopWordRemoverTest() var dataView = ML.Data.LoadFromEnumerable(data); var options = new TextFeaturizingEstimator.Options() { StopWordsRemoverOptions = new StopWordsRemovingEstimator.Options(), OutputTokensColumnName = "OutputTokens" }; - var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A"); + var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A"); var model = pipeline.Fit(dataView); var engine = model.CreatePredictionEngine(ML); var prediction = engine.Predict(data[0]); @@ -51,6 +52,95 @@ public void TextFeaturizerWithPredefinedStopWordRemoverTest() Assert.Equal("stop words", string.Join(" ", prediction.OutputTokens)); } + [Fact] + public void TextFeaturizerWithWordFeatureExtractorTest() + { + var data = new[] { new TestClass() { A = "This is some text in english", OutputTokens=null}, + new TestClass() { A = "This is another example", OutputTokens=null } }; + var dataView = ML.Data.LoadFromEnumerable(data); + + var options = new TextFeaturizingEstimator.Options() + { + WordFeatureExtractor = new WordBagEstimator.Options() { NgramLength = 1 }, + CharFeatureExtractor = null, + Norm = TextFeaturizingEstimator.NormFunction.None, + OutputTokensColumnName = "OutputTokens" + }; + var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A"); + var model = pipeline.Fit(dataView); + var engine = model.CreatePredictionEngine(ML); + + var prediction = engine.Predict(data[0]); + Assert.Equal(data[0].A.ToLower(), string.Join(" ", prediction.OutputTokens)); + var expected = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f }; + Assert.Equal(expected, prediction.Features); + + prediction = engine.Predict(data[1]); + Assert.Equal(data[1].A.ToLower(), string.Join(" ", prediction.OutputTokens)); + expected = new float[] { 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f }; + Assert.Equal(expected, prediction.Features); + } + + [Fact] + public void TextFeaturizerWithCharFeatureExtractorTest() + { + var data = new[] { new TestClass() { A = "abc efg", OutputTokens=null}, + new TestClass() { A = "xyz", OutputTokens=null } }; + var dataView = ML.Data.LoadFromEnumerable(data); + + var options = new TextFeaturizingEstimator.Options() + { + WordFeatureExtractor = null, + CharFeatureExtractor = new WordBagEstimator.Options() { NgramLength = 1 }, + Norm = TextFeaturizingEstimator.NormFunction.None, + OutputTokensColumnName = "OutputTokens" + }; + var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A"); + var model = pipeline.Fit(dataView); + var engine = model.CreatePredictionEngine(ML); + + var prediction = engine.Predict(data[0]); + Assert.Equal(data[0].A, string.Join(" ", prediction.OutputTokens)); + var expected = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f }; + Assert.Equal(expected, prediction.Features); + + prediction = engine.Predict(data[1]); + Assert.Equal(data[1].A, string.Join(" ", prediction.OutputTokens)); + expected = new float[] { 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f }; + Assert.Equal(expected, prediction.Features); + } + + [Fact] + public void TextFeaturizerWithL2NormTest() + { + var data = new[] { new TestClass() { A = "abc xyz", OutputTokens=null}, + new TestClass() { A = "xyz", OutputTokens=null } }; + var dataView = ML.Data.LoadFromEnumerable(data); + + var options = new TextFeaturizingEstimator.Options() + { + CharFeatureExtractor = new WordBagEstimator.Options() { NgramLength = 1}, + Norm = TextFeaturizingEstimator.NormFunction.L2, + OutputTokensColumnName = "OutputTokens" + }; + var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A"); + var model = pipeline.Fit(dataView); + var engine = model.CreatePredictionEngine(ML); + + var prediction = engine.Predict(data[0]); + Assert.Equal(data[0].A, string.Join(" ", prediction.OutputTokens)); + var exp1 = 0.333333343f; + var exp2 = 0.707106769f; + var expected = new float[] { exp1, exp1, exp1, exp1, exp1, exp1, exp1, exp1, exp1, exp2, exp2 }; + Assert.Equal(expected, prediction.Features); + + prediction = engine.Predict(data[1]); + exp1 = 0.4472136f; + Assert.Equal(data[1].A, string.Join(" ", prediction.OutputTokens)); + expected = new float[] { exp1, 0.0f, 0.0f, 0.0f, 0.0f, exp1, exp1, exp1, exp1, 0.0f, 1.0f }; + Assert.Equal(expected, prediction.Features); + } + [Fact] public void TextFeaturizerWithCustomStopWordRemoverTest() { @@ -67,7 +157,7 @@ public void TextFeaturizerWithCustomStopWordRemoverTest() OutputTokensColumnName = "OutputTokens", CaseMode = TextNormalizingEstimator.CaseMode.None }; - var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A"); + var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A"); var model = pipeline.Fit(dataView); var engine = model.CreatePredictionEngine(ML); var prediction = engine.Predict(data[0]); @@ -84,7 +174,7 @@ private void TestCaseMode(IDataView dataView, TestClass[] data, TextNormalizingE CaseMode = caseMode, OutputTokensColumnName = "OutputTokens" }; - var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A"); + var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A"); var model = pipeline.Fit(dataView); var engine = model.CreatePredictionEngine(ML); var prediction1 = engine.Predict(data[0]); @@ -133,7 +223,7 @@ private void TestKeepNumbers(IDataView dataView, TestClass[] data, bool keepNumb CaseMode = TextNormalizingEstimator.CaseMode.None, OutputTokensColumnName = "OutputTokens" }; - var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A"); + var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A"); var model = pipeline.Fit(dataView); var engine = model.CreatePredictionEngine(ML); var prediction1 = engine.Predict(data[0]); @@ -170,7 +260,7 @@ private void TestKeepPunctuations(IDataView dataView, TestClass[] data, bool kee CaseMode = TextNormalizingEstimator.CaseMode.None, OutputTokensColumnName = "OutputTokens" }; - var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A"); + var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A"); var model = pipeline.Fit(dataView); var engine = model.CreatePredictionEngine(ML); var prediction1 = engine.Predict(data[0]); @@ -208,7 +298,7 @@ private void TestKeepDiacritics(IDataView dataView, TestClass[] data, bool keepD CaseMode = TextNormalizingEstimator.CaseMode.None, OutputTokensColumnName = "OutputTokens" }; - var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A"); + var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A"); var model = pipeline.Fit(dataView); var engine = model.CreatePredictionEngine(ML); var prediction1 = engine.Predict(data[0]);