Skip to content

Commit 062d70c

Browse files
committed
Handle Ngram
Shift from Array to ReadOnlyList
1 parent e58f879 commit 062d70c

File tree

4 files changed

+49
-45
lines changed

4 files changed

+49
-45
lines changed

src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -492,15 +492,15 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
492492
/// <param name="ngramLength">Ngram length.</param>
493493
/// <param name="skipLength">Maximum number of tokens to skip when constructing an ngram.</param>
494494
/// <param name="allLengths">Whether to include all ngram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
495-
/// <param name="maxNumTerms">Maximum number of ngrams to store in the dictionary.</param>
495+
/// <param name="maximumTermCount">Maximum number of ngrams to store in the dictionary.</param>
496496
/// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
497497
public static Vector<float> ToNgrams<TKey>(this VarVector<Key<TKey, string>> input,
498498
int ngramLength = 1,
499499
int skipLength = 0,
500500
bool allLengths = true,
501-
int maxNumTerms = 10000000,
501+
int maximumTermCount = 10000000,
502502
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
503-
=> new OutPipelineColumn(input, ngramLength, skipLength, allLengths, maxNumTerms, weighting);
503+
=> new OutPipelineColumn(input, ngramLength, skipLength, allLengths, maximumTermCount, weighting);
504504
}
505505

506506
/// <summary>

src/Microsoft.ML.Transforms/Text/NgramTransform.cs

+35-31
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ internal sealed class Options : TransformInputBase
9393
public int SkipLength = NgramExtractingEstimator.Defaults.SkipLength;
9494

9595
[Argument(ArgumentType.Multiple, HelpText = "Maximum number of ngrams to store in the dictionary", ShortName = "max")]
96-
public int[] MaxNumTerms = new int[] { NgramExtractingEstimator.Defaults.MaxNumTerms };
96+
public int[] MaxNumTerms = new int[] { NgramExtractingEstimator.Defaults.MaximumTermCount };
9797

9898
[Argument(ArgumentType.AtMostOnce, HelpText = "The weighting criteria")]
9999
public NgramExtractingEstimator.WeightingCriteria Weighting = NgramExtractingEstimator.Defaults.Weighting;
@@ -253,7 +253,7 @@ private static SequencePool[] Train(IHostEnvironment env, NgramExtractingEstimat
253253
// Note: GetNgramIdFinderAdd will control how many ngrams of a specific length will
254254
// be added (using lims[iinfo]), therefore we set slotLim to the maximum
255255
helpers[iinfo] = new NgramBufferBuilder(ngramLength, skipLength, Utils.ArrayMaxSize,
256-
GetNgramIdFinderAdd(env, counts[iinfo], columns[iinfo].Limits, ngramMaps[iinfo], transformInfos[iinfo].RequireIdf));
256+
GetNgramIdFinderAdd(env, counts[iinfo], columns[iinfo].MaximumTermCounts, ngramMaps[iinfo], transformInfos[iinfo].RequireIdf));
257257
}
258258

259259
int cInfoFull = 0;
@@ -293,7 +293,7 @@ private static SequencePool[] Train(IHostEnvironment env, NgramExtractingEstimat
293293
}
294294
}
295295
}
296-
AssertValid(env, counts[iinfo], columns[iinfo].Limits, ngramMaps[iinfo]);
296+
AssertValid(env, counts[iinfo], columns[iinfo].MaximumTermCounts, ngramMaps[iinfo]);
297297
}
298298
}
299299

@@ -307,7 +307,7 @@ private static SequencePool[] Train(IHostEnvironment env, NgramExtractingEstimat
307307

308308
for (int iinfo = 0; iinfo < columns.Length; iinfo++)
309309
{
310-
AssertValid(env, counts[iinfo], columns[iinfo].Limits, ngramMaps[iinfo]);
310+
AssertValid(env, counts[iinfo], columns[iinfo].MaximumTermCounts, ngramMaps[iinfo]);
311311

312312
int ngramLength = transformInfos[iinfo].NgramLength;
313313
for (int i = 0; i < ngramLength; i++)
@@ -319,11 +319,11 @@ private static SequencePool[] Train(IHostEnvironment env, NgramExtractingEstimat
319319
}
320320

321321
[Conditional("DEBUG")]
322-
private static void AssertValid(IHostEnvironment env, int[] counts, ImmutableArray<int> lims, SequencePool pool)
322+
private static void AssertValid(IHostEnvironment env, int[] counts, IReadOnlyList<int> lims, SequencePool pool)
323323
{
324324
int count = 0;
325325
int countFull = 0;
326-
for (int i = 0; i < lims.Length; i++)
326+
for (int i = 0; i < lims.Count; i++)
327327
{
328328
env.Assert(counts[i] >= 0);
329329
env.Assert(counts[i] <= lims[i]);
@@ -334,20 +334,20 @@ private static void AssertValid(IHostEnvironment env, int[] counts, ImmutableArr
334334
env.Assert(count == pool.Count);
335335
}
336336

337-
private static NgramIdFinder GetNgramIdFinderAdd(IHostEnvironment env, int[] counts, ImmutableArray<int> lims, SequencePool pool, bool requireIdf)
337+
private static NgramIdFinder GetNgramIdFinderAdd(IHostEnvironment env, int[] counts, IReadOnlyList<int> lims, SequencePool pool, bool requireIdf)
338338
{
339339
Contracts.AssertValue(env);
340-
env.Assert(lims.Length > 0);
341-
env.Assert(lims.Length == Utils.Size(counts));
340+
env.Assert(lims.Count > 0);
341+
env.Assert(lims.Count == Utils.Size(counts));
342342

343343
int numFull = lims.Count(l => l <= 0);
344-
int ngramLength = lims.Length;
344+
int ngramLength = lims.Count;
345345
return
346346
(uint[] ngram, int lim, int icol, ref bool more) =>
347347
{
348348
env.Assert(0 < lim && lim <= Utils.Size(ngram));
349349
env.Assert(lim <= Utils.Size(counts));
350-
env.Assert(lim <= lims.Length);
350+
env.Assert(lim <= lims.Count);
351351
env.Assert(icol == 0);
352352

353353
var max = lim - 1;
@@ -695,7 +695,7 @@ internal static class Defaults
695695
public const int NgramLength = 2;
696696
public const bool AllLengths = true;
697697
public const int SkipLength = 0;
698-
public const int MaxNumTerms = 10000000;
698+
public const int MaximumTermCount = 10000000;
699699
public const WeightingCriteria Weighting = WeightingCriteria.Tf;
700700
}
701701

@@ -712,16 +712,16 @@ internal static class Defaults
712712
/// <param name="ngramLength">Ngram length.</param>
713713
/// <param name="skipLength">Maximum number of tokens to skip when constructing an ngram.</param>
714714
/// <param name="allLengths">Whether to include all ngram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
715-
/// <param name="maxNumTerms">Maximum number of ngrams to store in the dictionary.</param>
715+
/// <param name="maximumTermCount">Maximum number of ngrams to store in the dictionary.</param>
716716
/// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
717717
internal NgramExtractingEstimator(IHostEnvironment env,
718718
string outputColumnName, string inputColumnName = null,
719719
int ngramLength = Defaults.NgramLength,
720720
int skipLength = Defaults.SkipLength,
721721
bool allLengths = Defaults.AllLengths,
722-
int maxNumTerms = Defaults.MaxNumTerms,
722+
int maximumTermCount = Defaults.MaximumTermCount,
723723
WeightingCriteria weighting = Defaults.Weighting)
724-
: this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, ngramLength, skipLength, allLengths, maxNumTerms, weighting)
724+
: this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, ngramLength, skipLength, allLengths, maximumTermCount, weighting)
725725
{
726726
}
727727

@@ -734,16 +734,16 @@ internal NgramExtractingEstimator(IHostEnvironment env,
734734
/// <param name="ngramLength">Ngram length.</param>
735735
/// <param name="skipLength">Maximum number of tokens to skip when constructing an ngram.</param>
736736
/// <param name="allLengths">Whether to include all ngram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
737-
/// <param name="maxNumTerms">Maximum number of ngrams to store in the dictionary.</param>
737+
/// <param name="maximumTermCount">Maximum number of ngrams to store in the dictionary.</param>
738738
/// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
739739
internal NgramExtractingEstimator(IHostEnvironment env,
740740
(string outputColumnName, string inputColumnName)[] columns,
741741
int ngramLength = Defaults.NgramLength,
742742
int skipLength = Defaults.SkipLength,
743743
bool allLengths = Defaults.AllLengths,
744-
int maxNumTerms = Defaults.MaxNumTerms,
744+
int maximumTermCount = Defaults.MaximumTermCount,
745745
WeightingCriteria weighting = Defaults.Weighting)
746-
: this(env, columns.Select(x => new ColumnOptions(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, weighting, maxNumTerms)).ToArray())
746+
: this(env, columns.Select(x => new ColumnOptions(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, weighting, maximumTermCount)).ToArray())
747747
{
748748
}
749749

@@ -809,10 +809,14 @@ public sealed class ColumnOptions
809809
/// <summary>The weighting criteria.</summary>
810810
public readonly WeightingCriteria Weighting;
811811
/// <summary>
812+
/// Underlying state of <see cref="MaximumTermCounts"/>.
813+
/// </summary>
814+
private readonly ImmutableArray<int> _maximumTermCounts;
815+
/// <summary>
812816
/// Contains the maximum number of grams to store in the dictionary, for each level of ngrams,
813817
/// from 1 (in position 0) up to ngramLength (in position ngramLength-1)
814818
/// </summary>
815-
public readonly ImmutableArray<int> Limits;
819+
public IReadOnlyList<int> MaximumTermCounts => _maximumTermCounts;
816820

817821
/// <summary>
818822
/// Describes how the transformer handles one Gcn column pair.
@@ -823,14 +827,14 @@ public sealed class ColumnOptions
823827
/// <param name="skipLength">Maximum number of tokens to skip when constructing an ngram.</param>
824828
/// <param name="allLengths">Whether to store all ngram lengths up to ngramLength, or only ngramLength.</param>
825829
/// <param name="weighting">The weighting criteria.</param>
826-
/// <param name="maxNumTerms">Maximum number of ngrams to store in the dictionary.</param>
830+
/// <param name="maximumTermCount">Maximum number of ngrams to store in the dictionary.</param>
827831
public ColumnOptions(string name, string inputColumnName = null,
828832
int ngramLength = Defaults.NgramLength,
829833
int skipLength = Defaults.SkipLength,
830834
bool allLengths = Defaults.AllLengths,
831835
WeightingCriteria weighting = Defaults.Weighting,
832-
int maxNumTerms = Defaults.MaxNumTerms)
833-
: this(name, ngramLength, skipLength, allLengths, weighting, new int[] { maxNumTerms }, inputColumnName ?? name)
836+
int maximumTermCount = Defaults.MaximumTermCount)
837+
: this(name, ngramLength, skipLength, allLengths, weighting, new int[] { maximumTermCount }, inputColumnName ?? name)
834838
{
835839
}
836840

@@ -839,7 +843,7 @@ internal ColumnOptions(string name,
839843
int skipLength,
840844
bool allLengths,
841845
WeightingCriteria weighting,
842-
int[] maxNumTerms,
846+
int[] maximumTermCounts,
843847
string inputColumnName = null)
844848
{
845849
Name = name;
@@ -857,18 +861,18 @@ internal ColumnOptions(string name,
857861
var limits = new int[ngramLength];
858862
if (!AllLengths)
859863
{
860-
Contracts.CheckUserArg(Utils.Size(maxNumTerms) == 0 ||
861-
Utils.Size(maxNumTerms) == 1 && maxNumTerms[0] > 0, nameof(maxNumTerms));
862-
limits[ngramLength - 1] = Utils.Size(maxNumTerms) == 0 ? Defaults.MaxNumTerms : maxNumTerms[0];
864+
Contracts.CheckUserArg(Utils.Size(maximumTermCounts) == 0 ||
865+
Utils.Size(maximumTermCounts) == 1 && maximumTermCounts[0] > 0, nameof(maximumTermCounts));
866+
limits[ngramLength - 1] = Utils.Size(maximumTermCounts) == 0 ? Defaults.MaximumTermCount : maximumTermCounts[0];
863867
}
864868
else
865869
{
866-
Contracts.CheckUserArg(Utils.Size(maxNumTerms) <= ngramLength, nameof(maxNumTerms));
867-
Contracts.CheckUserArg(Utils.Size(maxNumTerms) == 0 || maxNumTerms.All(i => i >= 0) && maxNumTerms[maxNumTerms.Length - 1] > 0, nameof(maxNumTerms));
868-
var extend = Utils.Size(maxNumTerms) == 0 ? Defaults.MaxNumTerms : maxNumTerms[maxNumTerms.Length - 1];
869-
limits = Utils.BuildArray(ngramLength, i => i < Utils.Size(maxNumTerms) ? maxNumTerms[i] : extend);
870+
Contracts.CheckUserArg(Utils.Size(maximumTermCounts) <= ngramLength, nameof(maximumTermCounts));
871+
Contracts.CheckUserArg(Utils.Size(maximumTermCounts) == 0 || maximumTermCounts.All(i => i >= 0) && maximumTermCounts[maximumTermCounts.Length - 1] > 0, nameof(maximumTermCounts));
872+
var extend = Utils.Size(maximumTermCounts) == 0 ? Defaults.MaximumTermCount : maximumTermCounts[maximumTermCounts.Length - 1];
873+
limits = Utils.BuildArray(ngramLength, i => i < Utils.Size(maximumTermCounts) ? maximumTermCounts[i] : extend);
870874
}
871-
Limits = ImmutableArray.Create(limits);
875+
_maximumTermCounts = ImmutableArray.Create(limits);
872876
}
873877
}
874878

src/Microsoft.ML.Transforms/Text/TextCatalog.cs

+9-9
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextT
194194
/// <param name="ngramLength">Ngram length.</param>
195195
/// <param name="skipLength">Maximum number of tokens to skip when constructing an ngram.</param>
196196
/// <param name="allLengths">Whether to include all ngram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
197-
/// <param name="maxNumTerms">Maximum number of ngrams to store in the dictionary.</param>
197+
/// <param name="maximumTermCount">Maximum number of ngrams to store in the dictionary.</param>
198198
/// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
199199
/// <example>
200200
/// <format type="text/markdown">
@@ -209,10 +209,10 @@ public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.Text
209209
int ngramLength = NgramExtractingEstimator.Defaults.NgramLength,
210210
int skipLength = NgramExtractingEstimator.Defaults.SkipLength,
211211
bool allLengths = NgramExtractingEstimator.Defaults.AllLengths,
212-
int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms,
212+
int maximumTermCount = NgramExtractingEstimator.Defaults.MaximumTermCount,
213213
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.Defaults.Weighting) =>
214214
new NgramExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), outputColumnName, inputColumnName,
215-
ngramLength, skipLength, allLengths, maxNumTerms, weighting);
215+
ngramLength, skipLength, allLengths, maximumTermCount, weighting);
216216

217217
/// <summary>
218218
/// Produces a bag of counts of ngrams (sequences of consecutive words) in <paramref name="columns.inputs"/>
@@ -223,17 +223,17 @@ public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.Text
223223
/// <param name="ngramLength">Ngram length.</param>
224224
/// <param name="skipLength">Maximum number of tokens to skip when constructing an ngram.</param>
225225
/// <param name="allLengths">Whether to include all ngram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
226-
/// <param name="maxNumTerms">Maximum number of ngrams to store in the dictionary.</param>
226+
/// <param name="maximumTermCount">Maximum number of ngrams to store in the dictionary.</param>
227227
/// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
228228
public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.TextTransforms catalog,
229229
(string outputColumnName, string inputColumnName)[] columns,
230230
int ngramLength = NgramExtractingEstimator.Defaults.NgramLength,
231231
int skipLength = NgramExtractingEstimator.Defaults.SkipLength,
232232
bool allLengths = NgramExtractingEstimator.Defaults.AllLengths,
233-
int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms,
233+
int maximumTermCount = NgramExtractingEstimator.Defaults.MaximumTermCount,
234234
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.Defaults.Weighting)
235235
=> new NgramExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns,
236-
ngramLength, skipLength, allLengths, maxNumTerms, weighting);
236+
ngramLength, skipLength, allLengths, maximumTermCount, weighting);
237237

238238
/// <summary>
239239
/// Produces a bag of counts of ngrams (sequences of consecutive words) in <paramref name="columns.inputs"/>
@@ -339,7 +339,7 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf
339339
int ngramLength = NgramExtractingEstimator.Defaults.NgramLength,
340340
int skipLength = NgramExtractingEstimator.Defaults.SkipLength,
341341
bool allLengths = NgramExtractingEstimator.Defaults.AllLengths,
342-
int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms,
342+
int maxNumTerms = NgramExtractingEstimator.Defaults.MaximumTermCount,
343343
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
344344
=> new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(),
345345
outputColumnName, inputColumnName, ngramLength, skipLength, allLengths, maxNumTerms);
@@ -362,7 +362,7 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf
362362
int ngramLength = NgramExtractingEstimator.Defaults.NgramLength,
363363
int skipLength = NgramExtractingEstimator.Defaults.SkipLength,
364364
bool allLengths = NgramExtractingEstimator.Defaults.AllLengths,
365-
int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms,
365+
int maxNumTerms = NgramExtractingEstimator.Defaults.MaximumTermCount,
366366
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
367367
=> new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(),
368368
outputColumnName, inputColumnNames, ngramLength, skipLength, allLengths, maxNumTerms, weighting);
@@ -383,7 +383,7 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf
383383
int ngramLength = NgramExtractingEstimator.Defaults.NgramLength,
384384
int skipLength = NgramExtractingEstimator.Defaults.SkipLength,
385385
bool allLengths = NgramExtractingEstimator.Defaults.AllLengths,
386-
int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms,
386+
int maxNumTerms = NgramExtractingEstimator.Defaults.MaximumTermCount,
387387
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
388388
=> new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns, ngramLength, skipLength, allLengths, maxNumTerms, weighting);
389389

0 commit comments

Comments
 (0)