Skip to content

Commit 356cad4

Browse files
Ivanidzo4kaTomFinley
authored andcommitted
KeyToVector estimators (#858)
Added support for converting key values of various types to Vector<float> and VarVector<float>
1 parent 52aff02 commit 356cad4

File tree

9 files changed

+1791
-705
lines changed

9 files changed

+1791
-705
lines changed

src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs

+819-398
Large diffs are not rendered by default.

src/Microsoft.ML.Data/Transforms/TermTransform.cs

+15-47
Original file line numberDiff line numberDiff line change
@@ -241,18 +241,25 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum
241241
return columns.Select(x => (x.Input, x.Output)).ToArray();
242242
}
243243

244-
private ColInfo[] CreateInfos(ISchema schema)
244+
internal string TestIsKnownDataKind(ColumnType type)
245245
{
246-
Host.AssertValue(schema);
246+
if (type.ItemType.RawKind != default && (type.IsVector || type.IsPrimitive))
247+
return null;
248+
return "standard type or a vector of standard type";
249+
}
250+
251+
private ColInfo[] CreateInfos(ISchema inputSchema)
252+
{
253+
Host.AssertValue(inputSchema);
247254
var infos = new ColInfo[ColumnPairs.Length];
248255
for (int i = 0; i < ColumnPairs.Length; i++)
249256
{
250-
if (!schema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc))
251-
throw Host.ExceptUserArg(nameof(ColumnPairs), "Source column '{0}' not found", ColumnPairs[i].input);
252-
var type = schema.GetColumnType(colSrc);
257+
if (!inputSchema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc))
258+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input);
259+
var type = inputSchema.GetColumnType(colSrc);
253260
string reason = TestIsKnownDataKind(type);
254261
if (reason != null)
255-
throw Host.ExceptUserArg(nameof(ColumnPairs), InvalidTypeErrorFormat, ColumnPairs[i].input, type, reason);
262+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input, reason, type.ToString());
256263
infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type);
257264
}
258265
return infos;
@@ -271,7 +278,7 @@ private TermTransform(IHostEnvironment env, IDataView input,
271278
{
272279
using (var ch = Host.Start("Training"))
273280
{
274-
var infos = CreateInfos(Host, ColumnPairs, input.Schema, TestIsKnownDataKind);
281+
var infos = CreateInfos(input.Schema);
275282
_unboundMaps = Train(Host, ch, infos, file, termsColumn, loaderFactory, columns, input);
276283
_textMetadata = new bool[_unboundMaps.Length];
277284
for (int iinfo = 0; iinfo < columns.Length; ++iinfo)
@@ -400,32 +407,6 @@ public static IDataView Create(IHostEnvironment env,
400407
int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort) =>
401408
new TermTransform(env, input, new[] { new ColumnInfo(source ?? name, name, maxNumTerms, sort) }).MakeDataTransform(input);
402409

403-
//REVIEW: This and static method below need to go to base class as it get created.
404-
private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}.";
405-
406-
private static ColInfo[] CreateInfos(IHostEnvironment env, (string source, string name)[] columns, ISchema schema, Func<ColumnType, string> testType)
407-
{
408-
env.CheckUserArg(Utils.Size(columns) > 0, nameof(columns));
409-
env.AssertValue(schema);
410-
env.AssertValueOrNull(testType);
411-
412-
var infos = new ColInfo[columns.Length];
413-
for (int i = 0; i < columns.Length; i++)
414-
{
415-
if (!schema.TryGetColumnIndex(columns[i].source, out int colSrc))
416-
throw env.ExceptUserArg(nameof(columns), "Source column '{0}' not found", columns[i].source);
417-
var type = schema.GetColumnType(colSrc);
418-
if (testType != null)
419-
{
420-
string reason = testType(type);
421-
if (reason != null)
422-
throw env.ExceptUserArg(nameof(columns), InvalidTypeErrorFormat, columns[i].source, type, reason);
423-
}
424-
infos[i] = new ColInfo(columns[i].name, columns[i].source, type);
425-
}
426-
return infos;
427-
}
428-
429410
public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, ColumnBase[] column, IDataView input)
430411
{
431412
return Create(env, new Arguments()
@@ -452,13 +433,6 @@ public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, Co
452433
}, input);
453434
}
454435

455-
internal static string TestIsKnownDataKind(ColumnType type)
456-
{
457-
if (type.ItemType.RawKind != default && (type.IsVector || type.IsPrimitive))
458-
return null;
459-
return "Expected standard type or a vector of standard type";
460-
}
461-
462436
/// <summary>
463437
/// Utility method to create the file-based <see cref="TermMap"/>.
464438
/// </summary>
@@ -701,7 +675,7 @@ public override void Save(ModelSaveContext ctx)
701675
ctx.CheckAtModel();
702676
ctx.SetVersionInfo(GetVersionInfo());
703677

704-
base.SaveColumns(ctx);
678+
SaveColumns(ctx);
705679

706680
Host.Assert(_unboundMaps.Length == _textMetadata.Length);
707681
Host.Assert(_textMetadata.Length == ColumnPairs.Length);
@@ -743,12 +717,6 @@ internal TermMap GetTermMap(int iinfo)
743717
protected override IRowMapper MakeRowMapper(ISchema schema)
744718
=> new Mapper(this, schema);
745719

746-
protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
747-
{
748-
if ((inputSchema.GetColumnType(srcCol).ItemType.RawKind == default))
749-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, "image", inputSchema.GetColumnType(srcCol).ToString());
750-
}
751-
752720
private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa
753721
{
754722
private readonly ColumnType[] _types;

src/Microsoft.ML.Transforms/CategoricalTransform.cs

+5-10
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ public static IDataTransform CreateTransformCore(
177177
using (var ch = h.Start("Create Transform Core"))
178178
{
179179
// Create the KeyToVectorTransform, if needed.
180-
List<KeyToVectorTransform.Column> cols = new List<KeyToVectorTransform.Column>();
180+
var cols = new List<KeyToVectorTransform.Column>();
181181
bool binaryEncoding = argsOutputKind == OutputKind.Bin;
182182
for (int i = 0; i < columns.Length; i++)
183183
{
@@ -220,19 +220,14 @@ public static IDataTransform CreateTransformCore(
220220
if ((catHashArgs?.InvertHash ?? 0) != 0)
221221
ch.Warning("Invert hashing is being used with binary encoding.");
222222

223-
var keyToBinaryArgs = new KeyToBinaryVectorTransform.Arguments();
224-
keyToBinaryArgs.Column = cols.ToArray();
225-
transform = new KeyToBinaryVectorTransform(h, keyToBinaryArgs, input);
223+
var keyToBinaryVecCols = cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.Source, x.Name)).ToArray();
224+
transform = KeyToBinaryVectorTransform.Create(h, input, keyToBinaryVecCols);
226225
}
227226
else
228227
{
229-
var keyToVecArgs = new KeyToVectorTransform.Arguments
230-
{
231-
Bag = argsOutputKind == OutputKind.Bag,
232-
Column = cols.ToArray()
233-
};
228+
var keyToVecCols = cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.Source, x.Name, x.Bag ?? argsOutputKind == OutputKind.Bag)).ToArray();
234229

235-
transform = new KeyToVectorTransform(h, keyToVecArgs, input);
230+
transform = KeyToVectorTransform.Create(h, input, keyToVecCols);
236231
}
237232

238233
ch.Done();

0 commit comments

Comments
 (0)