Skip to content

Commit 0e0f702

Browse files
authored
Fixed the TextTransform bug where chargrams where being computed differently when using with/without word tokenizer. (dotnet#548)
1 parent 7fea0af commit 0e0f702

File tree

5 files changed

+168
-103
lines changed

5 files changed

+168
-103
lines changed

src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,16 @@ public sealed class Arguments : TransformInputBase
6464
public const string LoaderSignature = "CharToken";
6565
public const string UserName = "Character Tokenizer Transform";
6666

67+
// Keep track of the model that was saved with ver:0x00010001
68+
private readonly bool _isSeparatorStartEnd;
69+
6770
private static VersionInfo GetVersionInfo()
6871
{
6972
return new VersionInfo(
7073
modelSignature: "CHARTOKN",
71-
verWrittenCur: 0x00010001, // Initial
72-
verReadableCur: 0x00010001,
74+
//verWrittenCur: 0x00010001, // Initial
75+
verWrittenCur: 0x00010002, // Updated to use UnitSeparator <US> character instead of using <ETX><STX> for vector inputs.
76+
verReadableCur: 0x00010002,
7377
verWeCanReadBack: 0x00010001,
7478
loaderSignature: LoaderSignature);
7579
}
@@ -84,6 +88,7 @@ private static VersionInfo GetVersionInfo()
8488
private volatile string _keyValuesStr;
8589
private volatile int[] _keyValuesBoundaries;
8690

91+
private const ushort UnitSeparator = 0x1f;
8792
private const ushort TextStartMarker = 0x02;
8893
private const ushort TextEndMarker = 0x03;
8994
private const int TextMarkersCount = 2;
@@ -120,6 +125,8 @@ private CharTokenizeTransform(IHost host, ModelLoadContext ctx, IDataView input)
120125
// byte: _useMarkerChars value.
121126
_useMarkerChars = ctx.Reader.ReadBoolByte();
122127

128+
_isSeparatorStartEnd = ctx.Header.ModelVerReadable < 0x00010002 || ctx.Reader.ReadBoolByte();
129+
123130
_type = GetOutputColumnType();
124131
SetMetadata();
125132
}
@@ -145,6 +152,7 @@ public override void Save(ModelSaveContext ctx)
145152
// byte: _useMarkerChars value.
146153
SaveBase(ctx);
147154
ctx.Writer.WriteBoolByte(_useMarkerChars);
155+
ctx.Writer.WriteBoolByte(_isSeparatorStartEnd);
148156
}
149157

150158
protected override ColumnType GetColumnTypeCore(int iinfo)
@@ -399,8 +407,8 @@ private ValueGetter<VBuffer<ushort>> MakeGetterVec(IRow input, int iinfo)
399407

400408
var getSrc = GetSrcGetter<VBuffer<DvText>>(input, iinfo);
401409
var src = default(VBuffer<DvText>);
402-
return
403-
(ref VBuffer<ushort> dst) =>
410+
411+
ValueGetter<VBuffer<ushort>> getterWithStartEndSep = (ref VBuffer<ushort> dst) =>
404412
{
405413
getSrc(ref src);
406414

@@ -438,6 +446,67 @@ private ValueGetter<VBuffer<ushort>> MakeGetterVec(IRow input, int iinfo)
438446

439447
dst = new VBuffer<ushort>(len, values, dst.Indices);
440448
};
449+
450+
ValueGetter < VBuffer<ushort> > getterWithUnitSep = (ref VBuffer<ushort> dst) =>
451+
{
452+
getSrc(ref src);
453+
454+
int len = 0;
455+
456+
for (int i = 0; i < src.Count; i++)
457+
{
458+
if (src.Values[i].HasChars)
459+
{
460+
len += src.Values[i].Length;
461+
462+
if (i > 0)
463+
len += 1; // add UnitSeparator character to len that will be added
464+
}
465+
}
466+
467+
if (_useMarkerChars)
468+
len += TextMarkersCount;
469+
470+
var values = dst.Values;
471+
if (len > 0)
472+
{
473+
if (Utils.Size(values) < len)
474+
values = new ushort[len];
475+
476+
int index = 0;
477+
478+
// VBuffer<DvText> can be a result of either concatenating text columns together
479+
// or application of word tokenizer before char tokenizer in TextTransform.
480+
//
481+
// Considering VBuffer<DvText> as a single text stream.
482+
// Therefore, prepend and append start and end markers only once i.e. at the start and at end of vector.
483+
// Insert UnitSeparator after every piece of text in the vector.
484+
if (_useMarkerChars)
485+
values[index++] = TextStartMarker;
486+
487+
for (int i = 0; i < src.Count; i++)
488+
{
489+
if (!src.Values[i].HasChars)
490+
continue;
491+
492+
if (i > 0)
493+
values[index++] = UnitSeparator;
494+
495+
for (int ich = 0; ich < src.Values[i].Length; ich++)
496+
{
497+
values[index++] = src.Values[i][ich];
498+
}
499+
}
500+
501+
if (_useMarkerChars)
502+
values[index++] = TextEndMarker;
503+
504+
Contracts.Assert(index == len);
505+
}
506+
507+
dst = new VBuffer<ushort>(len, values, dst.Indices);
508+
};
509+
return _isSeparatorStartEnd ? getterWithStartEndSep : getterWithUnitSep;
441510
}
442511
}
443512
}

src/Microsoft.ML.Transforms/Text/TextTransform.cs

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,30 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
262262
view = new ConcatTransform(h, new ConcatTransform.Arguments() { Column = xfCols }, view);
263263
}
264264

265+
if (tparams.NeedsNormalizeTransform)
266+
{
267+
var xfCols = new TextNormalizerCol[textCols.Length];
268+
string[] dstCols = new string[textCols.Length];
269+
for (int i = 0; i < textCols.Length; i++)
270+
{
271+
dstCols[i] = GenerateColumnName(view.Schema, textCols[i], "TextNormalizer");
272+
tempCols.Add(dstCols[i]);
273+
xfCols[i] = new TextNormalizerCol() { Source = textCols[i], Name = dstCols[i] };
274+
}
275+
276+
view = new TextNormalizerTransform(h,
277+
new TextNormalizerArgs()
278+
{
279+
Column = xfCols,
280+
KeepDiacritics = tparams.KeepDiacritics,
281+
KeepNumbers = tparams.KeepNumbers,
282+
KeepPunctuations = tparams.KeepPunctuations,
283+
TextCase = tparams.TextCase
284+
}, view);
285+
286+
textCols = dstCols;
287+
}
288+
265289
if (tparams.NeedsWordTokenizationTransform)
266290
{
267291
var xfCols = new DelimitedTokenizeTransform.Column[textCols.Length];
@@ -281,34 +305,6 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
281305
view = new DelimitedTokenizeTransform(h, new DelimitedTokenizeTransform.Arguments() { Column = xfCols }, view);
282306
}
283307

284-
if (tparams.NeedsNormalizeTransform)
285-
{
286-
string[] srcCols = wordTokCols == null ? textCols : wordTokCols;
287-
var xfCols = new TextNormalizerCol[srcCols.Length];
288-
string[] dstCols = new string[srcCols.Length];
289-
for (int i = 0; i < srcCols.Length; i++)
290-
{
291-
dstCols[i] = GenerateColumnName(view.Schema, srcCols[i], "TextNormalizer");
292-
tempCols.Add(dstCols[i]);
293-
xfCols[i] = new TextNormalizerCol() { Source = srcCols[i], Name = dstCols[i] };
294-
}
295-
296-
view = new TextNormalizerTransform(h,
297-
new TextNormalizerArgs()
298-
{
299-
Column = xfCols,
300-
KeepDiacritics = tparams.KeepDiacritics,
301-
KeepNumbers = tparams.KeepNumbers,
302-
KeepPunctuations = tparams.KeepPunctuations,
303-
TextCase = tparams.TextCase
304-
}, view);
305-
306-
if (wordTokCols != null)
307-
wordTokCols = dstCols;
308-
else
309-
textCols = dstCols;
310-
}
311-
312308
if (tparams.NeedsRemoveStopwordsTransform)
313309
{
314310
Contracts.Assert(wordTokCols != null, "StopWords transform requires that word tokenization has been applied to the input text.");
@@ -360,7 +356,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
360356
if (tparams.CharExtractorFactory != null)
361357
{
362358
{
363-
var srcCols = wordTokCols ?? textCols;
359+
var srcCols = tparams.NeedsRemoveStopwordsTransform ? wordTokCols : textCols;
364360
charTokCols = new string[srcCols.Length];
365361
var xfCols = new CharTokenizeTransform.Column[srcCols.Length];
366362
for (int i = 0; i < srcCols.Length; i++)

test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ public void PipelineSweeperRoles()
308308
var trainAuc = bestPipeline.PerformanceSummary.TrainingMetricValue;
309309
var testAuc = bestPipeline.PerformanceSummary.MetricValue;
310310
Assert.True((0.94 < trainAuc) && (trainAuc < 0.95));
311-
Assert.True((0.83 < testAuc) && (testAuc < 0.84));
311+
Assert.True((0.815 < testAuc) && (testAuc < 0.825));
312312

313313
var results = runner.GetOutput<IDataView>("ResultsOut");
314314
Assert.NotNull(results);

0 commit comments

Comments
 (0)