diff --git a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs index 423c69ee4b..e1ea2974b3 100644 --- a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs @@ -64,12 +64,16 @@ public sealed class Arguments : TransformInputBase public const string LoaderSignature = "CharToken"; public const string UserName = "Character Tokenizer Transform"; + // Keep track of the model that was saved with ver:0x00010001 + private readonly bool _isSeparatorStartEnd; + private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "CHARTOKN", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Updated to use UnitSeparator character instead of using for vector inputs. + verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature); } @@ -84,6 +88,7 @@ private static VersionInfo GetVersionInfo() private volatile string _keyValuesStr; private volatile int[] _keyValuesBoundaries; + private const ushort UnitSeparator = 0x1f; private const ushort TextStartMarker = 0x02; private const ushort TextEndMarker = 0x03; private const int TextMarkersCount = 2; @@ -120,6 +125,8 @@ private CharTokenizeTransform(IHost host, ModelLoadContext ctx, IDataView input) // byte: _useMarkerChars value. _useMarkerChars = ctx.Reader.ReadBoolByte(); + _isSeparatorStartEnd = ctx.Header.ModelVerReadable < 0x00010002 || ctx.Reader.ReadBoolByte(); + _type = GetOutputColumnType(); SetMetadata(); } @@ -145,6 +152,7 @@ public override void Save(ModelSaveContext ctx) // byte: _useMarkerChars value. SaveBase(ctx); ctx.Writer.WriteBoolByte(_useMarkerChars); + ctx.Writer.WriteBoolByte(_isSeparatorStartEnd); } protected override ColumnType GetColumnTypeCore(int iinfo) @@ -399,8 +407,8 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) var getSrc = GetSrcGetter>(input, iinfo); var src = default(VBuffer); - return - (ref VBuffer dst) => + + ValueGetter> getterWithStartEndSep = (ref VBuffer dst) => { getSrc(ref src); @@ -438,6 +446,67 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) dst = new VBuffer(len, values, dst.Indices); }; + + ValueGetter < VBuffer > getterWithUnitSep = (ref VBuffer dst) => + { + getSrc(ref src); + + int len = 0; + + for (int i = 0; i < src.Count; i++) + { + if (src.Values[i].HasChars) + { + len += src.Values[i].Length; + + if (i > 0) + len += 1; // add UnitSeparator character to len that will be added + } + } + + if (_useMarkerChars) + len += TextMarkersCount; + + var values = dst.Values; + if (len > 0) + { + if (Utils.Size(values) < len) + values = new ushort[len]; + + int index = 0; + + // VBuffer can be a result of either concatenating text columns together + // or application of word tokenizer before char tokenizer in TextTransform. + // + // Considering VBuffer as a single text stream. + // Therefore, prepend and append start and end markers only once i.e. at the start and at end of vector. + // Insert UnitSeparator after every piece of text in the vector. + if (_useMarkerChars) + values[index++] = TextStartMarker; + + for (int i = 0; i < src.Count; i++) + { + if (!src.Values[i].HasChars) + continue; + + if (i > 0) + values[index++] = UnitSeparator; + + for (int ich = 0; ich < src.Values[i].Length; ich++) + { + values[index++] = src.Values[i][ich]; + } + } + + if (_useMarkerChars) + values[index++] = TextEndMarker; + + Contracts.Assert(index == len); + } + + dst = new VBuffer(len, values, dst.Indices); + }; + return _isSeparatorStartEnd ? getterWithStartEndSep : getterWithUnitSep; } } } diff --git a/src/Microsoft.ML.Transforms/Text/TextTransform.cs b/src/Microsoft.ML.Transforms/Text/TextTransform.cs index c839e78556..3f13dd7612 100644 --- a/src/Microsoft.ML.Transforms/Text/TextTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/TextTransform.cs @@ -262,6 +262,30 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV view = new ConcatTransform(h, new ConcatTransform.Arguments() { Column = xfCols }, view); } + if (tparams.NeedsNormalizeTransform) + { + var xfCols = new TextNormalizerCol[textCols.Length]; + string[] dstCols = new string[textCols.Length]; + for (int i = 0; i < textCols.Length; i++) + { + dstCols[i] = GenerateColumnName(view.Schema, textCols[i], "TextNormalizer"); + tempCols.Add(dstCols[i]); + xfCols[i] = new TextNormalizerCol() { Source = textCols[i], Name = dstCols[i] }; + } + + view = new TextNormalizerTransform(h, + new TextNormalizerArgs() + { + Column = xfCols, + KeepDiacritics = tparams.KeepDiacritics, + KeepNumbers = tparams.KeepNumbers, + KeepPunctuations = tparams.KeepPunctuations, + TextCase = tparams.TextCase + }, view); + + textCols = dstCols; + } + if (tparams.NeedsWordTokenizationTransform) { var xfCols = new DelimitedTokenizeTransform.Column[textCols.Length]; @@ -281,34 +305,6 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV view = new DelimitedTokenizeTransform(h, new DelimitedTokenizeTransform.Arguments() { Column = xfCols }, view); } - if (tparams.NeedsNormalizeTransform) - { - string[] srcCols = wordTokCols == null ? textCols : wordTokCols; - var xfCols = new TextNormalizerCol[srcCols.Length]; - string[] dstCols = new string[srcCols.Length]; - for (int i = 0; i < srcCols.Length; i++) - { - dstCols[i] = GenerateColumnName(view.Schema, srcCols[i], "TextNormalizer"); - tempCols.Add(dstCols[i]); - xfCols[i] = new TextNormalizerCol() { Source = srcCols[i], Name = dstCols[i] }; - } - - view = new TextNormalizerTransform(h, - new TextNormalizerArgs() - { - Column = xfCols, - KeepDiacritics = tparams.KeepDiacritics, - KeepNumbers = tparams.KeepNumbers, - KeepPunctuations = tparams.KeepPunctuations, - TextCase = tparams.TextCase - }, view); - - if (wordTokCols != null) - wordTokCols = dstCols; - else - textCols = dstCols; - } - if (tparams.NeedsRemoveStopwordsTransform) { Contracts.Assert(wordTokCols != null, "StopWords transform requires that word tokenization has been applied to the input text."); @@ -360,7 +356,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV if (tparams.CharExtractorFactory != null) { { - var srcCols = wordTokCols ?? textCols; + var srcCols = tparams.NeedsRemoveStopwordsTransform ? wordTokCols : textCols; charTokCols = new string[srcCols.Length]; var xfCols = new CharTokenizeTransform.Column[srcCols.Length]; for (int i = 0; i < srcCols.Length; i++) diff --git a/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs b/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs index 03fa8dfe29..d11811df6b 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs @@ -308,7 +308,7 @@ public void PipelineSweeperRoles() var trainAuc = bestPipeline.PerformanceSummary.TrainingMetricValue; var testAuc = bestPipeline.PerformanceSummary.MetricValue; Assert.True((0.94 < trainAuc) && (trainAuc < 0.95)); - Assert.True((0.83 < testAuc) && (testAuc < 0.84)); + Assert.True((0.815 < testAuc) && (testAuc < 0.825)); var results = runner.GetOutput("ResultsOut"); Assert.NotNull(results); diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index b104570ca7..9c4283df45 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -75,47 +75,47 @@ public void CrossValidateSentimentModelTest() //Avergae of all folds. var metrics = cv.BinaryClassificationMetrics[0]; - Assert.Equal(0.57023626091422708, metrics.Accuracy, 4); - Assert.Equal(0.54960689910161487, metrics.Auc, 1); - Assert.Equal(0.67048277219704255, metrics.Auprc, 2); + Assert.Equal(0.603235747303544, metrics.Accuracy, 4); + Assert.Equal(0.58811318075483943, metrics.Auc, 4); + Assert.Equal(0.70302385499183984, metrics.Auprc, 4); Assert.Equal(0, metrics.Entropy, 3); - Assert.Equal(0.68942642723130532, metrics.F1Score, 4); - Assert.Equal(0.97695909611968434, metrics.LogLoss, 3); - Assert.Equal(-3.050726259114541, metrics.LogLossReduction, 3); - Assert.Equal(0.37553879310344829, metrics.NegativePrecision, 3); - Assert.Equal(0.25683962264150945, metrics.NegativeRecall, 3); - Assert.Equal(0.63428539173628362, metrics.PositivePrecision, 3); - Assert.Equal(0.75795196364816619, metrics.PositiveRecall); + Assert.Equal(0.71751777634130576, metrics.F1Score, 4); + Assert.Equal(0.95263103280238037, metrics.LogLoss, 4); + Assert.Equal(-0.39971801589876232, metrics.LogLossReduction, 4); + Assert.Equal(0.43965517241379309, metrics.NegativePrecision, 4); + Assert.Equal(0.26627358490566039, metrics.NegativeRecall, 4); + Assert.Equal(0.64937737441958632, metrics.PositivePrecision, 4); + Assert.Equal(0.8027426160337553, metrics.PositiveRecall); Assert.Null(metrics.ConfusionMatrix); //Std. Deviation. metrics = cv.BinaryClassificationMetrics[1]; - Assert.Equal(0.039933230611196011, metrics.Accuracy, 4); - Assert.Equal(0.021066177821462407, metrics.Auc, 1); - Assert.Equal(0.045842033921572725, metrics.Auprc, 2); + Assert.Equal(0.057781201848998764, metrics.Accuracy, 4); + Assert.Equal(0.04249579360413544, metrics.Auc, 4); + Assert.Equal(0.086083866074815427, metrics.Auprc, 4); Assert.Equal(0, metrics.Entropy, 3); - Assert.Equal(0.030085767890644915, metrics.F1Score, 4); - Assert.Equal(0.032906777175141941, metrics.LogLoss, 3); - Assert.Equal(0.86311349745170118, metrics.LogLossReduction, 3); - Assert.Equal(0.030711206896551647, metrics.NegativePrecision, 3); - Assert.Equal(0.068160377358490579, metrics.NegativeRecall, 3); - Assert.Equal(0.051761119891622735, metrics.PositivePrecision, 3); - Assert.Equal(0.0015417072379052127, metrics.PositiveRecall); + Assert.Equal(0.04718810601163604, metrics.F1Score, 4); + Assert.Equal(0.063839715206238851, metrics.LogLoss, 4); + Assert.Equal(4.1937544629633878, metrics.LogLossReduction, 4); + Assert.Equal(0.060344827586206781, metrics.NegativePrecision, 4); + Assert.Equal(0.058726415094339748, metrics.NegativeRecall, 4); + Assert.Equal(0.057144364710848418, metrics.PositivePrecision, 4); + Assert.Equal(0.030590717299577637, metrics.PositiveRecall); Assert.Null(metrics.ConfusionMatrix); //Fold 1. metrics = cv.BinaryClassificationMetrics[2]; - Assert.Equal(0.53030303030303028, metrics.Accuracy, 4); - Assert.Equal(0.52854072128015284, metrics.Auc, 1); - Assert.Equal(0.62464073827546951, metrics.Auprc, 2); + Assert.Equal(0.54545454545454541, metrics.Accuracy, 4); + Assert.Equal(0.54561738715070451, metrics.Auc, 4); + Assert.Equal(0.61693998891702417, metrics.Auprc, 4); Assert.Equal(0, metrics.Entropy, 3); - Assert.Equal(0.65934065934065933, metrics.F1Score, 4); - Assert.Equal(1.0098658732948276, metrics.LogLoss, 3); - Assert.Equal(-3.9138397565662424, metrics.LogLossReduction, 3); - Assert.Equal(0.34482758620689657, metrics.NegativePrecision, 3); - Assert.Equal(0.18867924528301888, metrics.NegativeRecall, 3); - Assert.Equal(0.58252427184466016, metrics.PositivePrecision, 3); - Assert.Equal(0.759493670886076, metrics.PositiveRecall); + Assert.Equal(0.67032967032967028, metrics.F1Score, 4); + Assert.Equal(1.0164707480086188, metrics.LogLoss, 4); + Assert.Equal(-4.59347247886215, metrics.LogLossReduction, 4); + Assert.Equal(0.37931034482758619, metrics.NegativePrecision, 4); + Assert.Equal(0.20754716981132076, metrics.NegativeRecall, 4); + Assert.Equal(0.59223300970873782, metrics.PositivePrecision, 4); + Assert.Equal(0.77215189873417722, metrics.PositiveRecall); var matrix = metrics.ConfusionMatrix; Assert.Equal(2, matrix.Order); @@ -123,29 +123,29 @@ public void CrossValidateSentimentModelTest() Assert.Equal("positive", matrix.ClassNames[0]); Assert.Equal("negative", matrix.ClassNames[1]); - Assert.Equal(60, matrix[0, 0]); - Assert.Equal(60, matrix["positive", "positive"]); - Assert.Equal(19, matrix[0, 1]); - Assert.Equal(19, matrix["positive", "negative"]); + Assert.Equal(61, matrix[0, 0]); + Assert.Equal(61, matrix["positive", "positive"]); + Assert.Equal(18, matrix[0, 1]); + Assert.Equal(18, matrix["positive", "negative"]); - Assert.Equal(43, matrix[1, 0]); - Assert.Equal(43, matrix["negative", "positive"]); - Assert.Equal(10, matrix[1, 1]); - Assert.Equal(10, matrix["negative", "negative"]); + Assert.Equal(42, matrix[1, 0]); + Assert.Equal(42, matrix["negative", "positive"]); + Assert.Equal(11, matrix[1, 1]); + Assert.Equal(11, matrix["negative", "negative"]); //Fold 2. metrics = cv.BinaryClassificationMetrics[3]; - Assert.Equal(0.61016949152542377, metrics.Accuracy, 4); - Assert.Equal(0.57067307692307689, metrics.Auc, 1); - Assert.Equal(0.71632480611861549, metrics.Auprc, 2); + Assert.Equal(0.66101694915254239, metrics.Accuracy, 4); + Assert.Equal(0.63060897435897434, metrics.Auc, 4); + Assert.Equal(0.7891077210666555, metrics.Auprc, 4); Assert.Equal(0, metrics.Entropy, 3); - Assert.Equal(0.71951219512195119, metrics.F1Score, 4); - Assert.Equal(0.94405231894454111, metrics.LogLoss, 3); - Assert.Equal(-2.1876127616628396, metrics.LogLossReduction, 3); - Assert.Equal(0.40625, metrics.NegativePrecision, 3); + Assert.Equal(0.76470588235294124, metrics.F1Score, 4); + Assert.Equal(0.88879131759614194, metrics.LogLoss, 4); + Assert.Equal(3.7940364470646255, metrics.LogLossReduction, 4); + Assert.Equal(0.5, metrics.NegativePrecision, 3); Assert.Equal(0.325, metrics.NegativeRecall, 3); - Assert.Equal(0.686046511627907, metrics.PositivePrecision, 3); - Assert.Equal(0.75641025641025639, metrics.PositiveRecall); + Assert.Equal(0.70652173913043481, metrics.PositivePrecision, 4); + Assert.Equal(0.83333333333333337, metrics.PositiveRecall); matrix = metrics.ConfusionMatrix; Assert.Equal(2, matrix.Order); @@ -153,10 +153,10 @@ public void CrossValidateSentimentModelTest() Assert.Equal("positive", matrix.ClassNames[0]); Assert.Equal("negative", matrix.ClassNames[1]); - Assert.Equal(59, matrix[0, 0]); - Assert.Equal(59, matrix["positive", "positive"]); - Assert.Equal(19, matrix[0, 1]); - Assert.Equal(19, matrix["positive", "negative"]); + Assert.Equal(65, matrix[0, 0]); + Assert.Equal(65, matrix["positive", "positive"]); + Assert.Equal(13, matrix[0, 1]); + Assert.Equal(13, matrix["positive", "negative"]); Assert.Equal(27, matrix[1, 0]); Assert.Equal(27, matrix["negative", "positive"]); @@ -180,11 +180,11 @@ private void ValidateBinaryMetricsLightGBM(BinaryClassificationMetrics metrics) Assert.Equal(.6111, metrics.Accuracy, 4); Assert.Equal(.8, metrics.Auc, 1); - Assert.Equal(.85, metrics.Auprc, 2); + Assert.Equal(0.88, metrics.Auprc, 2); Assert.Equal(1, metrics.Entropy, 3); Assert.Equal(.72, metrics.F1Score, 4); - Assert.Equal(.952, metrics.LogLoss, 3); - Assert.Equal(4.777, metrics.LogLossReduction, 3); + Assert.Equal(0.96456100297125325, metrics.LogLoss, 4); + Assert.Equal(3.5438997028746755, metrics.LogLossReduction, 4); Assert.Equal(1, metrics.NegativePrecision, 3); Assert.Equal(.222, metrics.NegativeRecall, 3); Assert.Equal(.562, metrics.PositivePrecision, 3); @@ -211,16 +211,16 @@ private void ValidateBinaryMetricsLightGBM(BinaryClassificationMetrics metrics) private void ValidateBinaryMetrics(BinaryClassificationMetrics metrics) { - Assert.Equal(.5556, metrics.Accuracy, 4); - Assert.Equal(.8, metrics.Auc, 1); - Assert.Equal(.87, metrics.Auprc, 2); + Assert.Equal(0.6111, metrics.Accuracy, 4); + Assert.Equal(0.6667, metrics.Auc, 4); + Assert.Equal(0.8621, metrics.Auprc, 4); Assert.Equal(1, metrics.Entropy, 3); - Assert.Equal(.6923, metrics.F1Score, 4); - Assert.Equal(.969, metrics.LogLoss, 3); - Assert.Equal(3.083, metrics.LogLossReduction, 3); - Assert.Equal(1, metrics.NegativePrecision, 3); - Assert.Equal(.111, metrics.NegativeRecall, 3); - Assert.Equal(.529, metrics.PositivePrecision, 3); + Assert.Equal(0.72, metrics.F1Score, 2); + Assert.Equal(0.9689, metrics.LogLoss, 4); + Assert.Equal(3.1122, metrics.LogLossReduction, 4); + Assert.Equal(1, metrics.NegativePrecision, 1); + Assert.Equal(0.2222, metrics.NegativeRecall, 4); + Assert.Equal(0.5625, metrics.PositivePrecision, 4); Assert.Equal(1, metrics.PositiveRecall); var matrix = metrics.ConfusionMatrix; @@ -234,10 +234,10 @@ private void ValidateBinaryMetrics(BinaryClassificationMetrics metrics) Assert.Equal(0, matrix[0, 1]); Assert.Equal(0, matrix["positive", "negative"]); - Assert.Equal(8, matrix[1, 0]); - Assert.Equal(8, matrix["negative", "positive"]); - Assert.Equal(1, matrix[1, 1]); - Assert.Equal(1, matrix["negative", "negative"]); + Assert.Equal(7, matrix[1, 0]); + Assert.Equal(7, matrix["negative", "positive"]); + Assert.Equal(2, matrix[1, 1]); + Assert.Equal(2, matrix["negative", "negative"]); } private LearningPipeline PreparePipeline() @@ -344,7 +344,7 @@ private void ValidateExamples(PredictionModel