Skip to content

Commit 5e2ed11

Browse files
authored
Categorical estimator (#899)
1 parent 160b0df commit 5e2ed11

File tree

14 files changed

+743
-160
lines changed

14 files changed

+743
-160
lines changed

src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
6-
using System.Collections.Generic;
7-
using System.Linq;
8-
using System.Text;
95
using Microsoft.ML.Core.Data;
106
using Microsoft.ML.Data.StaticPipe.Runtime;
117
using Microsoft.ML.Runtime;
@@ -16,11 +12,15 @@
1612
using Microsoft.ML.Runtime.Model.Onnx;
1713
using Microsoft.ML.Runtime.Model.Pfa;
1814
using Newtonsoft.Json.Linq;
15+
using System;
16+
using System.Collections.Generic;
17+
using System.Linq;
18+
using System.Text;
1919

2020
[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(IDataTransform), typeof(KeyToVectorTransform), typeof(KeyToVectorTransform.Arguments), typeof(SignatureDataTransform),
2121
"Key To Vector Transform", KeyToVectorTransform.UserName, "KeyToVector", "ToVector", DocName = "transform/KeyToVectorTransform.md")]
2222

23-
[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(IDataView), typeof(KeyToVectorTransform), null, typeof(SignatureLoadDataTransform),
23+
[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(IDataTransform), typeof(KeyToVectorTransform), null, typeof(SignatureLoadDataTransform),
2424
"Key To Vector Transform", KeyToVectorTransform.LoaderSignature)]
2525

2626
[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(KeyToVectorTransform), null, typeof(SignatureLoadModel),
@@ -733,7 +733,7 @@ public KeyToVectorEstimator(IHostEnvironment env, string name, string source = n
733733
{
734734
}
735735

736-
public KeyToVectorEstimator(IHostEnvironment env, KeyToVectorTransform transformer)
736+
private KeyToVectorEstimator(IHostEnvironment env, KeyToVectorTransform transformer)
737737
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(KeyToVectorEstimator)), transformer)
738738
{
739739
}

src/Microsoft.ML.Data/Transforms/TermEstimator.cs

+22-7
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,31 @@
66
using Microsoft.ML.Data.StaticPipe.Runtime;
77
using System;
88
using System.Collections.Generic;
9-
using System.Collections.Immutable;
109
using System.Linq;
1110

1211
namespace Microsoft.ML.Runtime.Data
1312
{
1413
public sealed class TermEstimator : IEstimator<TermTransform>
1514
{
15+
public static class Defaults
16+
{
17+
public const int MaxNumTerms = 1000000;
18+
public const TermTransform.SortOrder Sort = TermTransform.SortOrder.Occurrence;
19+
}
20+
1621
private readonly IHost _host;
1722
private readonly TermTransform.ColumnInfo[] _columns;
18-
public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = TermTransform.Defaults.MaxNumTerms, TermTransform.SortOrder sort = TermTransform.Defaults.Sort) :
23+
24+
/// <summary>
25+
/// Convenience constructor for public facing API.
26+
/// </summary>
27+
/// <param name="env">Host Environment.</param>
28+
/// <param name="name">Name of the output column.</param>
29+
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
30+
/// <param name="maxNumTerms">Maximum number of terms to keep per column when auto-training.</param>
31+
/// <param name="sort">How items should be ordered when vectorized. By default, they will be in the order encountered.
32+
/// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').</param>
33+
public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) :
1934
this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort))
2035
{
2136
}
@@ -47,7 +62,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
4762
if (!col.IsKey || !col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector)
4863
{
4964
kv = new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
50-
col.ItemType, col.IsKey);
65+
colInfo.TextKeyValues ? TextType.Instance : col.ItemType, col.IsKey);
5166
}
5267
Contracts.AssertValue(kv);
5368

@@ -90,7 +105,7 @@ public sealed class ToKeyFitResult<T>
90105
// At the moment this is empty. Once PR #863 clears, we can change this class to hold the output
91106
// key-values metadata.
92107

93-
internal ToKeyFitResult(TermTransform.TermMap map)
108+
public ToKeyFitResult(TermTransform.TermMap map)
94109
{
95110
}
96111
}
@@ -101,8 +116,8 @@ public static partial class TermStaticExtensions
101116
// Raw generics would allow illegal possible inputs, e.g., Scalar<Bitmap>. So, this is a partial
102117
// class, and all the public facing extension methods for each possible type are in a T4 generated result.
103118

104-
private const KeyValueOrder DefSort = (KeyValueOrder)TermTransform.Defaults.Sort;
105-
private const int DefMax = TermTransform.Defaults.MaxNumTerms;
119+
private const KeyValueOrder DefSort = (KeyValueOrder)TermEstimator.Defaults.Sort;
120+
private const int DefMax = TermEstimator.Defaults.MaxNumTerms;
106121

107122
private struct Config
108123
{
@@ -176,7 +191,7 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env, Pipelin
176191
{
177192
var infos = new TermTransform.ColumnInfo[toOutput.Length];
178193
Action<TermTransform> onFit = null;
179-
for (int i=0; i<toOutput.Length; ++i)
194+
for (int i = 0; i < toOutput.Length; ++i)
180195
{
181196
var tcol = (ITermCol)toOutput[i];
182197
infos[i] = new TermTransform.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]],

src/Microsoft.ML.Data/Transforms/TermTransform.cs

+7-13
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
typeof(TermTransform.Arguments), typeof(SignatureDataTransform),
2424
TermTransform.UserName, "Term", "AutoLabel", "TermTransform", "AutoLabelTransform", DocName = "transform/TermTransform.md")]
2525

26-
[assembly: LoadableClass(TermTransform.Summary, typeof(IDataView), typeof(TermTransform), null, typeof(SignatureLoadDataTransform),
26+
[assembly: LoadableClass(TermTransform.Summary, typeof(IDataTransform), typeof(TermTransform), null, typeof(SignatureLoadDataTransform),
2727
TermTransform.UserName, TermTransform.LoaderSignature)]
2828

2929
[assembly: LoadableClass(TermTransform.Summary, typeof(TermTransform), null, typeof(SignatureLoadModel),
@@ -101,16 +101,10 @@ public enum SortOrder : byte
101101
// other things, like case insensitive (where appropriate), culturally aware, etc.?
102102
}
103103

104-
internal static class Defaults
105-
{
106-
public const int MaxNumTerms = 1000000;
107-
public const SortOrder Sort = SortOrder.Occurrence;
108-
}
109-
110104
public abstract class ArgumentsBase : TransformInputBase
111105
{
112106
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of terms to keep per column when auto-training", ShortName = "max", SortOrder = 5)]
113-
public int MaxNumTerms = Defaults.MaxNumTerms;
107+
public int MaxNumTerms = TermEstimator.Defaults.MaxNumTerms;
114108

115109
[Argument(ArgumentType.AtMostOnce, HelpText = "Comma separated list of terms", SortOrder = 105, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
116110
public string Terms;
@@ -134,7 +128,7 @@ public abstract class ArgumentsBase : TransformInputBase
134128
// REVIEW: Should we always sort? Opinions are mixed. See work item 7797429.
135129
[Argument(ArgumentType.AtMostOnce, HelpText = "How items should be ordered when vectorized. By default, they will be in the order encountered. " +
136130
"If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').", SortOrder = 113)]
137-
public SortOrder Sort = Defaults.Sort;
131+
public SortOrder Sort = TermEstimator.Defaults.Sort;
138132

139133
// REVIEW: Should we do this here, or correct the various pieces of code here and in MRS etc. that
140134
// assume key-values will be string? Once we correct these things perhaps we can see about removing it.
@@ -164,7 +158,7 @@ public ColInfo(string name, string source, ColumnType type)
164158

165159
public class ColumnInfo
166160
{
167-
public ColumnInfo(string input, string output, int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort, string[] term = null, bool textKeyValues = false)
161+
public ColumnInfo(string input, string output, int maxNumTerms = TermEstimator.Defaults.MaxNumTerms, SortOrder sort = TermEstimator.Defaults.Sort, string[] term = null, bool textKeyValues = false)
168162
{
169163
Input = input;
170164
Output = output;
@@ -181,7 +175,7 @@ public ColumnInfo(string input, string output, int maxNumTerms = Defaults.MaxNum
181175
public readonly string[] Term;
182176
public readonly bool TextKeyValues;
183177

184-
internal string Terms { get; set; }
178+
protected internal string Terms { get; set; }
185179
}
186180

187181
public const string Summary = "Converts input values (words, numbers, etc.) to index in a dictionary.";
@@ -406,7 +400,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISc
406400
/// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').</param>
407401
public static IDataView Create(IHostEnvironment env,
408402
IDataView input, string name, string source = null,
409-
int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort) =>
403+
int maxNumTerms = TermEstimator.Defaults.MaxNumTerms, SortOrder sort = TermEstimator.Defaults.Sort) =>
410404
new TermTransform(env, input, new[] { new ColumnInfo(source ?? name, name, maxNumTerms, sort) }).MakeDataTransform(input);
411405

412406
public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, ColumnBase[] column, IDataView input)
@@ -710,7 +704,7 @@ public override void Save(ModelSaveContext ctx)
710704
});
711705
}
712706

713-
internal TermMap GetTermMap(int iinfo)
707+
public TermMap GetTermMap(int iinfo)
714708
{
715709
Contracts.Assert(0 <= iinfo && iinfo < _unboundMaps.Length);
716710
return _unboundMaps[iinfo];

src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ private static BoundTermMap Bind(IHostEnvironment env, ISchema schema, TermMap u
470470
/// These are the immutable and serializable analogs to the <see cref="Builder"/> used in
471471
/// training.
472472
/// </summary>
473-
internal abstract class TermMap
473+
public abstract class TermMap
474474
{
475475
/// <summary>
476476
/// The item type of the input type, that is, either the input type or,
@@ -501,9 +501,9 @@ protected TermMap(PrimitiveType type, int count)
501501
OutputType = new KeyType(DataKind.U4, 0, Count == 0 ? 1 : Count);
502502
}
503503

504-
public abstract void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory);
504+
internal abstract void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory);
505505

506-
public static TermMap Load(ModelLoadContext ctx, IHostEnvironment ectx, CodecFactory codecFactory)
506+
internal static TermMap Load(ModelLoadContext ctx, IHostEnvironment ectx, CodecFactory codecFactory)
507507
{
508508
// *** Binary format ***
509509
// byte: map type code
@@ -610,7 +610,7 @@ public static TextImpl Create(ModelLoadContext ctx, IExceptionContext ectx)
610610
return new TextImpl(pool);
611611
}
612612

613-
public override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory)
613+
internal override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory)
614614
{
615615
// *** Binary format ***
616616
// byte: map type code, in this case 'Text' (0)
@@ -685,7 +685,7 @@ public HashArrayImpl(PrimitiveType itemType, HashArray<T> values)
685685
_values = values;
686686
}
687687

688-
public override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory)
688+
internal override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory)
689689
{
690690
// *** Binary format ***
691691
// byte: map type code, in this case 'Codec'
@@ -757,7 +757,7 @@ public override void WriteTextTerms(TextWriter writer)
757757
}
758758
}
759759

760-
internal abstract class TermMap<T> : TermMap
760+
public abstract class TermMap<T> : TermMap
761761
{
762762
protected TermMap(PrimitiveType type, int count)
763763
: base(type, count)

src/Microsoft.ML.Transforms/CategoricalHashTransform.cs

+72-11
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,14 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Float = System.Single;
6-
7-
using System;
5+
using System.Collections.Generic;
86
using System.Linq;
97
using System.Text;
108
using Microsoft.ML.Runtime;
119
using Microsoft.ML.Runtime.CommandLine;
1210
using Microsoft.ML.Runtime.Data;
1311
using Microsoft.ML.Runtime.EntryPoints;
1412
using Microsoft.ML.Runtime.Internal.Utilities;
15-
using Microsoft.ML.Runtime.Internal.Internallearn;
1613

1714
[assembly: LoadableClass(CategoricalHashTransform.Summary, typeof(IDataTransform), typeof(CategoricalHashTransform), typeof(CategoricalHashTransform.Arguments), typeof(SignatureDataTransform),
1815
CategoricalHashTransform.UserName, "CategoricalHashTransform", "CatHashTransform", "CategoricalHash", "CatHash")]
@@ -62,14 +59,11 @@ protected override bool TryParse(string str)
6259

6360
// We accept N:B:S where N is the new column name, B is the number of bits,
6461
// and S is source column names.
65-
string extra;
66-
if (!base.TryParse(str, out extra))
62+
if (!TryParse(str, out string extra))
6763
return false;
6864
if (extra == null)
6965
return true;
70-
71-
int bits;
72-
if (!int.TryParse(extra, out bits))
66+
if (!int.TryParse(extra, out int bits))
7367
return false;
7468
HashBits = bits;
7569
return true;
@@ -201,14 +195,81 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
201195
};
202196
}
203197

204-
return CategoricalTransform.CreateTransformCore(
198+
return CreateTransformCore(
205199
args.OutputKind, args.Column,
206200
args.Column.Select(col => col.OutputKind).ToList(),
207201
new HashTransform(h, hashArgs, input),
208202
h,
209-
env,
210203
args);
211204
}
212205
}
206+
207+
private static IDataTransform CreateTransformCore(CategoricalTransform.OutputKind argsOutputKind, OneToOneColumn[] columns,
208+
List<CategoricalTransform.OutputKind?> columnOutputKinds, IDataTransform input, IHost h, Arguments catHashArgs = null)
209+
{
210+
Contracts.CheckValue(columns, nameof(columns));
211+
Contracts.CheckValue(columnOutputKinds, nameof(columnOutputKinds));
212+
Contracts.CheckParam(columns.Length == columnOutputKinds.Count, nameof(columns));
213+
214+
using (var ch = h.Start("Create Transform Core"))
215+
{
216+
// Create the KeyToVectorTransform, if needed.
217+
var cols = new List<KeyToVectorTransform.Column>();
218+
bool binaryEncoding = argsOutputKind == CategoricalTransform.OutputKind.Bin;
219+
for (int i = 0; i < columns.Length; i++)
220+
{
221+
var column = columns[i];
222+
if (!column.TrySanitize())
223+
throw h.ExceptUserArg(nameof(Column.Name));
224+
225+
bool? bag;
226+
CategoricalTransform.OutputKind kind = columnOutputKinds[i] ?? argsOutputKind;
227+
switch (kind)
228+
{
229+
default:
230+
throw ch.ExceptUserArg(nameof(Column.OutputKind));
231+
case CategoricalTransform.OutputKind.Key:
232+
continue;
233+
case CategoricalTransform.OutputKind.Bin:
234+
binaryEncoding = true;
235+
bag = false;
236+
break;
237+
case CategoricalTransform.OutputKind.Ind:
238+
bag = false;
239+
break;
240+
case CategoricalTransform.OutputKind.Bag:
241+
bag = true;
242+
break;
243+
}
244+
var col = new KeyToVectorTransform.Column();
245+
col.Name = column.Name;
246+
col.Source = column.Name;
247+
col.Bag = bag;
248+
cols.Add(col);
249+
}
250+
251+
if (cols.Count == 0)
252+
return input;
253+
254+
IDataTransform transform;
255+
if (binaryEncoding)
256+
{
257+
if ((catHashArgs?.InvertHash ?? 0) != 0)
258+
ch.Warning("Invert hashing is being used with binary encoding.");
259+
260+
var keyToBinaryVecCols = cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.Source, x.Name)).ToArray();
261+
transform = KeyToBinaryVectorTransform.Create(h, input, keyToBinaryVecCols);
262+
}
263+
else
264+
{
265+
var keyToVecCols = cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.Source, x.Name, x.Bag ?? argsOutputKind == CategoricalTransform.OutputKind.Bag)).ToArray();
266+
267+
transform = KeyToVectorTransform.Create(h, input, keyToVecCols);
268+
}
269+
270+
ch.Done();
271+
return transform;
272+
}
273+
}
213274
}
214275
}

0 commit comments

Comments
 (0)