From 868a72f25d58de022f038429cf68cf8a6e14413e Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Tue, 17 Jul 2018 14:54:54 -0700 Subject: [PATCH 1/6] Fixed the TextTransform bug where chargrams where being computed differently for differnt settings. --- .../Text/CharTokenizeTransform.cs | 33 ++++- .../TestAutoInference.cs | 2 +- .../Scenarios/SentimentPredictionTests.cs | 136 +++++++++--------- .../SentimentPredictionTests.cs | 2 +- 4 files changed, 97 insertions(+), 76 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs index 423c69ee4b..331537564c 100644 --- a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs @@ -405,16 +405,21 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) getSrc(ref src); int len = 0; + for (int i = 0; i < src.Count; i++) { if (src.Values[i].HasChars) { len += src.Values[i].Length; - if (_useMarkerChars) - len += TextMarkersCount; + + if (i > 0) + len += 1; // add space character that will be added } } + if (_useMarkerChars) + len += TextMarkersCount; + var values = dst.Values; if (len > 0) { @@ -422,17 +427,33 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) values = new ushort[len]; int index = 0; + + // VBuffer can be a result of either concatenating text columns together + // or application of word tokenizer before char tokenizer. + // + // Considering VBuffer as a single text stream. + // Therefore, prepend and append start and end markers only once i.e. at the start and at end of vector. + // Insert spaces after every piece of text in the vector. + if (_useMarkerChars) + values[index++] = TextStartMarker; + for (int i = 0; i < src.Count; i++) { if (!src.Values[i].HasChars) continue; - if (_useMarkerChars) - values[index++] = TextStartMarker; + + if (i > 0) + values[index++] = ' '; + for (int ich = 0; ich < src.Values[i].Length; ich++) + { values[index++] = src.Values[i][ich]; - if (_useMarkerChars) - values[index++] = TextEndMarker; + } } + + if (_useMarkerChars) + values[index++] = TextEndMarker; + Contracts.Assert(index == len); } diff --git a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs index 4d8ec880d1..b44492cb9b 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs @@ -352,7 +352,7 @@ public void EntryPointPipelineSweepRoles() var trainAuc = bestPipeline.PerformanceSummary.TrainingMetricValue; var testAuc = bestPipeline.PerformanceSummary.MetricValue; Assert.True((0.94 < trainAuc) && (trainAuc < 0.95)); - Assert.True((0.83 < testAuc) && (testAuc < 0.84)); + Assert.True((0.815 < testAuc) && (testAuc < 0.825)); var results = runner.GetOutput("ResultsOut"); Assert.NotNull(results); diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index b104570ca7..9c4283df45 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -75,47 +75,47 @@ public void CrossValidateSentimentModelTest() //Avergae of all folds. var metrics = cv.BinaryClassificationMetrics[0]; - Assert.Equal(0.57023626091422708, metrics.Accuracy, 4); - Assert.Equal(0.54960689910161487, metrics.Auc, 1); - Assert.Equal(0.67048277219704255, metrics.Auprc, 2); + Assert.Equal(0.603235747303544, metrics.Accuracy, 4); + Assert.Equal(0.58811318075483943, metrics.Auc, 4); + Assert.Equal(0.70302385499183984, metrics.Auprc, 4); Assert.Equal(0, metrics.Entropy, 3); - Assert.Equal(0.68942642723130532, metrics.F1Score, 4); - Assert.Equal(0.97695909611968434, metrics.LogLoss, 3); - Assert.Equal(-3.050726259114541, metrics.LogLossReduction, 3); - Assert.Equal(0.37553879310344829, metrics.NegativePrecision, 3); - Assert.Equal(0.25683962264150945, metrics.NegativeRecall, 3); - Assert.Equal(0.63428539173628362, metrics.PositivePrecision, 3); - Assert.Equal(0.75795196364816619, metrics.PositiveRecall); + Assert.Equal(0.71751777634130576, metrics.F1Score, 4); + Assert.Equal(0.95263103280238037, metrics.LogLoss, 4); + Assert.Equal(-0.39971801589876232, metrics.LogLossReduction, 4); + Assert.Equal(0.43965517241379309, metrics.NegativePrecision, 4); + Assert.Equal(0.26627358490566039, metrics.NegativeRecall, 4); + Assert.Equal(0.64937737441958632, metrics.PositivePrecision, 4); + Assert.Equal(0.8027426160337553, metrics.PositiveRecall); Assert.Null(metrics.ConfusionMatrix); //Std. Deviation. metrics = cv.BinaryClassificationMetrics[1]; - Assert.Equal(0.039933230611196011, metrics.Accuracy, 4); - Assert.Equal(0.021066177821462407, metrics.Auc, 1); - Assert.Equal(0.045842033921572725, metrics.Auprc, 2); + Assert.Equal(0.057781201848998764, metrics.Accuracy, 4); + Assert.Equal(0.04249579360413544, metrics.Auc, 4); + Assert.Equal(0.086083866074815427, metrics.Auprc, 4); Assert.Equal(0, metrics.Entropy, 3); - Assert.Equal(0.030085767890644915, metrics.F1Score, 4); - Assert.Equal(0.032906777175141941, metrics.LogLoss, 3); - Assert.Equal(0.86311349745170118, metrics.LogLossReduction, 3); - Assert.Equal(0.030711206896551647, metrics.NegativePrecision, 3); - Assert.Equal(0.068160377358490579, metrics.NegativeRecall, 3); - Assert.Equal(0.051761119891622735, metrics.PositivePrecision, 3); - Assert.Equal(0.0015417072379052127, metrics.PositiveRecall); + Assert.Equal(0.04718810601163604, metrics.F1Score, 4); + Assert.Equal(0.063839715206238851, metrics.LogLoss, 4); + Assert.Equal(4.1937544629633878, metrics.LogLossReduction, 4); + Assert.Equal(0.060344827586206781, metrics.NegativePrecision, 4); + Assert.Equal(0.058726415094339748, metrics.NegativeRecall, 4); + Assert.Equal(0.057144364710848418, metrics.PositivePrecision, 4); + Assert.Equal(0.030590717299577637, metrics.PositiveRecall); Assert.Null(metrics.ConfusionMatrix); //Fold 1. metrics = cv.BinaryClassificationMetrics[2]; - Assert.Equal(0.53030303030303028, metrics.Accuracy, 4); - Assert.Equal(0.52854072128015284, metrics.Auc, 1); - Assert.Equal(0.62464073827546951, metrics.Auprc, 2); + Assert.Equal(0.54545454545454541, metrics.Accuracy, 4); + Assert.Equal(0.54561738715070451, metrics.Auc, 4); + Assert.Equal(0.61693998891702417, metrics.Auprc, 4); Assert.Equal(0, metrics.Entropy, 3); - Assert.Equal(0.65934065934065933, metrics.F1Score, 4); - Assert.Equal(1.0098658732948276, metrics.LogLoss, 3); - Assert.Equal(-3.9138397565662424, metrics.LogLossReduction, 3); - Assert.Equal(0.34482758620689657, metrics.NegativePrecision, 3); - Assert.Equal(0.18867924528301888, metrics.NegativeRecall, 3); - Assert.Equal(0.58252427184466016, metrics.PositivePrecision, 3); - Assert.Equal(0.759493670886076, metrics.PositiveRecall); + Assert.Equal(0.67032967032967028, metrics.F1Score, 4); + Assert.Equal(1.0164707480086188, metrics.LogLoss, 4); + Assert.Equal(-4.59347247886215, metrics.LogLossReduction, 4); + Assert.Equal(0.37931034482758619, metrics.NegativePrecision, 4); + Assert.Equal(0.20754716981132076, metrics.NegativeRecall, 4); + Assert.Equal(0.59223300970873782, metrics.PositivePrecision, 4); + Assert.Equal(0.77215189873417722, metrics.PositiveRecall); var matrix = metrics.ConfusionMatrix; Assert.Equal(2, matrix.Order); @@ -123,29 +123,29 @@ public void CrossValidateSentimentModelTest() Assert.Equal("positive", matrix.ClassNames[0]); Assert.Equal("negative", matrix.ClassNames[1]); - Assert.Equal(60, matrix[0, 0]); - Assert.Equal(60, matrix["positive", "positive"]); - Assert.Equal(19, matrix[0, 1]); - Assert.Equal(19, matrix["positive", "negative"]); + Assert.Equal(61, matrix[0, 0]); + Assert.Equal(61, matrix["positive", "positive"]); + Assert.Equal(18, matrix[0, 1]); + Assert.Equal(18, matrix["positive", "negative"]); - Assert.Equal(43, matrix[1, 0]); - Assert.Equal(43, matrix["negative", "positive"]); - Assert.Equal(10, matrix[1, 1]); - Assert.Equal(10, matrix["negative", "negative"]); + Assert.Equal(42, matrix[1, 0]); + Assert.Equal(42, matrix["negative", "positive"]); + Assert.Equal(11, matrix[1, 1]); + Assert.Equal(11, matrix["negative", "negative"]); //Fold 2. metrics = cv.BinaryClassificationMetrics[3]; - Assert.Equal(0.61016949152542377, metrics.Accuracy, 4); - Assert.Equal(0.57067307692307689, metrics.Auc, 1); - Assert.Equal(0.71632480611861549, metrics.Auprc, 2); + Assert.Equal(0.66101694915254239, metrics.Accuracy, 4); + Assert.Equal(0.63060897435897434, metrics.Auc, 4); + Assert.Equal(0.7891077210666555, metrics.Auprc, 4); Assert.Equal(0, metrics.Entropy, 3); - Assert.Equal(0.71951219512195119, metrics.F1Score, 4); - Assert.Equal(0.94405231894454111, metrics.LogLoss, 3); - Assert.Equal(-2.1876127616628396, metrics.LogLossReduction, 3); - Assert.Equal(0.40625, metrics.NegativePrecision, 3); + Assert.Equal(0.76470588235294124, metrics.F1Score, 4); + Assert.Equal(0.88879131759614194, metrics.LogLoss, 4); + Assert.Equal(3.7940364470646255, metrics.LogLossReduction, 4); + Assert.Equal(0.5, metrics.NegativePrecision, 3); Assert.Equal(0.325, metrics.NegativeRecall, 3); - Assert.Equal(0.686046511627907, metrics.PositivePrecision, 3); - Assert.Equal(0.75641025641025639, metrics.PositiveRecall); + Assert.Equal(0.70652173913043481, metrics.PositivePrecision, 4); + Assert.Equal(0.83333333333333337, metrics.PositiveRecall); matrix = metrics.ConfusionMatrix; Assert.Equal(2, matrix.Order); @@ -153,10 +153,10 @@ public void CrossValidateSentimentModelTest() Assert.Equal("positive", matrix.ClassNames[0]); Assert.Equal("negative", matrix.ClassNames[1]); - Assert.Equal(59, matrix[0, 0]); - Assert.Equal(59, matrix["positive", "positive"]); - Assert.Equal(19, matrix[0, 1]); - Assert.Equal(19, matrix["positive", "negative"]); + Assert.Equal(65, matrix[0, 0]); + Assert.Equal(65, matrix["positive", "positive"]); + Assert.Equal(13, matrix[0, 1]); + Assert.Equal(13, matrix["positive", "negative"]); Assert.Equal(27, matrix[1, 0]); Assert.Equal(27, matrix["negative", "positive"]); @@ -180,11 +180,11 @@ private void ValidateBinaryMetricsLightGBM(BinaryClassificationMetrics metrics) Assert.Equal(.6111, metrics.Accuracy, 4); Assert.Equal(.8, metrics.Auc, 1); - Assert.Equal(.85, metrics.Auprc, 2); + Assert.Equal(0.88, metrics.Auprc, 2); Assert.Equal(1, metrics.Entropy, 3); Assert.Equal(.72, metrics.F1Score, 4); - Assert.Equal(.952, metrics.LogLoss, 3); - Assert.Equal(4.777, metrics.LogLossReduction, 3); + Assert.Equal(0.96456100297125325, metrics.LogLoss, 4); + Assert.Equal(3.5438997028746755, metrics.LogLossReduction, 4); Assert.Equal(1, metrics.NegativePrecision, 3); Assert.Equal(.222, metrics.NegativeRecall, 3); Assert.Equal(.562, metrics.PositivePrecision, 3); @@ -211,16 +211,16 @@ private void ValidateBinaryMetricsLightGBM(BinaryClassificationMetrics metrics) private void ValidateBinaryMetrics(BinaryClassificationMetrics metrics) { - Assert.Equal(.5556, metrics.Accuracy, 4); - Assert.Equal(.8, metrics.Auc, 1); - Assert.Equal(.87, metrics.Auprc, 2); + Assert.Equal(0.6111, metrics.Accuracy, 4); + Assert.Equal(0.6667, metrics.Auc, 4); + Assert.Equal(0.8621, metrics.Auprc, 4); Assert.Equal(1, metrics.Entropy, 3); - Assert.Equal(.6923, metrics.F1Score, 4); - Assert.Equal(.969, metrics.LogLoss, 3); - Assert.Equal(3.083, metrics.LogLossReduction, 3); - Assert.Equal(1, metrics.NegativePrecision, 3); - Assert.Equal(.111, metrics.NegativeRecall, 3); - Assert.Equal(.529, metrics.PositivePrecision, 3); + Assert.Equal(0.72, metrics.F1Score, 2); + Assert.Equal(0.9689, metrics.LogLoss, 4); + Assert.Equal(3.1122, metrics.LogLossReduction, 4); + Assert.Equal(1, metrics.NegativePrecision, 1); + Assert.Equal(0.2222, metrics.NegativeRecall, 4); + Assert.Equal(0.5625, metrics.PositivePrecision, 4); Assert.Equal(1, metrics.PositiveRecall); var matrix = metrics.ConfusionMatrix; @@ -234,10 +234,10 @@ private void ValidateBinaryMetrics(BinaryClassificationMetrics metrics) Assert.Equal(0, matrix[0, 1]); Assert.Equal(0, matrix["positive", "negative"]); - Assert.Equal(8, matrix[1, 0]); - Assert.Equal(8, matrix["negative", "positive"]); - Assert.Equal(1, matrix[1, 1]); - Assert.Equal(1, matrix["negative", "negative"]); + Assert.Equal(7, matrix[1, 0]); + Assert.Equal(7, matrix["negative", "positive"]); + Assert.Equal(2, matrix[1, 1]); + Assert.Equal(2, matrix["negative", "negative"]); } private LearningPipeline PreparePipeline() @@ -344,7 +344,7 @@ private void ValidateExamples(PredictionModel Date: Thu, 19 Jul 2018 15:35:16 -0700 Subject: [PATCH 2/6] Addressed reviewers' comments. --- .../Text/CharTokenizeTransform.cs | 62 ++++++++++++++++--- .../Text/TextTransform.cs | 54 ++++++++-------- .../TestPipelineSweeper.cs | 2 +- 3 files changed, 79 insertions(+), 39 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs index 331537564c..290df8a541 100644 --- a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs @@ -64,12 +64,13 @@ public sealed class Arguments : TransformInputBase public const string LoaderSignature = "CharToken"; public const string UserName = "Character Tokenizer Transform"; + public static uint CurrentModelVersion = 0x00010002; private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "CHARTOKN", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, + verWrittenCur: CurrentModelVersion, + verReadableCur: CurrentModelVersion, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature); } @@ -84,6 +85,7 @@ private static VersionInfo GetVersionInfo() private volatile string _keyValuesStr; private volatile int[] _keyValuesBoundaries; + private const ushort UnitSeparator = 0x1f; private const ushort TextStartMarker = 0x02; private const ushort TextEndMarker = 0x03; private const int TextMarkersCount = 2; @@ -119,6 +121,7 @@ private CharTokenizeTransform(IHost host, ModelLoadContext ctx, IDataView input) // // byte: _useMarkerChars value. _useMarkerChars = ctx.Reader.ReadBoolByte(); + CurrentModelVersion = ctx.Header.ModelVerReadable; _type = GetOutputColumnType(); SetMetadata(); @@ -397,15 +400,55 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) int cv = Infos[iinfo].TypeSrc.VectorSize; Contracts.Assert(cv >= 0); + var version = GetVersionInfo(); var getSrc = GetSrcGetter>(input, iinfo); var src = default(VBuffer); - return - (ref VBuffer dst) => + + ValueGetter> valueGetterOldVersion = (ref VBuffer dst) => + { + getSrc(ref src); + + int len = 0; + for (int i = 0; i < src.Count; i++) + { + if (src.Values[i].HasChars) + { + len += src.Values[i].Length; + if (_useMarkerChars) + len += TextMarkersCount; + } + } + + var values = dst.Values; + if (len > 0) + { + if (Utils.Size(values) < len) + values = new ushort[len]; + + int index = 0; + for (int i = 0; i < src.Count; i++) + { + if (!src.Values[i].HasChars) + continue; + if (_useMarkerChars) + values[index++] = TextStartMarker; + for (int ich = 0; ich < src.Values[i].Length; ich++) + values[index++] = src.Values[i][ich]; + if (_useMarkerChars) + values[index++] = TextEndMarker; + } + Contracts.Assert(index == len); + } + + dst = new VBuffer(len, values, dst.Indices); + }; + + ValueGetter < VBuffer > valueGetterCurrentVersion = (ref VBuffer dst) => { getSrc(ref src); int len = 0; - + for (int i = 0; i < src.Count; i++) { if (src.Values[i].HasChars) @@ -413,7 +456,7 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) len += src.Values[i].Length; if (i > 0) - len += 1; // add space character that will be added + len += 1; // add UnitSeparator character to len that will be added } } @@ -429,11 +472,11 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) int index = 0; // VBuffer can be a result of either concatenating text columns together - // or application of word tokenizer before char tokenizer. + // or application of word tokenizer before char tokenizer in TextTransform. // // Considering VBuffer as a single text stream. // Therefore, prepend and append start and end markers only once i.e. at the start and at end of vector. - // Insert spaces after every piece of text in the vector. + // Insert UnitSeparator after every piece of text in the vector. if (_useMarkerChars) values[index++] = TextStartMarker; @@ -443,7 +486,7 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) continue; if (i > 0) - values[index++] = ' '; + values[index++] = UnitSeparator; for (int ich = 0; ich < src.Values[i].Length; ich++) { @@ -459,6 +502,7 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) dst = new VBuffer(len, values, dst.Indices); }; + return CurrentModelVersion < version.VerReadableCur ? valueGetterOldVersion : valueGetterCurrentVersion; } } } diff --git a/src/Microsoft.ML.Transforms/Text/TextTransform.cs b/src/Microsoft.ML.Transforms/Text/TextTransform.cs index 932bf63272..b363a9f740 100644 --- a/src/Microsoft.ML.Transforms/Text/TextTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/TextTransform.cs @@ -262,6 +262,30 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV view = new ConcatTransform(h, new ConcatTransform.Arguments() { Column = xfCols }, view); } + if (tparams.NeedsNormalizeTransform) + { + var xfCols = new TextNormalizerCol[textCols.Length]; + string[] dstCols = new string[textCols.Length]; + for (int i = 0; i < textCols.Length; i++) + { + dstCols[i] = GenerateColumnName(view.Schema, textCols[i], "TextNormalizer"); + tempCols.Add(dstCols[i]); + xfCols[i] = new TextNormalizerCol() { Source = textCols[i], Name = dstCols[i] }; + } + + view = new TextNormalizerTransform(h, + new TextNormalizerArgs() + { + Column = xfCols, + KeepDiacritics = tparams.KeepDiacritics, + KeepNumbers = tparams.KeepNumbers, + KeepPunctuations = tparams.KeepPunctuations, + TextCase = tparams.TextCase + }, view); + + textCols = dstCols; + } + if (tparams.NeedsWordTokenizationTransform) { var xfCols = new DelimitedTokenizeTransform.Column[textCols.Length]; @@ -281,34 +305,6 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV view = new DelimitedTokenizeTransform(h, new DelimitedTokenizeTransform.Arguments() { Column = xfCols }, view); } - if (tparams.NeedsNormalizeTransform) - { - string[] srcCols = wordTokCols == null ? textCols : wordTokCols; - var xfCols = new TextNormalizerCol[srcCols.Length]; - string[] dstCols = new string[srcCols.Length]; - for (int i = 0; i < srcCols.Length; i++) - { - dstCols[i] = GenerateColumnName(view.Schema, srcCols[i], "TextNormalizer"); - tempCols.Add(dstCols[i]); - xfCols[i] = new TextNormalizerCol() { Source = srcCols[i], Name = dstCols[i] }; - } - - view = new TextNormalizerTransform(h, - new TextNormalizerArgs() - { - Column = xfCols, - KeepDiacritics = tparams.KeepDiacritics, - KeepNumbers = tparams.KeepNumbers, - KeepPunctuations = tparams.KeepPunctuations, - TextCase = tparams.TextCase - }, view); - - if (wordTokCols != null) - wordTokCols = dstCols; - else - textCols = dstCols; - } - if (tparams.NeedsRemoveStopwordsTransform) { Contracts.Assert(wordTokCols != null, "StopWords transform requires that word tokenization has been applied to the input text."); @@ -360,7 +356,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV if (tparams.CharExtractorFactory != null) { { - var srcCols = wordTokCols ?? textCols; + var srcCols = tparams.NeedsRemoveStopwordsTransform ? wordTokCols : textCols; charTokCols = new string[srcCols.Length]; var xfCols = new CharTokenizeTransform.Column[srcCols.Length]; for (int i = 0; i < srcCols.Length; i++) diff --git a/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs b/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs index 03fa8dfe29..d11811df6b 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs @@ -308,7 +308,7 @@ public void PipelineSweeperRoles() var trainAuc = bestPipeline.PerformanceSummary.TrainingMetricValue; var testAuc = bestPipeline.PerformanceSummary.MetricValue; Assert.True((0.94 < trainAuc) && (trainAuc < 0.95)); - Assert.True((0.83 < testAuc) && (testAuc < 0.84)); + Assert.True((0.815 < testAuc) && (testAuc < 0.825)); var results = runner.GetOutput("ResultsOut"); Assert.NotNull(results); From 2201433f095e00c82a11850eb0212581a8d13ce4 Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Sat, 21 Jul 2018 00:09:31 -0700 Subject: [PATCH 3/6] Addressed reviewers' comments. --- .../Text/CharTokenizeTransform.cs | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs index 290df8a541..9b3dc4ccdd 100644 --- a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs @@ -64,13 +64,16 @@ public sealed class Arguments : TransformInputBase public const string LoaderSignature = "CharToken"; public const string UserName = "Character Tokenizer Transform"; - public static uint CurrentModelVersion = 0x00010002; + // Keep track of the model that was saved with ver:0x00010001 + private readonly bool _isSeparatorStartEnd = false; + private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "CHARTOKN", - verWrittenCur: CurrentModelVersion, - verReadableCur: CurrentModelVersion, + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Updated to use UnitSeparator character instead of using for vector inputs. + verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature); } @@ -121,7 +124,9 @@ private CharTokenizeTransform(IHost host, ModelLoadContext ctx, IDataView input) // // byte: _useMarkerChars value. _useMarkerChars = ctx.Reader.ReadBoolByte(); - CurrentModelVersion = ctx.Header.ModelVerReadable; + + var version = GetVersionInfo(); + _isSeparatorStartEnd = ctx.Header.ModelVerReadable < version.VerReadableCur || ctx.Reader.ReadBoolByte(); _type = GetOutputColumnType(); SetMetadata(); @@ -148,6 +153,7 @@ public override void Save(ModelSaveContext ctx) // byte: _useMarkerChars value. SaveBase(ctx); ctx.Writer.WriteBoolByte(_useMarkerChars); + ctx.Writer.WriteBoolByte(_isSeparatorStartEnd); } protected override ColumnType GetColumnTypeCore(int iinfo) @@ -400,11 +406,10 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) int cv = Infos[iinfo].TypeSrc.VectorSize; Contracts.Assert(cv >= 0); - var version = GetVersionInfo(); var getSrc = GetSrcGetter>(input, iinfo); var src = default(VBuffer); - ValueGetter> valueGetterOldVersion = (ref VBuffer dst) => + ValueGetter> getterWithStartEndSep = (ref VBuffer dst) => { getSrc(ref src); @@ -443,7 +448,7 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) dst = new VBuffer(len, values, dst.Indices); }; - ValueGetter < VBuffer > valueGetterCurrentVersion = (ref VBuffer dst) => + ValueGetter < VBuffer > getterWithUnitSep = (ref VBuffer dst) => { getSrc(ref src); @@ -502,7 +507,7 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) dst = new VBuffer(len, values, dst.Indices); }; - return CurrentModelVersion < version.VerReadableCur ? valueGetterOldVersion : valueGetterCurrentVersion; + return _isSeparatorStartEnd ? getterWithStartEndSep : getterWithUnitSep; } } } From d1dd161aa5ebc7aa241fbf24e0b1de2b4027bf3f Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Wed, 25 Jul 2018 14:56:30 -0700 Subject: [PATCH 4/6] Addressed reviewers' comments. --- src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs index 9b3dc4ccdd..99da9b462a 100644 --- a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs @@ -125,8 +125,7 @@ private CharTokenizeTransform(IHost host, ModelLoadContext ctx, IDataView input) // byte: _useMarkerChars value. _useMarkerChars = ctx.Reader.ReadBoolByte(); - var version = GetVersionInfo(); - _isSeparatorStartEnd = ctx.Header.ModelVerReadable < version.VerReadableCur || ctx.Reader.ReadBoolByte(); + _isSeparatorStartEnd = ctx.Header.ModelVerReadable < 0x00010002 || ctx.Reader.ReadBoolByte(); _type = GetOutputColumnType(); SetMetadata(); From 90cdf2ae94d9e551f6bf752b8fb4b4ac1fb56d3b Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Wed, 25 Jul 2018 15:22:43 -0700 Subject: [PATCH 5/6] Fixed bugs detected by code analyzer. --- src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs index 99da9b462a..aab8aa65dd 100644 --- a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs @@ -65,7 +65,7 @@ public sealed class Arguments : TransformInputBase public const string UserName = "Character Tokenizer Transform"; // Keep track of the model that was saved with ver:0x00010001 - private readonly bool _isSeparatorStartEnd = false; + private readonly bool _isSeparatorStartEnd; private static VersionInfo GetVersionInfo() { @@ -107,6 +107,7 @@ public CharTokenizeTransform(IHostEnvironment env, Arguments args, IDataView inp _type = GetOutputColumnType(); SetMetadata(); + _isSeparatorStartEnd = false; } private static ColumnType GetOutputColumnType() @@ -475,7 +476,7 @@ private ValueGetter> MakeGetterVec(IRow input, int iinfo) int index = 0; - // VBuffer can be a result of either concatenating text columns together + // VBuffer can be a result of either concatenating text columns together // or application of word tokenizer before char tokenizer in TextTransform. // // Considering VBuffer as a single text stream. From 3ed6e872bf7b56698bfc82057a517961cdbbf171 Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Wed, 25 Jul 2018 16:23:26 -0700 Subject: [PATCH 6/6] Addressed reviewers' comments. --- src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs index aab8aa65dd..e1ea2974b3 100644 --- a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs @@ -107,7 +107,6 @@ public CharTokenizeTransform(IHostEnvironment env, Arguments args, IDataView inp _type = GetOutputColumnType(); SetMetadata(); - _isSeparatorStartEnd = false; } private static ColumnType GetOutputColumnType()