From b1f3d69c131728eea540cfa8318117bd0e672834 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Fri, 7 Sep 2018 10:51:16 -0700 Subject: [PATCH 01/17] still need metadata tests --- .../Transforms/KeyToVectorTransform.cs | 1043 ++++++++++------- .../Transforms/TermTransform.cs | 27 +- .../CategoricalTransform.cs | 4 +- .../Runtime/EntryPoints/FeatureCombiner.cs | 10 +- .../Transformers/KeyToVectorEstimatorTests.cs | 122 ++ 5 files changed, 749 insertions(+), 457 deletions(-) create mode 100644 test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 0f4b616a49..3a20ebcd04 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -17,15 +16,22 @@ using Microsoft.ML.Runtime.Model.Pfa; using Newtonsoft.Json.Linq; -[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(KeyToVectorTransform), typeof(KeyToVectorTransform.Arguments), typeof(SignatureDataTransform), - "Key To Vector Transform", "KeyToVectorTransform", "KeyToVector", "ToVector", DocName = "transform/KeyToVectorTransform.md")] +[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(IDataTransform), typeof(KeyToVectorTransform), typeof(KeyToVectorTransform.Arguments), typeof(SignatureDataTransform), + "Key To Vector Transform", KeyToVectorTransform.UserName, "KeyToVector", "ToVector", DocName = "transform/KeyToVectorTransform.md")] -[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(KeyToVectorTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(IDataView), typeof(KeyToVectorTransform), null, typeof(SignatureLoadDataTransform), "Key To Vector Transform", KeyToVectorTransform.LoaderSignature)] +[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(KeyToVectorTransform), null, typeof(SignatureLoadModel), + KeyToVectorTransform.UserName, KeyToVectorTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(KeyToVectorTransform), null, typeof(SignatureLoadRowMapper), + KeyToVectorTransform.UserName, KeyToVectorTransform.LoaderSignature)] + namespace Microsoft.ML.Runtime.Data { - public sealed class KeyToVectorTransform : OneToOneTransformBase + + public sealed class KeyToVectorTransform : OneToOneTransformerBase { public abstract class ColumnBase : OneToOneColumn { @@ -69,12 +75,6 @@ public bool TryUnparse(StringBuilder sb) return TryUnparseCore(sb); } } - - private static class Defaults - { - public const bool Bag = false; - } - public sealed class Arguments { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] @@ -82,121 +82,101 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector.")] - public bool Bag = Defaults.Bag; + public bool Bag = KeyToVectorEstimator.Defaults.Bag; } - internal const string Summary = "Converts a key column to an indicator vector."; - - public const string LoaderSignature = "KeyToVectorTransform"; - private static VersionInfo GetVersionInfo() + public class ColumnInfo { - return new VersionInfo( - modelSignature: "KEY2VECT", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, - verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + public readonly string Input; + public readonly string Output; + public readonly bool Bag; + public ColumnInfo(string input, string output, bool bag = KeyToVectorEstimator.Defaults.Bag) + { + Input = input; + Output = output; + Bag = bag; + } + } + internal sealed class ColInfo + { + public readonly string Name; + public readonly string Source; + public readonly ColumnType TypeSrc; + + public ColInfo(string name, string source, ColumnType type) + { + Name = name; + Source = source; + TypeSrc = type; + } } private const string RegistrationName = "KeyToVector"; + private readonly bool[] _bags; + private readonly int[] _valueCounts; + private readonly int[] _sizes; - // These arrays are parallel to Infos. - // * _bag indicates whether vector inputs should have their output indicator vectors added - // (instead of concatenated). This is faithful to what the user specified in the Arguments - // and is persisted. - // * _concat is whether, given the current input, there are multiple output instance vectors - // to concatenate. If _bag[i] is true, then _concat[i] will be false. If _bag[i] is false, - // _concat[i] will be true iff the input is a vector with either unknown length or length - // bigger than one. In the other cases (non-vector input and vector of length one), there - // is only one resulting indicator vector so no need to concatenate anything. - // * _types contains the output column types. - // * _slotNamesTypes contains the metadata types for slot name metadata. _slotNamesTypes[i] will - // be null if slot names are not available for the given column (eg, in the variable size case, - // or when the source doesn't have key names). - private readonly bool[] _bag; - private readonly bool[] _concat; - private readonly VectorType[] _types; - - /// - /// Convenience constructor for public facing API. - /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the input column. If this is null '' will be used. - /// Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector. - public KeyToVectorTransform(IHostEnvironment env, - IDataView input, - string name, - string source = null, - bool bag = Defaults.Bag) - : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, Bag = bag }, input) + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) { + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); } - /// - /// Public constructor corresponding to SignatureDataTransform. - /// - public KeyToVectorTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, - input, TestIsKey) - { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); + //REVIEW: This and static method below need to go to base class as it get created. + private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}."; - _bag = new bool[Infos.Length]; - _concat = new bool[Infos.Length]; - _types = new VectorType[Infos.Length]; - for (int i = 0; i < Infos.Length; i++) + private ColInfo[] CreateInfos(ISchema schema) + { + Host.AssertValue(schema); + var infos = new ColInfo[ColumnPairs.Length]; + for (int i = 0; i < ColumnPairs.Length; i++) { - var item = args.Column[i]; - _bag[i] = item.Bag ?? args.Bag; - ComputeType(this, Source.Schema, i, Infos[i], _bag[i], Metadata, - out _types[i], out _concat[i]); + if (!schema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) + throw Host.ExceptUserArg(nameof(ColumnPairs), "Source column '{0}' not found", ColumnPairs[i].input); + var type = schema.GetColumnType(colSrc); + string reason = TestIsKey(type); + if (reason != null) + throw Host.ExceptUserArg(nameof(ColumnPairs), InvalidTypeErrorFormat, ColumnPairs[i].input, type, reason); + infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type); } - Metadata.Seal(); + return infos; } - private KeyToVectorTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestIsKey) + private string TestIsKey(ColumnType type) { - Host.AssertValue(ctx); + if (type.ItemType.KeyCount > 0) + return null; + return "Expected Key type of known cardinality"; + } - // *** Binary format *** - // - // - // for each added column - // byte: bag as 0/1 - Host.AssertNonEmpty(Infos); - int size = Infos.Length; - _bag = new bool[size]; - _concat = new bool[Infos.Length]; - _types = new VectorType[size]; - for (int i = 0; i < size; i++) + public KeyToVectorTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) : + base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) + { + var infos = CreateInfos(input.Schema); + _bags = new bool[infos.Length]; + _valueCounts = new int[infos.Length]; + _sizes = new int[infos.Length]; + + for (int i = 0; i < infos.Length; i++) { - _bag[i] = ctx.Reader.ReadBoolByte(); - ComputeType(this, Source.Schema, i, Infos[i], _bag[i], Metadata, - out _types[i], out _concat[i]); + _bags[i] = columns[i].Bag; + _sizes[i] = infos[i].TypeSrc.ItemType.KeyCount; + _valueCounts[i] = infos[i].TypeSrc.ValueCount; } - Metadata.Seal(); } + public const string LoaderSignature = "KeyToVectorTransform"; + public const string UserName = "KeyToVectorTransform"; + internal const string Summary = "Converts a key column to an indicator vector."; - public static KeyToVectorTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + private static VersionInfo GetVersionInfo() { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - return h.Apply("Loading Model", - ch => - { - // *** Binary format *** - // int: sizeof(Float) - // - int cbFloat = ctx.Reader.ReadInt32(); - ch.CheckDecode(cbFloat == sizeof(Float)); - return new KeyToVectorTransform(h, ctx, input); - }); + return new VersionInfo( + modelSignature: "KEY2VECT", + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Convert to Estimators + verReadableCur: 0x00010002, + verWeCanReadBack: 0x00010002, + loaderSignature: LoaderSignature); } public override void Save(ModelSaveContext ctx) @@ -206,393 +186,614 @@ public override void Save(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: sizeof(Float) // // for each added column // byte: bag as 0/1 - ctx.Writer.Write(sizeof(Float)); - SaveBase(ctx); + // for each added column + // int: keyCount + // int: valueCount + SaveColumns(ctx); + + Host.Assert(_bags.Length == ColumnPairs.Length); + for (int i = 0; i < _bags.Length; i++) + ctx.Writer.WriteBoolByte(_bags[i]); + Host.Assert(_valueCounts.Length == ColumnPairs.Length); + Host.Assert(_sizes.Length == ColumnPairs.Length); - Host.Assert(_bag.Length == Infos.Length); - for (int i = 0; i < _bag.Length; i++) - ctx.Writer.WriteBoolByte(_bag[i]); + for (int i = 0; i < ColumnPairs.Length; i++) + { + ctx.Writer.Write(_sizes[i]); + ctx.Writer.Write(_valueCounts[i]); + } } - public override bool CanSavePfa => true; - public override bool CanSaveOnnx => true; + // Factory method for SignatureLoadModel. + public static KeyToVectorTransform Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(RegistrationName); + + host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + return new KeyToVectorTransform(host, ctx); + } - protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToken srcToken) + private KeyToVectorTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) { - Contracts.AssertValue(ctx); - Contracts.Assert(0 <= iinfo && iinfo < Infos.Length); - Contracts.Assert(Infos[iinfo] == info); - Contracts.AssertValue(srcToken); - Contracts.Assert(CanSavePfa); - - int keyCount = info.TypeSrc.ItemType.KeyCount; - Host.Assert(keyCount > 0); - // If the input type is scalar, we can just use the fanout function. - if (!info.TypeSrc.IsVector) - return PfaUtils.Call("cast.fanoutDouble", srcToken, 0, keyCount, false); - - JToken arrType = PfaUtils.Type.Array(PfaUtils.Type.Double); - if (_concat[iinfo]) + var columnsLength = ColumnPairs.Length; + // *** Binary format *** + // + // for each added column + // byte: bag as 0/1 + // for each added column + // int: keyCount + // int: valueCount + _bags = new bool[columnsLength]; + _sizes = new int[columnsLength]; + _valueCounts = new int[columnsLength]; + + _bags = ctx.Reader.ReadBoolArray(columnsLength); + for (int i = 0; i < columnsLength; i++) { - // The concatenation case. We can still use fanout, but we just append them all together. - return PfaUtils.Call("a.flatMap", srcToken, - PfaContext.CreateFuncBlock(new JArray() { PfaUtils.Param("k", PfaUtils.Type.Int) }, - arrType, PfaUtils.Call("cast.fanoutDouble", "k", 0, keyCount, false))); + _sizes[i] = ctx.Reader.ReadInt32(); + _valueCounts[i] = ctx.Reader.ReadInt32(); } + } + + public static IDataView Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) => + new KeyToVectorTransform(env, input, columns).MakeDataTransform(input); - // The bag case, while the most useful, is the most elaborate and difficult: we create - // an all-zero array and then add items to it. - const string funcName = "keyToVecUpdate"; - if (!ctx.Pfa.ContainsFunc(funcName)) + // Factory method for SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + env.CheckValue(args.Column, nameof(args.Column)); + var cols = new ColumnInfo[args.Column.Length]; + using (var ch = env.Start("ValidateArgs")) { - var toFunc = PfaContext.CreateFuncBlock( - new JArray() { PfaUtils.Param("v", PfaUtils.Type.Double) }, PfaUtils.Type.Double, - PfaUtils.Call("+", "v", 1)); - - ctx.Pfa.AddFunc(funcName, - new JArray(PfaUtils.Param("a", arrType), PfaUtils.Param("i", PfaUtils.Type.Int)), - arrType, PfaUtils.If(PfaUtils.Call(">=", "i", 0), - PfaUtils.Index("a", "i").AddReturn("to", toFunc), "a")); - } - return PfaUtils.Call("a.fold", srcToken, - PfaUtils.Call("cast.fanoutDouble", -1, 0, keyCount, false), PfaUtils.FuncRef("u." + funcName)); - } + for (int i = 0; i < cols.Length; i++) + { + var item = args.Column[i]; - protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) - { - string opType = "OneHotEncoder"; - var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - node.AddAttribute("cats_int64s", Enumerable.Range(1, info.TypeSrc.ItemType.KeyCount).Select(x => (long)x)); - node.AddAttribute("zeros", true); - return true; + cols[i] = new ColumnInfo(item.Source, + item.Name, + item.Bag ?? args.Bag); + }; + } + return new KeyToVectorTransform(env, input, cols).MakeDataTransform(input); } - // Computes the column type and whether multiple indicator vectors need to be concatenated. - // Also populates the metadata. - private static void ComputeType(KeyToVectorTransform trans, ISchema input, int iinfo, ColInfo info, bool bag, - MetadataDispatcher md, out VectorType type, out bool concat) + // Factory method for SignatureLoadDataTransform. + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, schema); + + private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa { - Contracts.AssertValue(trans); - Contracts.AssertValue(input); - Contracts.AssertValue(info); - Contracts.Assert(info.TypeSrc.ItemType.IsKey); - Contracts.AssertValue(md); - - int size = info.TypeSrc.ItemType.KeyCount; - Contracts.Assert(size > 0); - - // See if the source has key names. - var typeNames = input.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, info.Source); - if (typeNames == null || !typeNames.IsKnownSizeVector || !typeNames.ItemType.IsText || - typeNames.VectorSize != size) + private readonly KeyToVectorTransform _parent; + private readonly ColInfo[] _infos; + private readonly ColumnType[] _types; + + public Mapper(KeyToVectorTransform parent, ISchema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { - typeNames = null; + _parent = parent; + _infos = _parent.CreateInfos(inputSchema); + _types = new ColumnType[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + ColumnType type; + if (_parent._valueCounts[i] == 1) + type = new VectorType(NumberType.Float, _parent._sizes[i]); + else + type = new VectorType(NumberType.Float, _parent._valueCounts[i], _parent._sizes[i]); + _types[i] = type; + } } - // Don't pass through any source column metadata. - using (var bldr = md.BuildMetadata(iinfo)) + public override RowMapperColumnInfo[] GetOutputColumns() { - if (bag || info.TypeSrc.ValueCount == 1) + var result = new RowMapperColumnInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + Host.Assert(colIndex >= 0); + var colMetaInfo = new ColumnMetadataInfo(_parent.ColumnPairs[i].output); + AddMetadata(i, colMetaInfo); + + ColumnType type; + if (_parent._valueCounts[i] == 1) + type = new VectorType(NumberType.Float, _parent._sizes[i]); + else + type = new VectorType(NumberType.Float, _parent._valueCounts[i], _parent._sizes[i]); + result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _types[i], colMetaInfo); + } + return result; + } + + private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) + { + InputSchema.TryGetColumnIndex(_infos[i].Source, out int srcCol); + //IVAN: Simplify + var srcType = _infos[i].TypeSrc; + var typeNames = InputSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, srcCol); + if (typeNames == null || !typeNames.IsKnownSizeVector || !typeNames.ItemType.IsText || + typeNames.VectorSize != _parent._sizes[i]) + { + typeNames = null; + } + if (_parent._bags[i] || _parent._valueCounts[i] == 1) { - // Output is a single vector computed as the sum of the output indicator vectors. - concat = false; - type = new VectorType(NumberType.Float, size); if (typeNames != null) - bldr.AddGetter>(MetadataUtils.Kinds.SlotNames, typeNames, trans.GetKeyNames); + { + MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => + { + Host.Assert(0 <= col && col < _infos.Length); + InputSchema.TryGetColumnIndex(_infos[col].Source, out int sourceColumn); + InputSchema.GetMetadata>(MetadataUtils.Kinds.KeyValues, sourceColumn, ref dst); + }; + var info = new MetadataInfo>(typeNames, getter); + colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info); + } } else { - // Output is the concatenation of the multiple output indicator vectors. - concat = true; - type = new VectorType(NumberType.Float, info.TypeSrc.ValueCount, size); + //IVAN:simplify it + var type = new VectorType(NumberType.Float, _parent._valueCounts[i], _parent._sizes[i]); if (typeNames != null && type.VectorSize > 0) { - bldr.AddGetter>(MetadataUtils.Kinds.SlotNames, - new VectorType(TextType.Instance, type), trans.GetSlotNames); + MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => + { + GetSlotNames(col, ref dst); + }; + var info = new MetadataInfo>(new VectorType(TextType.Instance, type), getter); + colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info); } } + if (!_parent._bags[i] && _parent._valueCounts[i] > 0) + { + MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => + { + GetCategoricalSlotRanges(col, ref dst); + }; + var info = new MetadataInfo>(MetadataUtils.GetCategoricalType(_parent._valueCounts[i]), getter); + colMetaInfo.Add(MetadataUtils.Kinds.CategoricalSlotRanges, info); + } + if (_parent._bags[i] || _parent._valueCounts[i] == 1) + { + MetadataUtils.MetadataGetter getter = (int col, ref DvBool dst) => + { + dst = true; + }; + var info = new MetadataInfo(BoolType.Instance, getter); + colMetaInfo.Add(MetadataUtils.Kinds.IsNormalized, info); + } + } - if (!bag && info.TypeSrc.ValueCount > 0) + // Combines source key names and slot names to produce final slot names. + private void GetSlotNames(int iinfo, ref VBuffer dst) + { + Host.Assert(0 <= iinfo && iinfo < _infos.Length); + var type = new VectorType(NumberType.Float, _parent._valueCounts[iinfo], _parent._sizes[iinfo]); + Host.Assert(type.IsKnownSizeVector); + + // Size one should have been treated the same as Bag (by the caller). + // Variable size should have thrown (by the caller). + var typeSrc = _infos[iinfo].TypeSrc; + Host.Assert(typeSrc.VectorSize > 1); + + // Get the source slot names, defaulting to empty text. + var namesSlotSrc = default(VBuffer); + InputSchema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol); + Host.Assert(srcCol >= 0); + var typeSlotSrc = InputSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, srcCol); + if (typeSlotSrc != null && typeSlotSrc.VectorSize == typeSrc.VectorSize && typeSlotSrc.ItemType.IsText) { - bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, - MetadataUtils.GetCategoricalType(info.TypeSrc.ValueCount), trans.GetCategoricalSlotRanges); + InputSchema.GetMetadata(MetadataUtils.Kinds.SlotNames, srcCol, ref namesSlotSrc); + Host.Check(namesSlotSrc.Length == typeSrc.VectorSize); } + else + namesSlotSrc = VBufferUtils.CreateEmpty(typeSrc.VectorSize); + + int keyCount = typeSrc.ItemType.KeyCount; + int slotLim = type.VectorSize; + Host.Assert(slotLim == (long)typeSrc.VectorSize * keyCount); + + // Get the source key names, in an array (since we will use them multiple times). + var namesKeySrc = default(VBuffer); + InputSchema.GetMetadata(MetadataUtils.Kinds.KeyValues, srcCol, ref namesKeySrc); + Host.Check(namesKeySrc.Length == keyCount); + var keys = new DvText[keyCount]; + namesKeySrc.CopyTo(keys); + + var values = dst.Values; + if (Utils.Size(values) < slotLim) + values = new DvText[slotLim]; + + var sb = new StringBuilder(); + int slot = 0; + foreach (var kvpSlot in namesSlotSrc.Items(all: true)) + { + Contracts.Assert(slot == (long)kvpSlot.Key * keyCount); + sb.Clear(); + if (kvpSlot.Value.HasChars) + kvpSlot.Value.AddToStringBuilder(sb); + else + sb.Append('[').Append(kvpSlot.Key).Append(']'); + sb.Append('.'); - if (!bag || info.TypeSrc.ValueCount == 1) - bldr.AddPrimitive(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, DvBool.True); + int len = sb.Length; + foreach (var key in keys) + { + sb.Length = len; + key.AddToStringBuilder(sb); + values[slot++] = new DvText(sb.ToString()); + } + } + Host.Assert(slot == slotLim); + + dst = new VBuffer(slotLim, values, dst.Indices); } - } - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Host.Assert(0 <= iinfo & iinfo < _types.Length); - return _types[iinfo]; - } + private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) + { + Host.Assert(0 <= iinfo && iinfo < _infos.Length); - private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) - { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); + var info = _infos[iinfo]; - var info = Infos[iinfo]; + Host.Assert(info.TypeSrc.ValueCount > 0); - Host.Assert(info.TypeSrc.ValueCount > 0); + DvInt4[] ranges = new DvInt4[info.TypeSrc.ValueCount * 2]; + int size = info.TypeSrc.ItemType.KeyCount; - DvInt4[] ranges = new DvInt4[info.TypeSrc.ValueCount * 2]; - int size = info.TypeSrc.ItemType.KeyCount; + ranges[0] = 0; + ranges[1] = size - 1; + for (int i = 2; i < ranges.Length; i += 2) + { + ranges[i] = ranges[i - 1] + 1; + ranges[i + 1] = ranges[i] + size - 1; + } + + dst = new VBuffer(ranges.Length, ranges); + } - ranges[0] = 0; - ranges[1] = size - 1; - for (int i = 2; i < ranges.Length; i += 2) + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) { - ranges[i] = ranges[i - 1] + 1; - ranges[i + 1] = ranges[i] + size - 1; + Host.AssertValue(input); + Host.Assert(0 <= iinfo && iinfo < _infos.Length); + disposer = null; + + var info = _infos[iinfo]; + if (!info.TypeSrc.IsVector) + return MakeGetterOne(input, iinfo); + if (_parent._bags[iinfo]) + return MakeGetterBag(input, iinfo); + return MakeGetterInd(input, iinfo); } - dst = new VBuffer(ranges.Length, ranges); - } + /// + /// This is for the singleton case. This should be equivalent to both Bag and Ord over + /// a vector of size one. + /// + private ValueGetter> MakeGetterOne(IRow input, int iinfo) + { + Host.AssertValue(input); + Host.Assert(_infos[iinfo].TypeSrc.IsKey); + Host.Assert(_infos[iinfo].TypeSrc.KeyCount == _parent._valueCounts[iinfo] * _parent._sizes[iinfo]); + + int size = _infos[iinfo].TypeSrc.KeyCount; + Host.Assert(size > 0); + input.Schema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol); + var getSrc = RowCursorUtils.GetGetterAs(NumberType.U4, input, srcCol); + var src = default(uint); + return + (ref VBuffer dst) => + { + getSrc(ref src); + if (src == 0 || src > size) + { + dst = new VBuffer(size, 0, dst.Values, dst.Indices); + return; + } - // Used for slot names when appropriate. - private void GetKeyNames(int iinfo, ref VBuffer dst) - { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - Host.Assert(!_concat[iinfo]); + var values = dst.Values; + var indices = dst.Indices; + if (Utils.Size(values) < 1) + values = new float[1]; + if (Utils.Size(indices) < 1) + indices = new int[1]; + values[0] = 1; + indices[0] = (int)src - 1; + + dst = new VBuffer(size, 1, values, indices); + }; + } - // Slot names are just the key value names. - Source.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, Infos[iinfo].Source, ref dst); - } + /// + /// This is for the bagging case - vector input and outputs should be added. + /// + private ValueGetter> MakeGetterBag(IRow input, int iinfo) + { + Host.AssertValue(input); + Host.Assert(_infos[iinfo].TypeSrc.IsVector); + Host.Assert(_infos[iinfo].TypeSrc.ItemType.IsKey); + Host.Assert(_parent._bags[iinfo]); + Host.Assert(_infos[iinfo].TypeSrc.ItemType.KeyCount == _parent._valueCounts[iinfo] * _parent._sizes[iinfo]); + + var info = _infos[iinfo]; + int size = info.TypeSrc.ItemType.KeyCount; + Host.Assert(size > 0); + + int cv = info.TypeSrc.VectorSize; + Host.Assert(cv >= 0); + input.Schema.TryGetColumnIndex(info.Source, out int srcCol); + var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.U4, input, srcCol); + var src = default(VBuffer); + var bldr = BufferBuilder.CreateDefault(); + return + (ref VBuffer dst) => + { + bldr.Reset(size, false); - // Combines source key names and slot names to produce final slot names. - private void GetSlotNames(int iinfo, ref VBuffer dst) - { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - Host.Assert(_concat[iinfo]); - Host.Assert(_types[iinfo].IsKnownSizeVector); - - // Size one should have been treated the same as Bag (by the caller). - // Variable size should have thrown (by the caller). - var typeSrc = Infos[iinfo].TypeSrc; - Host.Assert(typeSrc.VectorSize > 1); - - // Get the source slot names, defaulting to empty text. - var namesSlotSrc = default(VBuffer); - var typeSlotSrc = Source.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, Infos[iinfo].Source); - if (typeSlotSrc != null && typeSlotSrc.VectorSize == typeSrc.VectorSize && typeSlotSrc.ItemType.IsText) + getSrc(ref src); + Host.Check(cv == 0 || src.Length == cv); + + // The indices are irrelevant in the bagging case. + var values = src.Values; + int count = src.Count; + for (int slot = 0; slot < count; slot++) + { + uint key = values[slot] - 1; + if (key < size) + bldr.AddFeature((int)key, 1); + } + + bldr.GetResult(ref dst); + }; + } + + /// + /// This is for the indicator (non-bagging) case - vector input and outputs should be concatenated. + /// + private ValueGetter> MakeGetterInd(IRow input, int iinfo) { - Source.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, Infos[iinfo].Source, ref namesSlotSrc); - Host.Check(namesSlotSrc.Length == typeSrc.VectorSize); + Host.AssertValue(input); + Host.Assert(_infos[iinfo].TypeSrc.IsVector); + Host.Assert(_infos[iinfo].TypeSrc.ItemType.IsKey); + Host.Assert(!_parent._bags[iinfo]); + + var info = _infos[iinfo]; + int size = info.TypeSrc.ItemType.KeyCount; + Host.Assert(size > 0); + + int cv = info.TypeSrc.VectorSize; + Host.Assert(cv >= 0); + Host.Assert(_parent._valueCounts[iinfo] * _parent._sizes[iinfo] == size * cv); + input.Schema.TryGetColumnIndex(info.Source, out int srcCol); + var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.U4, input, srcCol); + var src = default(VBuffer); + return + (ref VBuffer dst) => + { + getSrc(ref src); + int lenSrc = src.Length; + Host.Check(lenSrc == cv || cv == 0); + + // Since we generate values in order, no need for a builder. + var valuesDst = dst.Values; + var indicesDst = dst.Indices; + + int lenDst = checked(size * lenSrc); + int cntSrc = src.Count; + if (Utils.Size(valuesDst) < cntSrc) + valuesDst = new float[cntSrc]; + if (Utils.Size(indicesDst) < cntSrc) + indicesDst = new int[cntSrc]; + + var values = src.Values; + int count = 0; + if (src.IsDense) + { + Host.Assert(lenSrc == cntSrc); + for (int slot = 0; slot < cntSrc; slot++) + { + Host.Assert(count < cntSrc); + uint key = values[slot] - 1; + if (key >= (uint)size) + continue; + valuesDst[count] = 1; + indicesDst[count++] = slot * size + (int)key; + } + } + else + { + var indices = src.Indices; + for (int islot = 0; islot < cntSrc; islot++) + { + Host.Assert(count < cntSrc); + uint key = values[islot] - 1; + if (key >= (uint)size) + continue; + valuesDst[count] = 1; + indicesDst[count++] = indices[islot] * size + (int)key; + } + } + dst = new VBuffer(lenDst, count, valuesDst, indicesDst); + }; } - else - namesSlotSrc = VBufferUtils.CreateEmpty(typeSrc.VectorSize); - - int keyCount = typeSrc.ItemType.KeyCount; - int slotLim = _types[iinfo].VectorSize; - Host.Assert(slotLim == (long)typeSrc.VectorSize * keyCount); - - // Get the source key names, in an array (since we will use them multiple times). - var namesKeySrc = default(VBuffer); - Source.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, Infos[iinfo].Source, ref namesKeySrc); - Host.Check(namesKeySrc.Length == keyCount); - var keys = new DvText[keyCount]; - namesKeySrc.CopyTo(keys); - - var values = dst.Values; - if (Utils.Size(values) < slotLim) - values = new DvText[slotLim]; - - var sb = new StringBuilder(); - int slot = 0; - foreach (var kvpSlot in namesSlotSrc.Items(all: true)) + + public bool CanSaveOnnx => true; + + public bool CanSavePfa => true; + + public void SaveAsOnnx(OnnxContext ctx) { - Contracts.Assert(slot == (long)kvpSlot.Key * keyCount); - sb.Clear(); - if (kvpSlot.Value.HasChars) - kvpSlot.Value.AddToStringBuilder(sb); - else - sb.Append('[').Append(kvpSlot.Key).Append(']'); - sb.Append('.'); + Host.CheckValue(ctx, nameof(ctx)); - int len = sb.Length; - foreach (var key in keys) + for (int iinfo = 0; iinfo < _infos.Length; ++iinfo) { - sb.Length = len; - key.AddToStringBuilder(sb); - values[slot++] = new DvText(sb.ToString()); + ColInfo info = _infos[iinfo]; + string sourceColumnName = info.Source; + if (!ctx.ContainsColumn(sourceColumnName)) + { + ctx.RemoveColumn(info.Name, false); + continue; + } + + if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName), + ctx.AddIntermediateVariable(_types[iinfo], info.Name))) + { + ctx.RemoveColumn(info.Name, true); + } } } - Host.Assert(slot == slotLim); - - dst = new VBuffer(slotLim, values, dst.Indices); - } - - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) - { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; - - var info = Infos[iinfo]; - if (!info.TypeSrc.IsVector) - return MakeGetterOne(input, iinfo); - if (_bag[iinfo]) - return MakeGetterBag(input, iinfo); - return MakeGetterInd(input, iinfo); - } - /// - /// This is for the singleton case. This should be equivalent to both Bag and Ord over - /// a vector of size one. - /// - private ValueGetter> MakeGetterOne(IRow input, int iinfo) - { - Host.AssertValue(input); - Host.Assert(Infos[iinfo].TypeSrc.IsKey); - Host.Assert(Infos[iinfo].TypeSrc.KeyCount == _types[iinfo].VectorSize); + public void SaveAsPfa(BoundPfaContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); - int size = Infos[iinfo].TypeSrc.KeyCount; - Host.Assert(size > 0); + var toHide = new List(); + var toDeclare = new List>(); - var getSrc = RowCursorUtils.GetGetterAs(NumberType.U4, input, Infos[iinfo].Source); - var src = default(uint); - return - (ref VBuffer dst) => + for (int iinfo = 0; iinfo < _infos.Length; ++iinfo) { - getSrc(ref src); - if (src == 0 || src > size) + var info = _infos[iinfo]; + var srcName = info.Source; + string srcToken = ctx.TokenOrNullForName(srcName); + if (srcToken == null) { - dst = new VBuffer(size, 0, dst.Values, dst.Indices); - return; + toHide.Add(info.Name); + continue; } + var result = SaveAsPfaCore(ctx, iinfo, info, srcToken); + if (result == null) + { + toHide.Add(info.Name); + continue; + } + toDeclare.Add(new KeyValuePair(info.Name, result)); + } + ctx.Hide(toHide.ToArray()); + ctx.DeclareVar(toDeclare.ToArray()); + } - var values = dst.Values; - var indices = dst.Indices; - if (Utils.Size(values) < 1) - values = new Float[1]; - if (Utils.Size(indices) < 1) - indices = new int[1]; - values[0] = 1; - indices[0] = (int)src - 1; + private JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToken srcToken) + { + Contracts.AssertValue(ctx); + Contracts.Assert(0 <= iinfo && iinfo < _infos.Length); + Contracts.Assert(_infos[iinfo] == info); + Contracts.AssertValue(srcToken); + Contracts.Assert(CanSavePfa); + + int keyCount = info.TypeSrc.ItemType.KeyCount; + Host.Assert(keyCount > 0); + // If the input type is scalar, we can just use the fanout function. + if (!info.TypeSrc.IsVector) + return PfaUtils.Call("cast.fanoutDouble", srcToken, 0, keyCount, false); + + JToken arrType = PfaUtils.Type.Array(PfaUtils.Type.Double); + if (_parent._bags[iinfo] || info.TypeSrc.ValueCount == 1) + { + // The concatenation case. We can still use fanout, but we just append them all together. + return PfaUtils.Call("a.flatMap", srcToken, + PfaContext.CreateFuncBlock(new JArray() { PfaUtils.Param("k", PfaUtils.Type.Int) }, + arrType, PfaUtils.Call("cast.fanoutDouble", "k", 0, keyCount, false))); + } + + // The bag case, while the most useful, is the most elaborate and difficult: we create + // an all-zero array and then add items to it. + const string funcName = "keyToVecUpdate"; + if (!ctx.Pfa.ContainsFunc(funcName)) + { + var toFunc = PfaContext.CreateFuncBlock( + new JArray() { PfaUtils.Param("v", PfaUtils.Type.Double) }, PfaUtils.Type.Double, + PfaUtils.Call("+", "v", 1)); + + ctx.Pfa.AddFunc(funcName, + new JArray(PfaUtils.Param("a", arrType), PfaUtils.Param("i", PfaUtils.Type.Int)), + arrType, PfaUtils.If(PfaUtils.Call(">=", "i", 0), + PfaUtils.Index("a", "i").AddReturn("to", toFunc), "a")); + } + + return PfaUtils.Call("a.fold", srcToken, + PfaUtils.Call("cast.fanoutDouble", -1, 0, keyCount, false), PfaUtils.FuncRef("u." + funcName)); + } + + private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + { + string opType = "OneHotEncoder"; + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("cats_int64s", Enumerable.Range(1, info.TypeSrc.ItemType.KeyCount).Select(x => (long)x)); + node.AddAttribute("zeros", true); + return true; + } - dst = new VBuffer(size, 1, values, indices); - }; } + } - /// - /// This is for the bagging case - vector input and outputs should be added. - /// - private ValueGetter> MakeGetterBag(IRow input, int iinfo) + public sealed class KeyToVectorEstimator : IEstimator + { + private readonly IHost _host; + private readonly KeyToVectorTransform.ColumnInfo[] _columns; + public static class Defaults { - Host.AssertValue(input); - Host.Assert(Infos[iinfo].TypeSrc.IsVector); - Host.Assert(Infos[iinfo].TypeSrc.ItemType.IsKey); - Host.Assert(_bag[iinfo]); - Host.Assert(Infos[iinfo].TypeSrc.ItemType.KeyCount == _types[iinfo].VectorSize); - - var info = Infos[iinfo]; - int size = info.TypeSrc.ItemType.KeyCount; - Host.Assert(size > 0); - - int cv = info.TypeSrc.VectorSize; - Host.Assert(cv >= 0); - - var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.U4, input, info.Source); - var src = default(VBuffer); - var bldr = BufferBuilder.CreateDefault(); - return - (ref VBuffer dst) => - { - bldr.Reset(size, false); + public const bool Bag = false; + } - getSrc(ref src); - Host.Check(cv == 0 || src.Length == cv); + public KeyToVectorEstimator(IHostEnvironment env, params KeyToVectorTransform.ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(KeyToVectorEstimator)); + _columns = columns; + } - // The indices are irrelevant in the bagging case. - var values = src.Values; - int count = src.Count; - for (int slot = 0; slot < count; slot++) - { - uint key = values[slot] - 1; - if (key < size) - bldr.AddFeature((int)key, 1); - } + public KeyToVectorEstimator(IHostEnvironment env, string name, string source = null, bool bag = Defaults.Bag) : + this(env, new KeyToVectorTransform.ColumnInfo(source ?? name, name, bag)) + { - bldr.GetResult(ref dst); - }; } - /// - /// This is for the indicator (non-bagging) case - vector input and outputs should be concatenated. - /// - private ValueGetter> MakeGetterInd(IRow input, int iinfo) + public SchemaShape GetOutputSchema(SchemaShape inputSchema) { - Host.AssertValue(input); - Host.Assert(Infos[iinfo].TypeSrc.IsVector); - Host.Assert(Infos[iinfo].TypeSrc.ItemType.IsKey); - Host.Assert(!_bag[iinfo]); - - var info = Infos[iinfo]; - int size = info.TypeSrc.ItemType.KeyCount; - Host.Assert(size > 0); - - int cv = info.TypeSrc.VectorSize; - Host.Assert(cv >= 0); - Host.Assert(_types[iinfo].VectorSize == size * cv); - - var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.U4, input, info.Source); - var src = default(VBuffer); - return - (ref VBuffer dst) => - { - getSrc(ref src); - int lenSrc = src.Length; - Host.Check(lenSrc == cv || cv == 0); - - // Since we generate values in order, no need for a builder. - var valuesDst = dst.Values; - var indicesDst = dst.Indices; - - int lenDst = checked(size * lenSrc); - int cntSrc = src.Count; - if (Utils.Size(valuesDst) < cntSrc) - valuesDst = new Float[cntSrc]; - if (Utils.Size(indicesDst) < cntSrc) - indicesDst = new int[cntSrc]; - - var values = src.Values; - int count = 0; - if (src.IsDense) - { - Host.Assert(lenSrc == cntSrc); - for (int slot = 0; slot < cntSrc; slot++) - { - Host.Assert(count < cntSrc); - uint key = values[slot] - 1; - if (key >= (uint)size) - continue; - valuesDst[count] = 1; - indicesDst[count++] = slot * size + (int)key; - } - } - else - { - var indices = src.Indices; - for (int islot = 0; islot < cntSrc; islot++) - { - Host.Assert(count < cntSrc); - uint key = values[islot] - 1; - if (key >= (uint)size) - continue; - valuesDst[count] = 1; - indicesDst[count++] = indices[islot] * size + (int)key; - } - } - dst = new VBuffer(lenDst, count, valuesDst, indicesDst); - }; + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in _columns) + { + var col = inputSchema.FindColumn(colInfo.Input); + + if (col == null) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + + if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + List metadata = new List(); + + if (col.MetadataKinds.Contains(MetadataUtils.Kinds.KeyValues)) + if (col.Kind != SchemaShape.Column.VectorKind.VariableVector && col.ItemType.IsText) + metadata.Add(MetadataUtils.Kinds.SlotNames); + if (!colInfo.Bag && (col.Kind == SchemaShape.Column.VectorKind.Scalar || col.Kind == SchemaShape.Column.VectorKind.Vector)) + metadata.Add(MetadataUtils.Kinds.CategoricalSlotRanges); + if (!colInfo.Bag || (col.Kind == SchemaShape.Column.VectorKind.Scalar)) + metadata.Add(MetadataUtils.Kinds.IsNormalized); + + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, metadata.ToArray()); + } + + return new SchemaShape(result.Values); } + + public KeyToVectorTransform Fit(IDataView input) => new KeyToVectorTransform(_host, input, _columns); } + } diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index d124433797..a96aa056a3 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -271,7 +271,7 @@ private TermTransform(IHostEnvironment env, IDataView input, { using (var ch = Host.Start("Training")) { - var infos = CreateInfos(Host, ColumnPairs, input.Schema, TestIsKnownDataKind); + var infos = CreateInfos(input.Schema); _unboundMaps = Train(Host, ch, infos, file, termsColumn, loaderFactory, columns, input); _textMetadata = new bool[_unboundMaps.Length]; for (int iinfo = 0; iinfo < columns.Length; ++iinfo) @@ -403,29 +403,6 @@ public static IDataView Create(IHostEnvironment env, //REVIEW: This and static method below need to go to base class as it get created. private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}."; - private static ColInfo[] CreateInfos(IHostEnvironment env, (string source, string name)[] columns, ISchema schema, Func testType) - { - env.CheckUserArg(Utils.Size(columns) > 0, nameof(columns)); - env.AssertValue(schema); - env.AssertValueOrNull(testType); - - var infos = new ColInfo[columns.Length]; - for (int i = 0; i < columns.Length; i++) - { - if (!schema.TryGetColumnIndex(columns[i].source, out int colSrc)) - throw env.ExceptUserArg(nameof(columns), "Source column '{0}' not found", columns[i].source); - var type = schema.GetColumnType(colSrc); - if (testType != null) - { - string reason = testType(type); - if (reason != null) - throw env.ExceptUserArg(nameof(columns), InvalidTypeErrorFormat, columns[i].source, type, reason); - } - infos[i] = new ColInfo(columns[i].name, columns[i].source, type); - } - return infos; - } - public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, ColumnBase[] column, IDataView input) { return Create(env, new Arguments() @@ -701,7 +678,7 @@ public override void Save(ModelSaveContext ctx) ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - base.SaveColumns(ctx); + SaveColumns(ctx); Host.Assert(_unboundMaps.Length == _textMetadata.Length); Host.Assert(_textMetadata.Length == ColumnPairs.Length); diff --git a/src/Microsoft.ML.Transforms/CategoricalTransform.cs b/src/Microsoft.ML.Transforms/CategoricalTransform.cs index b19ebc7bfc..95da1e3932 100644 --- a/src/Microsoft.ML.Transforms/CategoricalTransform.cs +++ b/src/Microsoft.ML.Transforms/CategoricalTransform.cs @@ -177,7 +177,7 @@ public static IDataTransform CreateTransformCore( using (var ch = h.Start("Create Transform Core")) { // Create the KeyToVectorTransform, if needed. - List cols = new List(); + var cols = new List(); bool binaryEncoding = argsOutputKind == OutputKind.Bin; for (int i = 0; i < columns.Length; i++) { @@ -232,7 +232,7 @@ public static IDataTransform CreateTransformCore( Column = cols.ToArray() }; - transform = new KeyToVectorTransform(h, keyToVecArgs, input); + transform =KeyToVectorTransform.Create(h, keyToVecArgs, input); } ch.Done(); diff --git a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs index fb91b03ebb..5ee5353962 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs @@ -110,15 +110,7 @@ private static IDataView ApplyKeyToVec(List ktv, ID TextKeyValues = true }, viewTrain); - viewTrain = new KeyToVectorTransform(host, - new KeyToVectorTransform.Arguments() - { - Column = ktv - .Select(c => new KeyToVectorTransform.Column() { Name = c.Name, Source = c.Name }) - .ToArray(), - Bag = false - }, - viewTrain); + viewTrain = KeyToVectorTransform.Create(host, viewTrain, ktv.Select(c => new KeyToVectorTransform.ColumnInfo(c.Name, c.Name)).ToArray()); } return viewTrain; } diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs new file mode 100644 index 0000000000..af45856e18 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -0,0 +1,122 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using System.IO; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class KeyToVectorEstimatorTest : TestDataPipeBase + { + public KeyToVectorEstimatorTest(ITestOutputHelper output) : base(output) + { + } + class TestClass + { + public int A; + public int B; + public int C; + } + class TestMeta + { + [VectorType(2)] + public string[] A; + public string B; + [VectorType(2)] + public int[] C; + public int D; + [VectorType(2)] + public float[] E; + public float F; + } + + [Fact] + public void KeyToVectorWorkout() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + dataView = new TermEstimator(Env, new[]{ + new TermTransform.ColumnInfo("A", "TermA"), + new TermTransform.ColumnInfo("B", "TermB"), + new TermTransform.ColumnInfo("C", "TermC", textKeyValues:true) + }).Fit(dataView).Transform(dataView); + + var pipe = new KeyToVectorEstimator(Env, new KeyToVectorTransform.ColumnInfo("TermA", "CatA", false), + new KeyToVectorTransform.ColumnInfo("TermB", "CatB", true)); + TestEstimatorCore(pipe, dataView); + } + + [Fact] + void TestMetadataCopy() + { + var data = new[] { new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E=new float[2] { 2.0f,4.0f}, F = 1.0f }, + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 5,3}, D= 1, E=new float[2] { 4.0f,2.0f}, F = -1.0f }, + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E=new float[2] { 2.0f,4.0f}, F = 1.0f } }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + var termEst = new TermEstimator(Env, + new TermTransform.ColumnInfo("A", "TA"), + new TermTransform.ColumnInfo("B", "TB"), + new TermTransform.ColumnInfo("C", "TC"), + new TermTransform.ColumnInfo("D", "TD"), + new TermTransform.ColumnInfo("E", "TE"), + new TermTransform.ColumnInfo("F", "TF")); + var termTransformer = termEst.Fit(dataView); + dataView = termTransformer.Transform(dataView); + + var pipe = new KeyToVectorEstimator(Env, + new KeyToVectorTransform.ColumnInfo("TA", "CatA", false), + new KeyToVectorTransform.ColumnInfo("TB", "CatB", false), + new KeyToVectorTransform.ColumnInfo("TC", "CatC", true), + new KeyToVectorTransform.ColumnInfo("TD", "CatD", false), + new KeyToVectorTransform.ColumnInfo("TE", "CatE", false), + new KeyToVectorTransform.ColumnInfo("TF", "CatF", true) + ); + + var result = pipe.Fit(dataView).Transform(dataView); + } + + + [Fact] + void TestCommandLine() + { + using (var env = new TlcEnvironment()) + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToVector{col=C:B} in=f:\2.txt" }), (int)0); + } + } + + [Fact] + void TestOldSavingAndLoading() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var est = new TermEstimator(Env, new[]{ + new TermTransform.ColumnInfo("A", "TermA"), + new TermTransform.ColumnInfo("B", "TermB"), + new TermTransform.ColumnInfo("C", "TermC") + }); + var transformer = est.Fit(dataView); + dataView = transformer.Transform(dataView); + var pipe = new KeyToVectorEstimator(Env, new KeyToVectorTransform.ColumnInfo("TermA", "CatA", false), + new KeyToVectorTransform.ColumnInfo("TermB", "CatB", true)); + var result = pipe.Fit(dataView).Transform(dataView); + var resultRoles = new RoleMappedData(result); + using (var ms = new MemoryStream()) + { + TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles); + ms.Position = 0; + var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms); + + } + } + } +} From b8ce664f75d23381730c7e5daa06c180095253e3 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Fri, 7 Sep 2018 11:22:33 -0700 Subject: [PATCH 02/17] merge with master --- .../Transforms/KeyToVectorTransform.cs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 3a20ebcd04..440b3fc780 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -341,7 +341,7 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) { Host.Assert(0 <= col && col < _infos.Length); InputSchema.TryGetColumnIndex(_infos[col].Source, out int sourceColumn); - InputSchema.GetMetadata>(MetadataUtils.Kinds.KeyValues, sourceColumn, ref dst); + InputSchema.GetMetadata(MetadataUtils.Kinds.KeyValues, sourceColumn, ref dst); }; var info = new MetadataInfo>(typeNames, getter); colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info); @@ -770,24 +770,23 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.Columns.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { - var col = inputSchema.FindColumn(colInfo.Input); - - if (col == null) + if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); - List metadata = new List(); + var metadata = new List(); - if (col.MetadataKinds.Contains(MetadataUtils.Kinds.KeyValues)) + if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var keyMeta)) if (col.Kind != SchemaShape.Column.VectorKind.VariableVector && col.ItemType.IsText) - metadata.Add(MetadataUtils.Kinds.SlotNames); + metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, keyMeta.ItemType, false)); if (!colInfo.Bag && (col.Kind == SchemaShape.Column.VectorKind.Scalar || col.Kind == SchemaShape.Column.VectorKind.Vector)) - metadata.Add(MetadataUtils.Kinds.CategoricalSlotRanges); + metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.CategoricalSlotRanges, SchemaShape.Column.VectorKind.Scalar, MetadataUtils.GetCategoricalType(1), false)); if (!colInfo.Bag || (col.Kind == SchemaShape.Column.VectorKind.Scalar)) - metadata.Add(MetadataUtils.Kinds.IsNormalized); + metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, + BoolType.Instance, false)); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, metadata.ToArray()); + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(metadata)); } return new SchemaShape(result.Values); From 8eb2691dbf3c9bb2d1107a0d0453a57282d9665e Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Fri, 7 Sep 2018 15:14:23 -0700 Subject: [PATCH 03/17] Improve tests --- .../Transforms/KeyToVectorTransform.cs | 18 ++- .../Transforms/TermTransformImpl.cs | 2 - .../Transformers/KeyToVectorEstimatorTests.cs | 112 ++++++++++++++++-- 3 files changed, 106 insertions(+), 26 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 440b3fc780..95d2b003ab 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -339,9 +339,7 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) { MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => { - Host.Assert(0 <= col && col < _infos.Length); - InputSchema.TryGetColumnIndex(_infos[col].Source, out int sourceColumn); - InputSchema.GetMetadata(MetadataUtils.Kinds.KeyValues, sourceColumn, ref dst); + InputSchema.GetMetadata(MetadataUtils.Kinds.KeyValues, srcCol, ref dst); }; var info = new MetadataInfo>(typeNames, getter); colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info); @@ -355,22 +353,22 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) { MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => { - GetSlotNames(col, ref dst); + GetSlotNames(i, ref dst); }; var info = new MetadataInfo>(new VectorType(TextType.Instance, type), getter); colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info); } } - if (!_parent._bags[i] && _parent._valueCounts[i] > 0) + if (!_parent._bags[i] && srcType.ValueCount > 0) { MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => { - GetCategoricalSlotRanges(col, ref dst); + GetCategoricalSlotRanges(i, ref dst); }; var info = new MetadataInfo>(MetadataUtils.GetCategoricalType(_parent._valueCounts[i]), getter); colMetaInfo.Add(MetadataUtils.Kinds.CategoricalSlotRanges, info); } - if (_parent._bags[i] || _parent._valueCounts[i] == 1) + if (!_parent._bags[i] || srcType.ValueCount == 1) { MetadataUtils.MetadataGetter getter = (int col, ref DvBool dst) => { @@ -772,19 +770,17 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); - if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); - var metadata = new List(); + var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var keyMeta)) if (col.Kind != SchemaShape.Column.VectorKind.VariableVector && col.ItemType.IsText) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, keyMeta.ItemType, false)); if (!colInfo.Bag && (col.Kind == SchemaShape.Column.VectorKind.Scalar || col.Kind == SchemaShape.Column.VectorKind.Vector)) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.CategoricalSlotRanges, SchemaShape.Column.VectorKind.Scalar, MetadataUtils.GetCategoricalType(1), false)); if (!colInfo.Bag || (col.Kind == SchemaShape.Column.VectorKind.Scalar)) - metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, - BoolType.Instance, false)); + metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(metadata)); } diff --git a/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs b/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs index d428d06489..2096a41047 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs @@ -1051,7 +1051,6 @@ public override void AddMetadata(ColumnMetadataInfo colMetaInfo) MetadataUtils.MetadataGetter> getter = (int iinfo, ref VBuffer dst) => { - _host.Assert(iinfo == _iinfo); // No buffer sharing convenient here. VBuffer dstT = default(VBuffer); TypedMap.GetTerms(ref dstT); @@ -1066,7 +1065,6 @@ public override void AddMetadata(ColumnMetadataInfo colMetaInfo) MetadataUtils.MetadataGetter> getter = (int iinfo, ref VBuffer dst) => { - _host.Assert(iinfo == _iinfo); TypedMap.GetTerms(ref dst); }; var columnType = new VectorType(TypedMap.ItemType, TypedMap.OutputType.KeyCount); diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index af45856e18..d4d8b8656f 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -8,6 +8,7 @@ using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Runtime.Tools; using System.IO; +using System.Linq; using Xunit; using Xunit.Abstractions; @@ -35,6 +36,9 @@ class TestMeta [VectorType(2)] public float[] E; public float F; + [VectorType(2)] + public string[] G; + public string H; } [Fact] @@ -55,36 +59,118 @@ public void KeyToVectorWorkout() } [Fact] - void TestMetadataCopy() + void TestMetadataPropagation() { - var data = new[] { new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E=new float[2] { 2.0f,4.0f}, F = 1.0f }, - new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 5,3}, D= 1, E=new float[2] { 4.0f,2.0f}, F = -1.0f }, - new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E=new float[2] { 2.0f,4.0f}, F = 1.0f } }; + var data = new[] { + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E= new float[2] { 1.0f,2.0f}, F = 1.0f , G= new string[2]{ "A","D"}, H="D"}, + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 5,3}, D= 1, E=new float[2] { 3.0f,4.0f}, F = -1.0f ,G= new string[2]{"E", "A"}, H="E"}, + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E=new float[2] { 5.0f,6.0f}, F = 1.0f ,G= new string[2]{ "D", "E"}, H="D"} }; + var dataView = ComponentCreation.CreateDataView(Env, data); var termEst = new TermEstimator(Env, - new TermTransform.ColumnInfo("A", "TA"), + new TermTransform.ColumnInfo("A", "TA", textKeyValues: true), new TermTransform.ColumnInfo("B", "TB"), - new TermTransform.ColumnInfo("C", "TC"), - new TermTransform.ColumnInfo("D", "TD"), + new TermTransform.ColumnInfo("C", "TC", textKeyValues: true), + new TermTransform.ColumnInfo("D", "TD", textKeyValues: true), new TermTransform.ColumnInfo("E", "TE"), - new TermTransform.ColumnInfo("F", "TF")); + new TermTransform.ColumnInfo("F", "TF"), + new TermTransform.ColumnInfo("G", "TG"), + new TermTransform.ColumnInfo("H", "TH", textKeyValues: true)); var termTransformer = termEst.Fit(dataView); dataView = termTransformer.Transform(dataView); var pipe = new KeyToVectorEstimator(Env, - new KeyToVectorTransform.ColumnInfo("TA", "CatA", false), + new KeyToVectorTransform.ColumnInfo("TA", "CatA", true), new KeyToVectorTransform.ColumnInfo("TB", "CatB", false), - new KeyToVectorTransform.ColumnInfo("TC", "CatC", true), - new KeyToVectorTransform.ColumnInfo("TD", "CatD", false), + new KeyToVectorTransform.ColumnInfo("TC", "CatC", false), + new KeyToVectorTransform.ColumnInfo("TD", "CatD", true), new KeyToVectorTransform.ColumnInfo("TE", "CatE", false), - new KeyToVectorTransform.ColumnInfo("TF", "CatF", true) + new KeyToVectorTransform.ColumnInfo("TF", "CatF", true), + new KeyToVectorTransform.ColumnInfo("TG", "CatG", true), + new KeyToVectorTransform.ColumnInfo("TH", "CatH", false) ); var result = pipe.Fit(dataView).Transform(dataView); + ValidateMetadata(result); } + void ValidateMetadata(IDataView result) + { + Assert.True(result.Schema.TryGetColumnIndex("CatA", out int colA)); + Assert.True(result.Schema.TryGetColumnIndex("CatB", out int colB)); + Assert.True(result.Schema.TryGetColumnIndex("CatC", out int colC)); + Assert.True(result.Schema.TryGetColumnIndex("CatD", out int colD)); + Assert.True(result.Schema.TryGetColumnIndex("CatE", out int colE)); + Assert.True(result.Schema.TryGetColumnIndex("CatF", out int colF)); + Assert.True(result.Schema.TryGetColumnIndex("CatE", out int colG)); + Assert.True(result.Schema.TryGetColumnIndex("CatF", out int colH)); + var types = result.Schema.GetMetadataTypes(colA); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.SlotNames }); + VBuffer slots = default; + VBuffer slotRanges = default; + DvBool normalized = default; + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colA, ref slots); + Assert.True(slots.Length == 2); + Assert.Equal(slots.Values.Select(x => x.ToString()), new string[2] { "A", "B" }); + + types = result.Schema.GetMetadataTypes(colB); + Assert.Equal(types.Select(x => x.Key), new string[3] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colB, ref slots); + Assert.True(slots.Length == 1); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[1] { "C" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colB, ref slotRanges); + Assert.True(slotRanges.Length == 2); + Assert.Equal(slotRanges.Items().Select(x => x.Value.RawValue), new int[2] { 0, 0 }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colB, ref normalized); + Assert.True(normalized.IsTrue); + + types = result.Schema.GetMetadataTypes(colC); + Assert.Equal(types.Select(x => x.Key), new string[3] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colC, ref slots); + Assert.True(slots.Length == 4); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[4] { "[0].3", "[0].5", "[1].3", "[1].5" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colC, ref slotRanges); + Assert.True(slotRanges.Length == 4); + Assert.Equal(slotRanges.Items().Select(x => x.Value.RawValue), new int[4] { 0, 1, 2, 3 }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colC, ref normalized); + Assert.True(normalized.IsTrue); + + types = result.Schema.GetMetadataTypes(colD); + Assert.Equal(types.Select(x => x.Key), new string[2] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colD, ref slots); + Assert.True(slots.Length == 2); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "6", "1" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colD, ref normalized); + Assert.True(normalized.IsTrue); + types = result.Schema.GetMetadataTypes(colE); + Assert.Equal(types.Select(x => x.Key), new string[2] { MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colE, ref slotRanges); + Assert.True(slotRanges.Length == 4); + Assert.Equal(slotRanges.Items().Select(x => x.Value.RawValue), new int[4] { 0, 5, 6, 11 }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colE, ref normalized); + Assert.True(normalized.IsTrue); + + types = result.Schema.GetMetadataTypes(colF); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colF, ref normalized); + Assert.True(normalized.IsTrue); + + types = result.Schema.GetMetadataTypes(colG); + Assert.Equal(types.Select(x => x.Key), new string[2] { MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colG, ref slotRanges); + Assert.True(slotRanges.Length == 4); + Assert.Equal(slotRanges.Items().Select(x => x.Value.RawValue), new int[4] { 0, 5, 6, 11 }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colF, ref normalized); + Assert.True(normalized.IsTrue); + + types = result.Schema.GetMetadataTypes(colH); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colF, ref normalized); + Assert.True(normalized.IsTrue); + } + [Fact] void TestCommandLine() { @@ -107,7 +193,7 @@ void TestOldSavingAndLoading() var transformer = est.Fit(dataView); dataView = transformer.Transform(dataView); var pipe = new KeyToVectorEstimator(Env, new KeyToVectorTransform.ColumnInfo("TermA", "CatA", false), - new KeyToVectorTransform.ColumnInfo("TermB", "CatB", true)); + new KeyToVectorTransform.ColumnInfo("TermB", "CatB", true)); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); using (var ms = new MemoryStream()) From b17d003777227ae64d83d6b6391b60d09709cbeb Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Fri, 7 Sep 2018 17:22:48 -0700 Subject: [PATCH 04/17] cleanup code a bit --- .../Transforms/KeyToVectorTransform.cs | 13 +++++-------- .../CategoricalTransform.cs | 8 ++------ .../Runtime/EntryPoints/FeatureCombiner.cs | 14 +++++++------- .../Transformers/KeyToVectorEstimatorTests.cs | 11 +++++++---- 4 files changed, 21 insertions(+), 25 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 95d2b003ab..3d1b51e320 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -97,6 +97,7 @@ public ColumnInfo(string input, string output, bool bag = KeyToVectorEstimator.D Bag = bag; } } + internal sealed class ColInfo { public readonly string Name; @@ -242,7 +243,7 @@ private KeyToVectorTransform(IHost host, ModelLoadContext ctx) } } - public static IDataView Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) => + public static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) => new KeyToVectorTransform(env, input, columns).MakeDataTransform(input); // Factory method for SignatureDataTransform. @@ -256,7 +257,6 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV var cols = new ColumnInfo[args.Column.Length]; using (var ch = env.Start("ValidateArgs")) { - for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; @@ -325,7 +325,6 @@ public override RowMapperColumnInfo[] GetOutputColumns() private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) { InputSchema.TryGetColumnIndex(_infos[i].Source, out int srcCol); - //IVAN: Simplify var srcType = _infos[i].TypeSrc; var typeNames = InputSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, srcCol); if (typeNames == null || !typeNames.IsKnownSizeVector || !typeNames.ItemType.IsText || @@ -347,7 +346,6 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) } else { - //IVAN:simplify it var type = new VectorType(NumberType.Float, _parent._valueCounts[i], _parent._sizes[i]); if (typeNames != null && type.VectorSize > 0) { @@ -359,6 +357,7 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info); } } + if (!_parent._bags[i] && srcType.ValueCount > 0) { MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => @@ -368,6 +367,7 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) var info = new MetadataInfo>(MetadataUtils.GetCategoricalType(_parent._valueCounts[i]), getter); colMetaInfo.Add(MetadataUtils.Kinds.CategoricalSlotRanges, info); } + if (!_parent._bags[i] || srcType.ValueCount == 1) { MetadataUtils.MetadataGetter getter = (int col, ref DvBool dst) => @@ -736,7 +736,6 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src node.AddAttribute("zeros", true); return true; } - } } @@ -759,7 +758,6 @@ public KeyToVectorEstimator(IHostEnvironment env, params KeyToVectorTransform.Co public KeyToVectorEstimator(IHostEnvironment env, string name, string source = null, bool bag = Defaults.Bag) : this(env, new KeyToVectorTransform.ColumnInfo(source ?? name, name, bag)) { - } public SchemaShape GetOutputSchema(SchemaShape inputSchema) @@ -778,7 +776,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) if (col.Kind != SchemaShape.Column.VectorKind.VariableVector && col.ItemType.IsText) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, keyMeta.ItemType, false)); if (!colInfo.Bag && (col.Kind == SchemaShape.Column.VectorKind.Scalar || col.Kind == SchemaShape.Column.VectorKind.Vector)) - metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.CategoricalSlotRanges, SchemaShape.Column.VectorKind.Scalar, MetadataUtils.GetCategoricalType(1), false)); + metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.CategoricalSlotRanges, SchemaShape.Column.VectorKind.Vector, NumberType.I4, false)); if (!colInfo.Bag || (col.Kind == SchemaShape.Column.VectorKind.Scalar)) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); @@ -790,5 +788,4 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) public KeyToVectorTransform Fit(IDataView input) => new KeyToVectorTransform(_host, input, _columns); } - } diff --git a/src/Microsoft.ML.Transforms/CategoricalTransform.cs b/src/Microsoft.ML.Transforms/CategoricalTransform.cs index 95da1e3932..4fa6b08ab3 100644 --- a/src/Microsoft.ML.Transforms/CategoricalTransform.cs +++ b/src/Microsoft.ML.Transforms/CategoricalTransform.cs @@ -226,13 +226,9 @@ public static IDataTransform CreateTransformCore( } else { - var keyToVecArgs = new KeyToVectorTransform.Arguments - { - Bag = argsOutputKind == OutputKind.Bag, - Column = cols.ToArray() - }; + var keyToVecCols = cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.Source, x.Name, x.Bag ?? argsOutputKind == OutputKind.Bag)).ToArray(); - transform =KeyToVectorTransform.Create(h, keyToVecArgs, input); + transform = KeyToVectorTransform.Create(h, input, keyToVecCols); } ch.Done(); diff --git a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs index 5ee5353962..dc7089f5df 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs @@ -82,7 +82,7 @@ public static CommonOutputs.TransformOutput PrepareFeatures(IHostEnvironment env } } - private static IDataView ApplyKeyToVec(List ktv, IDataView viewTrain, IHost host) + private static IDataView ApplyKeyToVec(List ktv, IDataView viewTrain, IHost host) { Contracts.AssertValueOrNull(ktv); Contracts.AssertValue(viewTrain); @@ -97,7 +97,7 @@ private static IDataView ApplyKeyToVec(List ktv, ID new KeyToValueTransform.Arguments() { Column = ktv - .Select(c => new KeyToValueTransform.Column() { Name = c.Name, Source = c.Source }) + .Select(c => new KeyToValueTransform.Column() { Name = c.Output, Source = c.Input }) .ToArray() }, viewTrain); @@ -105,12 +105,12 @@ private static IDataView ApplyKeyToVec(List ktv, ID new TermTransform.Arguments() { Column = ktv - .Select(c => new TermTransform.Column() { Name = c.Name, Source = c.Name, Terms = GetTerms(viewTrain, c.Source) }) + .Select(c => new TermTransform.Column() { Name = c.Output, Source = c.Output, Terms = GetTerms(viewTrain, c.Input) }) .ToArray(), TextKeyValues = true }, viewTrain); - viewTrain = KeyToVectorTransform.Create(host, viewTrain, ktv.Select(c => new KeyToVectorTransform.ColumnInfo(c.Name, c.Name)).ToArray()); + viewTrain = KeyToVectorTransform.Create(host, viewTrain, ktv.Select(c => new KeyToVectorTransform.ColumnInfo(c.Output, c.Output)).ToArray()); } return viewTrain; } @@ -154,14 +154,14 @@ private static IDataView ApplyConvert(List cvt, IDataVi return viewTrain; } - private static List ConvertFeatures(ColumnInfo[] feats, HashSet featNames, List> concatNames, IChannel ch, + private static List ConvertFeatures(ColumnInfo[] feats, HashSet featNames, List> concatNames, IChannel ch, out List cvt, out int errCount) { Contracts.AssertValue(feats); Contracts.AssertValue(featNames); Contracts.AssertValue(concatNames); Contracts.AssertValue(ch); - List ktv = null; + List ktv = null; cvt = null; errCount = 0; foreach (var col in feats) @@ -179,7 +179,7 @@ private static IDataView ApplyConvert(List cvt, IDataVi { var colName = GetUniqueName(); concatNames.Add(new KeyValuePair(col.Name, colName)); - Utils.Add(ref ktv, new KeyToVectorTransform.Column() { Name = colName, Source = col.Name }); + Utils.Add(ref ktv, new KeyToVectorTransform.ColumnInfo(col.Name, colName)); continue; } } diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index d4d8b8656f..35dc90db5f 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -19,12 +19,14 @@ public class KeyToVectorEstimatorTest : TestDataPipeBase public KeyToVectorEstimatorTest(ITestOutputHelper output) : base(output) { } + class TestClass { public int A; public int B; public int C; } + class TestMeta { [VectorType(2)] @@ -189,11 +191,13 @@ void TestOldSavingAndLoading() new TermTransform.ColumnInfo("A", "TermA"), new TermTransform.ColumnInfo("B", "TermB"), new TermTransform.ColumnInfo("C", "TermC") - }); + }); var transformer = est.Fit(dataView); dataView = transformer.Transform(dataView); - var pipe = new KeyToVectorEstimator(Env, new KeyToVectorTransform.ColumnInfo("TermA", "CatA", false), - new KeyToVectorTransform.ColumnInfo("TermB", "CatB", true)); + var pipe = new KeyToVectorEstimator(Env, + new KeyToVectorTransform.ColumnInfo("TermA", "CatA", false), + new KeyToVectorTransform.ColumnInfo("TermB", "CatB", true) + ); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); using (var ms = new MemoryStream()) @@ -201,7 +205,6 @@ void TestOldSavingAndLoading() TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles); ms.Position = 0; var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms); - } } } From d541c6483654b6eead430cf622a86370489b72ca Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 10 Sep 2018 10:27:25 -0700 Subject: [PATCH 05/17] address comments --- .../Transforms/KeyToVectorTransform.cs | 82 ++++++++----------- 1 file changed, 35 insertions(+), 47 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 3d1b51e320..75a949b85e 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -113,9 +113,13 @@ public ColInfo(string name, string source, ColumnType type) } private const string RegistrationName = "KeyToVector"; + + /// + /// _bags indicates whether vector inputs should have their output indicator vectors added + /// (instead of concatenated). This is faithful to what the user specified in the Arguments + /// and is persisted. + /// private readonly bool[] _bags; - private readonly int[] _valueCounts; - private readonly int[] _sizes; private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) { @@ -155,16 +159,10 @@ public KeyToVectorTransform(IHostEnvironment env, IDataView input, ColumnInfo[] { var infos = CreateInfos(input.Schema); _bags = new bool[infos.Length]; - _valueCounts = new int[infos.Length]; - _sizes = new int[infos.Length]; - for (int i = 0; i < infos.Length; i++) - { _bags[i] = columns[i].Bag; - _sizes[i] = infos[i].TypeSrc.ItemType.KeyCount; - _valueCounts[i] = infos[i].TypeSrc.ValueCount; - } } + public const string LoaderSignature = "KeyToVectorTransform"; public const string UserName = "KeyToVectorTransform"; internal const string Summary = "Converts a key column to an indicator vector."; @@ -173,10 +171,9 @@ private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "KEY2VECT", - //verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x00010002, // Convert to Estimators - verReadableCur: 0x00010002, - verWeCanReadBack: 0x00010002, + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature); } @@ -187,25 +184,19 @@ public override void Save(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** + // int: sizeof(Float) // // for each added column // byte: bag as 0/1 // for each added column // int: keyCount // int: valueCount + ctx.Writer.Write(sizeof(float)); SaveColumns(ctx); Host.Assert(_bags.Length == ColumnPairs.Length); for (int i = 0; i < _bags.Length; i++) ctx.Writer.WriteBoolByte(_bags[i]); - Host.Assert(_valueCounts.Length == ColumnPairs.Length); - Host.Assert(_sizes.Length == ColumnPairs.Length); - - for (int i = 0; i < ColumnPairs.Length; i++) - { - ctx.Writer.Write(_sizes[i]); - ctx.Writer.Write(_valueCounts[i]); - } } // Factory method for SignatureLoadModel. @@ -219,9 +210,14 @@ public static KeyToVectorTransform Create(IHostEnvironment env, ModelLoadContext return new KeyToVectorTransform(host, ctx); } - + private static ModelLoadContext ReadFloatFromCtx(IHostEnvironment env, ModelLoadContext ctx) + { + int cbFloat = ctx.Reader.ReadInt32(); + env.CheckDecode(cbFloat == sizeof(float)); + return ctx; + } private KeyToVectorTransform(IHost host, ModelLoadContext ctx) - : base(host, ctx) + : base(host, ReadFloatFromCtx(host, ctx)) { var columnsLength = ColumnPairs.Length; // *** Binary format *** @@ -232,15 +228,7 @@ private KeyToVectorTransform(IHost host, ModelLoadContext ctx) // int: keyCount // int: valueCount _bags = new bool[columnsLength]; - _sizes = new int[columnsLength]; - _valueCounts = new int[columnsLength]; - _bags = ctx.Reader.ReadBoolArray(columnsLength); - for (int i = 0; i < columnsLength; i++) - { - _sizes[i] = ctx.Reader.ReadInt32(); - _valueCounts[i] = ctx.Reader.ReadInt32(); - } } public static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) => @@ -294,10 +282,10 @@ public Mapper(KeyToVectorTransform parent, ISchema inputSchema) for (int i = 0; i < _parent.ColumnPairs.Length; i++) { ColumnType type; - if (_parent._valueCounts[i] == 1) - type = new VectorType(NumberType.Float, _parent._sizes[i]); + if (_infos[i].TypeSrc.ValueCount == 1) + type = new VectorType(NumberType.Float, _infos[i].TypeSrc.ItemType.KeyCount); else - type = new VectorType(NumberType.Float, _parent._valueCounts[i], _parent._sizes[i]); + type = new VectorType(NumberType.Float, _infos[i].TypeSrc.ValueCount, _infos[i].TypeSrc.ItemType.KeyCount); _types[i] = type; } } @@ -313,10 +301,10 @@ public override RowMapperColumnInfo[] GetOutputColumns() AddMetadata(i, colMetaInfo); ColumnType type; - if (_parent._valueCounts[i] == 1) - type = new VectorType(NumberType.Float, _parent._sizes[i]); + if (_infos[i].TypeSrc.ValueCount == 1) + type = new VectorType(NumberType.Float, _infos[i].TypeSrc.ItemType.KeyCount); else - type = new VectorType(NumberType.Float, _parent._valueCounts[i], _parent._sizes[i]); + type = new VectorType(NumberType.Float, _infos[i].TypeSrc.ValueCount, _infos[i].TypeSrc.ItemType.KeyCount); result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _types[i], colMetaInfo); } return result; @@ -328,11 +316,11 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) var srcType = _infos[i].TypeSrc; var typeNames = InputSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, srcCol); if (typeNames == null || !typeNames.IsKnownSizeVector || !typeNames.ItemType.IsText || - typeNames.VectorSize != _parent._sizes[i]) + typeNames.VectorSize != _infos[i].TypeSrc.ItemType.KeyCount) { typeNames = null; } - if (_parent._bags[i] || _parent._valueCounts[i] == 1) + if (_parent._bags[i] || _infos[i].TypeSrc.ValueCount == 1) { if (typeNames != null) { @@ -346,8 +334,8 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) } else { - var type = new VectorType(NumberType.Float, _parent._valueCounts[i], _parent._sizes[i]); - if (typeNames != null && type.VectorSize > 0) + var type = new VectorType(NumberType.Float, _infos[i].TypeSrc.ValueCount, _infos[i].TypeSrc.ItemType.KeyCount); + if (typeNames != null && type.IsKnownSizeVector) { MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => { @@ -364,7 +352,7 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) { GetCategoricalSlotRanges(i, ref dst); }; - var info = new MetadataInfo>(MetadataUtils.GetCategoricalType(_parent._valueCounts[i]), getter); + var info = new MetadataInfo>(MetadataUtils.GetCategoricalType(_infos[i].TypeSrc.ValueCount), getter); colMetaInfo.Add(MetadataUtils.Kinds.CategoricalSlotRanges, info); } @@ -383,7 +371,7 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) private void GetSlotNames(int iinfo, ref VBuffer dst) { Host.Assert(0 <= iinfo && iinfo < _infos.Length); - var type = new VectorType(NumberType.Float, _parent._valueCounts[iinfo], _parent._sizes[iinfo]); + var type = new VectorType(NumberType.Float, _infos[iinfo].TypeSrc.ValueCount, _infos[iinfo].TypeSrc.ItemType.KeyCount); Host.Assert(type.IsKnownSizeVector); // Size one should have been treated the same as Bag (by the caller). @@ -404,7 +392,7 @@ private void GetSlotNames(int iinfo, ref VBuffer dst) else namesSlotSrc = VBufferUtils.CreateEmpty(typeSrc.VectorSize); - int keyCount = typeSrc.ItemType.KeyCount; + int keyCount = typeSrc.ItemType.ItemType.KeyCount; int slotLim = type.VectorSize; Host.Assert(slotLim == (long)typeSrc.VectorSize * keyCount); @@ -488,7 +476,7 @@ private ValueGetter> MakeGetterOne(IRow input, int iinfo) { Host.AssertValue(input); Host.Assert(_infos[iinfo].TypeSrc.IsKey); - Host.Assert(_infos[iinfo].TypeSrc.KeyCount == _parent._valueCounts[iinfo] * _parent._sizes[iinfo]); + Host.Assert(_infos[iinfo].TypeSrc.KeyCount == _types[iinfo].VectorSize); int size = _infos[iinfo].TypeSrc.KeyCount; Host.Assert(size > 0); @@ -527,7 +515,7 @@ private ValueGetter> MakeGetterBag(IRow input, int iinfo) Host.Assert(_infos[iinfo].TypeSrc.IsVector); Host.Assert(_infos[iinfo].TypeSrc.ItemType.IsKey); Host.Assert(_parent._bags[iinfo]); - Host.Assert(_infos[iinfo].TypeSrc.ItemType.KeyCount == _parent._valueCounts[iinfo] * _parent._sizes[iinfo]); + Host.Assert(_infos[iinfo].TypeSrc.ItemType.KeyCount == _types[iinfo].VectorSize); var info = _infos[iinfo]; int size = info.TypeSrc.ItemType.KeyCount; @@ -577,7 +565,7 @@ private ValueGetter> MakeGetterInd(IRow input, int iinfo) int cv = info.TypeSrc.VectorSize; Host.Assert(cv >= 0); - Host.Assert(_parent._valueCounts[iinfo] * _parent._sizes[iinfo] == size * cv); + Host.Assert(_types[iinfo].VectorSize == size * cv); input.Schema.TryGetColumnIndex(info.Source, out int srcCol); var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.U4, input, srcCol); var src = default(VBuffer); From 28c8d13c0d61cc5f81388890ee26209f0e1a928c Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 10 Sep 2018 13:13:58 -0700 Subject: [PATCH 06/17] KeyToBinary estimator --- .../Transforms/KeyToVectorTransform.cs | 56 +- .../CategoricalTransform.cs | 5 +- .../KeyToBinaryVectorTransform.cs | 602 +++++++++++------- .../KeyToBinaryVectorEstimatorTest.cs | 152 +++++ .../Transformers/KeyToVectorEstimatorTests.cs | 7 +- 5 files changed, 544 insertions(+), 278 deletions(-) create mode 100644 test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 75a949b85e..2317887d8d 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -90,6 +90,7 @@ public class ColumnInfo public readonly string Input; public readonly string Output; public readonly bool Bag; + public ColumnInfo(string input, string output, bool bag = KeyToVectorEstimator.Defaults.Bag) { Input = input; @@ -127,9 +128,16 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum return columns.Select(x => (x.Input, x.Output)).ToArray(); } - //REVIEW: This and static method below need to go to base class as it get created. + //REVIEW: This and method below need to go to base class as it get created. private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}."; + private string TestIsKey(ColumnType type) + { + if (type.ItemType.KeyCount > 0) + return null; + return "Expected Key type of known cardinality"; + } + private ColInfo[] CreateInfos(ISchema schema) { Host.AssertValue(schema); @@ -147,19 +155,13 @@ private ColInfo[] CreateInfos(ISchema schema) return infos; } - private string TestIsKey(ColumnType type) - { - if (type.ItemType.KeyCount > 0) - return null; - return "Expected Key type of known cardinality"; - } - public KeyToVectorTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { - var infos = CreateInfos(input.Schema); - _bags = new bool[infos.Length]; - for (int i = 0; i < infos.Length; i++) + // Validate input schema + CreateInfos(input.Schema); + _bags = new bool[ColumnPairs.Length]; + for (int i = 0; i < ColumnPairs.Length; i++) _bags[i] = columns[i].Bag; } @@ -210,12 +212,14 @@ public static KeyToVectorTransform Create(IHostEnvironment env, ModelLoadContext return new KeyToVectorTransform(host, ctx); } + private static ModelLoadContext ReadFloatFromCtx(IHostEnvironment env, ModelLoadContext ctx) { int cbFloat = ctx.Reader.ReadInt32(); env.CheckDecode(cbFloat == sizeof(float)); return ctx; } + private KeyToVectorTransform(IHost host, ModelLoadContext ctx) : base(host, ReadFloatFromCtx(host, ctx)) { @@ -271,22 +275,20 @@ private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa { private readonly KeyToVectorTransform _parent; private readonly ColInfo[] _infos; - private readonly ColumnType[] _types; + private readonly VectorType[] _types; public Mapper(KeyToVectorTransform parent, ISchema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; _infos = _parent.CreateInfos(inputSchema); - _types = new ColumnType[_parent.ColumnPairs.Length]; + _types = new VectorType[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - ColumnType type; if (_infos[i].TypeSrc.ValueCount == 1) - type = new VectorType(NumberType.Float, _infos[i].TypeSrc.ItemType.KeyCount); + _types[i] = new VectorType(NumberType.Float, _infos[i].TypeSrc.ItemType.KeyCount); else - type = new VectorType(NumberType.Float, _infos[i].TypeSrc.ValueCount, _infos[i].TypeSrc.ItemType.KeyCount); - _types[i] = type; + _types[i] = new VectorType(NumberType.Float, _infos[i].TypeSrc.ValueCount, _infos[i].TypeSrc.ItemType.KeyCount); } } @@ -299,12 +301,6 @@ public override RowMapperColumnInfo[] GetOutputColumns() Host.Assert(colIndex >= 0); var colMetaInfo = new ColumnMetadataInfo(_parent.ColumnPairs[i].output); AddMetadata(i, colMetaInfo); - - ColumnType type; - if (_infos[i].TypeSrc.ValueCount == 1) - type = new VectorType(NumberType.Float, _infos[i].TypeSrc.ItemType.KeyCount); - else - type = new VectorType(NumberType.Float, _infos[i].TypeSrc.ValueCount, _infos[i].TypeSrc.ItemType.KeyCount); result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _types[i], colMetaInfo); } return result; @@ -334,14 +330,13 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) } else { - var type = new VectorType(NumberType.Float, _infos[i].TypeSrc.ValueCount, _infos[i].TypeSrc.ItemType.KeyCount); - if (typeNames != null && type.IsKnownSizeVector) + if (typeNames != null && _types[i].IsKnownSizeVector) { MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => { GetSlotNames(i, ref dst); }; - var info = new MetadataInfo>(new VectorType(TextType.Instance, type), getter); + var info = new MetadataInfo>(new VectorType(TextType.Instance, _types[i]), getter); colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info); } } @@ -371,8 +366,7 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) private void GetSlotNames(int iinfo, ref VBuffer dst) { Host.Assert(0 <= iinfo && iinfo < _infos.Length); - var type = new VectorType(NumberType.Float, _infos[iinfo].TypeSrc.ValueCount, _infos[iinfo].TypeSrc.ItemType.KeyCount); - Host.Assert(type.IsKnownSizeVector); + Host.Assert(_types[iinfo].IsKnownSizeVector); // Size one should have been treated the same as Bag (by the caller). // Variable size should have thrown (by the caller). @@ -393,7 +387,7 @@ private void GetSlotNames(int iinfo, ref VBuffer dst) namesSlotSrc = VBufferUtils.CreateEmpty(typeSrc.VectorSize); int keyCount = typeSrc.ItemType.ItemType.KeyCount; - int slotLim = type.VectorSize; + int slotLim = _types[iinfo].VectorSize; Host.Assert(slotLim == (long)typeSrc.VectorSize * keyCount); // Get the source key names, in an array (since we will use them multiple times). @@ -481,6 +475,7 @@ private ValueGetter> MakeGetterOne(IRow input, int iinfo) int size = _infos[iinfo].TypeSrc.KeyCount; Host.Assert(size > 0); input.Schema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol); + Host.Assert(srcCol >= 0); var getSrc = RowCursorUtils.GetGetterAs(NumberType.U4, input, srcCol); var src = default(uint); return @@ -524,6 +519,7 @@ private ValueGetter> MakeGetterBag(IRow input, int iinfo) int cv = info.TypeSrc.VectorSize; Host.Assert(cv >= 0); input.Schema.TryGetColumnIndex(info.Source, out int srcCol); + Host.Assert(srcCol >= 0); var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.U4, input, srcCol); var src = default(VBuffer); var bldr = BufferBuilder.CreateDefault(); @@ -761,7 +757,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var keyMeta)) - if (col.Kind != SchemaShape.Column.VectorKind.VariableVector && col.ItemType.IsText) + if (col.Kind != SchemaShape.Column.VectorKind.VariableVector && keyMeta.ItemType.IsText) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, keyMeta.ItemType, false)); if (!colInfo.Bag && (col.Kind == SchemaShape.Column.VectorKind.Scalar || col.Kind == SchemaShape.Column.VectorKind.Vector)) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.CategoricalSlotRanges, SchemaShape.Column.VectorKind.Vector, NumberType.I4, false)); diff --git a/src/Microsoft.ML.Transforms/CategoricalTransform.cs b/src/Microsoft.ML.Transforms/CategoricalTransform.cs index 4fa6b08ab3..a4bec583c4 100644 --- a/src/Microsoft.ML.Transforms/CategoricalTransform.cs +++ b/src/Microsoft.ML.Transforms/CategoricalTransform.cs @@ -220,9 +220,8 @@ public static IDataTransform CreateTransformCore( if ((catHashArgs?.InvertHash ?? 0) != 0) ch.Warning("Invert hashing is being used with binary encoding."); - var keyToBinaryArgs = new KeyToBinaryVectorTransform.Arguments(); - keyToBinaryArgs.Column = cols.ToArray(); - transform = new KeyToBinaryVectorTransform(h, keyToBinaryArgs, input); + var keyToBinaryVecCols = cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.Source, x.Name)).ToArray(); + transform = KeyToBinaryVectorTransform.Create(h, input, keyToBinaryVecCols); } else { diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index 88a4228941..f35ef84e30 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -3,24 +3,31 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; +using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -[assembly: LoadableClass(KeyToBinaryVectorTransform.Summary, typeof(KeyToBinaryVectorTransform), - typeof(KeyToBinaryVectorTransform.Arguments), typeof(SignatureDataTransform), - "Key To Binary Vector Transform", "KeyToBinaryVectorTransform", "KeyToBinary", - DocName = "transform/KeyToBinaryVectorTransform.md")] +[assembly: LoadableClass(KeyToBinaryVectorTransform.Summary, typeof(IDataTransform), typeof(KeyToBinaryVectorTransform), typeof(KeyToBinaryVectorTransform.Arguments), typeof(SignatureDataTransform), + "Key To Binary Vector Transform", KeyToBinaryVectorTransform.UserName, "KeyToBinary", "ToVector", DocName = "transform/KeyToBinaryVectorTransform.md")] -[assembly: LoadableClass(KeyToBinaryVectorTransform.Summary, typeof(KeyToBinaryVectorTransform), - null, typeof(SignatureLoadDataTransform), "Key To Binary Vector Transform", KeyToBinaryVectorTransform.LoaderSignature)] +[assembly: LoadableClass(KeyToBinaryVectorTransform.Summary, typeof(IDataView), typeof(KeyToBinaryVectorTransform), null, typeof(SignatureLoadDataTransform), + "Key To Binary Vector Transform", KeyToBinaryVectorTransform.LoaderSignature)] + +[assembly: LoadableClass(KeyToBinaryVectorTransform.Summary, typeof(KeyToBinaryVectorTransform), null, typeof(SignatureLoadModel), + KeyToBinaryVectorTransform.UserName, KeyToBinaryVectorTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(KeyToBinaryVectorTransform), null, typeof(SignatureLoadRowMapper), + KeyToBinaryVectorTransform.UserName, KeyToBinaryVectorTransform.LoaderSignature)] namespace Microsoft.ML.Runtime.Data { - public sealed class KeyToBinaryVectorTransform : OneToOneTransformBase + public sealed class KeyToBinaryVectorTransform : OneToOneTransformerBase { public sealed class Arguments { @@ -28,9 +35,34 @@ public sealed class Arguments ShortName = "col", SortOrder = 1)] public KeyToVectorTransform.Column[] Column; } + public class ColumnInfo + { + public readonly string Input; + public readonly string Output; - internal const string Summary = "Converts a key column to a binary encoded vector."; + public ColumnInfo(string input, string output) + { + Input = input; + Output = output; + } + } + internal sealed class ColInfo + { + public readonly string Name; + public readonly string Source; + public readonly ColumnType TypeSrc; + + public ColInfo(string name, string source, ColumnType type) + { + Name = name; + Source = source; + TypeSrc = type; + } + } + + internal const string Summary = "Converts a key column to a binary encoded vector."; + public const string UserName = "KeyToBinaryVectorTransform"; public const string LoaderSignature = "KeyToBinaryTransform"; private static VersionInfo GetVersionInfo() { @@ -44,311 +76,395 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "KeyToBinary"; - // These arrays are parallel to Infos. - // * _concat is whether, given the current input, there are multiple output instance vectors - // to concatenate. - // * _types contains the output column types. - private readonly bool[] _concat; - - private readonly int[] _bitsPerKey; - - private readonly VectorType[] _types; - - /// - /// Convenience constructor for public facing API. - /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. - public KeyToBinaryVectorTransform(IHostEnvironment env, IDataView input, string name, string source = null) - : this(env, new Arguments() { Column = new[] { new KeyToVectorTransform.Column() { Source = source ?? name, Name = name } } }, input) + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) { + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); } - /// - /// Public constructor corresponding to SignatureDataTransform. - /// - public KeyToBinaryVectorTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, - input, TestIsKey) + //REVIEW: This and method below need to go to base class as it get created. + private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}."; + + private string TestIsKey(ColumnType type) { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); + if (type.ItemType.KeyCount > 0) + return null; + return "Expected Key type of known cardinality"; + } - Init(out _concat, out _types, out _bitsPerKey); + private ColInfo[] CreateInfos(ISchema schema) + { + Host.AssertValue(schema); + var infos = new ColInfo[ColumnPairs.Length]; + for (int i = 0; i < ColumnPairs.Length; i++) + { + if (!schema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) + throw Host.ExceptUserArg(nameof(ColumnPairs), "Source column '{0}' not found", ColumnPairs[i].input); + var type = schema.GetColumnType(colSrc); + string reason = TestIsKey(type); + if (reason != null) + throw Host.ExceptUserArg(nameof(ColumnPairs), InvalidTypeErrorFormat, ColumnPairs[i].input, type, reason); + infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type); + } + return infos; + } + + public KeyToBinaryVectorTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) + { + // Validate input schema. + CreateInfos(input.Schema); } - private KeyToBinaryVectorTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestIsKey) + public override void Save(ModelSaveContext ctx) { - Host.AssertValue(ctx); + Host.CheckValue(ctx, nameof(ctx)); // *** Binary format *** // // - Host.AssertNonEmpty(Infos); - - Init(out _concat, out _types, out _bitsPerKey); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + SaveColumns(ctx); } - private void Init(out bool[] concat, out VectorType[] types, out int[] bitsPerKey) + // Factory method for SignatureLoadModel. + public static KeyToBinaryVectorTransform Create(IHostEnvironment env, ModelLoadContext ctx) { - concat = new bool[Infos.Length]; - types = new VectorType[Infos.Length]; - bitsPerKey = new int[Infos.Length]; + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(RegistrationName); - for (int i = 0; i < Infos.Length; i++) - ComputeType(this, Source.Schema, i, Infos[i], Metadata, - out _types[i], out _concat[i], out _bitsPerKey[i]); + host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); - Metadata.Seal(); + return new KeyToBinaryVectorTransform(host, ctx); } - public static KeyToBinaryVectorTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + private KeyToBinaryVectorTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - return h.Apply("Loading Model", ch => new KeyToBinaryVectorTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + public static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) => + new KeyToBinaryVectorTransform(env, input, columns).MakeDataTransform(input); + + // Factory method for SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { - Host.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); - // *** Binary format *** - // - // - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - SaveBase(ctx); + env.CheckValue(args.Column, nameof(args.Column)); + var cols = new ColumnInfo[args.Column.Length]; + using (var ch = env.Start("ValidateArgs")) + { + for (int i = 0; i < cols.Length; i++) + { + var item = args.Column[i]; + cols[i] = new ColumnInfo(item.Source, item.Name); + }; + } + return new KeyToBinaryVectorTransform(env, input, cols).MakeDataTransform(input); } - /// - /// Computes the column type and whether multiple indicator vectors need to be concatenated. - /// Also populates the metadata. - /// - private static void ComputeType(KeyToBinaryVectorTransform trans, ISchema input, int iinfo, - ColInfo info, MetadataDispatcher md, out VectorType type, out bool concat, out int bitsPerColumn) - { - Contracts.AssertValue(trans); - Contracts.AssertValue(input); - Contracts.AssertValue(info); - Contracts.Assert(info.TypeSrc.ItemType.IsKey); - Contracts.Assert(info.TypeSrc.ItemType.KeyCount > 0); + // Factory method for SignatureLoadDataTransform. + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); - //Add an additional bit for all 1s to represent missing values. - bitsPerColumn = Utils.IbitHigh((uint)info.TypeSrc.ItemType.KeyCount) + 2; + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); - Contracts.Assert(bitsPerColumn > 0); + protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, schema); - // See if the source has key names. - var typeNames = input.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, info.Source); - if (typeNames == null || !typeNames.IsKnownSizeVector || !typeNames.ItemType.IsText || - typeNames.VectorSize != info.TypeSrc.ItemType.KeyCount) + private sealed class Mapper : MapperBase + { + private readonly KeyToBinaryVectorTransform _parent; + private readonly ColInfo[] _infos; + private readonly VectorType[] _types; + private readonly int[] _bitsPerKey; + + public Mapper(KeyToBinaryVectorTransform parent, ISchema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + _parent = parent; + _infos = _parent.CreateInfos(inputSchema); + _types = new VectorType[_parent.ColumnPairs.Length]; + _bitsPerKey = new int[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + //Add an additional bit for all 1s to represent missing values. + _bitsPerKey[i] = Utils.IbitHigh((uint)_infos[i].TypeSrc.ItemType.KeyCount) + 2; + Host.Assert(_bitsPerKey[i] > 0); + if (_infos[i].TypeSrc.ValueCount == 1) + // Output is a single vector computed as the sum of the output indicator vectors. + _types[i] = new VectorType(NumberType.Float, _bitsPerKey[i]); + else + // Output is the concatenation of the multiple output indicator vectors. + _types[i] = new VectorType(NumberType.Float, _infos[i].TypeSrc.ValueCount, _bitsPerKey[i]); + } + } + + public override RowMapperColumnInfo[] GetOutputColumns() { - typeNames = null; + var result = new RowMapperColumnInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + Host.Assert(colIndex >= 0); + var colMetaInfo = new ColumnMetadataInfo(_parent.ColumnPairs[i].output); + AddMetadata(i, colMetaInfo); + + result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _types[i], colMetaInfo); + } + return result; } - // Don't pass through any source column metadata. - using (var bldr = md.BuildMetadata(iinfo)) + private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) { - if (info.TypeSrc.ValueCount == 1) + InputSchema.TryGetColumnIndex(_infos[i].Source, out int srcCol); + var srcType = _infos[i].TypeSrc; + // See if the source has key names. + var typeNames = InputSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, srcCol); + if (typeNames == null || !typeNames.IsKnownSizeVector || !typeNames.ItemType.IsText || + typeNames.VectorSize != _infos[i].TypeSrc.ItemType.KeyCount) + { + typeNames = null; + } + + if (_infos[i].TypeSrc.ValueCount == 1) { - // Output is a single vector computed as the sum of the output indicator vectors. - concat = false; - type = new VectorType(NumberType.Float, bitsPerColumn); if (typeNames != null) { - bldr.AddGetter>(MetadataUtils.Kinds.SlotNames, - new VectorType(TextType.Instance, type), trans.GetKeyNames); + MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => + { + GenerateBitSlotName(i, ref dst); + }; + var info = new MetadataInfo>(new VectorType(TextType.Instance, _types[i]), getter); + colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info); } - - bldr.AddPrimitive(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, DvBool.True); + MetadataUtils.MetadataGetter normalizeGetter = (int col, ref DvBool dst) => + { + dst = true; + }; + var normalizeInfo = new MetadataInfo(BoolType.Instance, normalizeGetter); + colMetaInfo.Add(MetadataUtils.Kinds.IsNormalized, normalizeInfo); } else { - // Output is the concatenation of the multiple output indicator vectors. - concat = true; - type = new VectorType(NumberType.Float, info.TypeSrc.ValueCount, bitsPerColumn); - if (typeNames != null && type.VectorSize > 0) + if (typeNames != null && _types[i].IsKnownSizeVector) { - bldr.AddGetter>(MetadataUtils.Kinds.SlotNames, - new VectorType(TextType.Instance, type), trans.GetSlotNames); + MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => + { + GetSlotNames(i, ref dst); + }; + var info = new MetadataInfo>(new VectorType(TextType.Instance, _types[i]), getter); + colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info); } } } - } - - private void GenerateBitSlotName(int iinfo, ref VBuffer dst) - { - const string slotNamePrefix = "Bit"; - var bldr = new BufferBuilder(TextCombiner.Instance); - bldr.Reset(_bitsPerKey[iinfo], true); - for (int i = 0; i < _bitsPerKey[iinfo]; i++) - bldr.AddFeature(i, new DvText(slotNamePrefix + (_bitsPerKey[iinfo] - i - 1))); - - bldr.GetResult(ref dst); - } - - private void GetKeyNames(int iinfo, ref VBuffer dst) - { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - Host.Assert(!_concat[iinfo]); - - GenerateBitSlotName(iinfo, ref dst); - } - private void GetSlotNames(int iinfo, ref VBuffer dst) - { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - Host.Assert(_concat[iinfo]); - Host.Assert(_types[iinfo].IsKnownSizeVector); - - // Variable size should have thrown (by the caller). - var typeSrc = Infos[iinfo].TypeSrc; - Host.Assert(typeSrc.VectorSize > 1); - - // Get the source slot names, defaulting to empty text. - var namesSlotSrc = default(VBuffer); - var typeSlotSrc = Source.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, Infos[iinfo].Source); - if (typeSlotSrc != null && typeSlotSrc.VectorSize == typeSrc.VectorSize && typeSlotSrc.ItemType.IsText) + private void GenerateBitSlotName(int iinfo, ref VBuffer dst) { - Source.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, Infos[iinfo].Source, ref namesSlotSrc); - Host.Check(namesSlotSrc.Length == typeSrc.VectorSize); - } - else - namesSlotSrc = VBufferUtils.CreateEmpty(typeSrc.VectorSize); - - int slotLim = _types[iinfo].VectorSize; - Host.Assert(slotLim == (long)typeSrc.VectorSize * _bitsPerKey[iinfo]); + const string slotNamePrefix = "Bit"; + var bldr = new BufferBuilder(TextCombiner.Instance); + bldr.Reset(_bitsPerKey[iinfo], true); + for (int i = 0; i < _bitsPerKey[iinfo]; i++) + bldr.AddFeature(i, new DvText(slotNamePrefix + (_bitsPerKey[iinfo] - i - 1))); - var values = dst.Values; - if (Utils.Size(values) < slotLim) - values = new DvText[slotLim]; + bldr.GetResult(ref dst); + } - var sb = new StringBuilder(); - int slot = 0; - VBuffer bits = default(VBuffer); - GenerateBitSlotName(iinfo, ref bits); - foreach (var kvpSlot in namesSlotSrc.Items(all: true)) + private void GetSlotNames(int iinfo, ref VBuffer dst) { - Contracts.Assert(slot == (long)kvpSlot.Key * _bitsPerKey[iinfo]); - sb.Clear(); - if (kvpSlot.Value.HasChars) - kvpSlot.Value.AddToStringBuilder(sb); + Host.Assert(0 <= iinfo && iinfo < _infos.Length); + Host.Assert(_types[iinfo].IsKnownSizeVector); + + // Variable size should have thrown (by the caller). + var typeSrc = _infos[iinfo].TypeSrc; + Host.Assert(typeSrc.VectorSize > 1); + + // Get the source slot names, defaulting to empty text. + var namesSlotSrc = default(VBuffer); + InputSchema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol); + Host.Assert(srcCol >= 0); + var typeSlotSrc = InputSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, srcCol); + if (typeSlotSrc != null && typeSlotSrc.VectorSize == typeSrc.VectorSize && typeSlotSrc.ItemType.IsText) + { + InputSchema.GetMetadata(MetadataUtils.Kinds.SlotNames, srcCol, ref namesSlotSrc); + Host.Check(namesSlotSrc.Length == typeSrc.VectorSize); + } else - sb.Append('[').Append(kvpSlot.Key).Append(']'); - sb.Append('.'); + namesSlotSrc = VBufferUtils.CreateEmpty(typeSrc.VectorSize); + + int slotLim = _types[iinfo].VectorSize; + Host.Assert(slotLim == (long)typeSrc.VectorSize * _bitsPerKey[iinfo]); + + var values = dst.Values; + if (Utils.Size(values) < slotLim) + values = new DvText[slotLim]; - int len = sb.Length; - foreach (var key in bits.Values) + var sb = new StringBuilder(); + int slot = 0; + VBuffer bits = default; + GenerateBitSlotName(iinfo, ref bits); + foreach (var kvpSlot in namesSlotSrc.Items(all: true)) { - sb.Length = len; - key.AddToStringBuilder(sb); - values[slot++] = new DvText(sb.ToString()); + Contracts.Assert(slot == (long)kvpSlot.Key * _bitsPerKey[iinfo]); + sb.Clear(); + if (kvpSlot.Value.HasChars) + kvpSlot.Value.AddToStringBuilder(sb); + else + sb.Append('[').Append(kvpSlot.Key).Append(']'); + sb.Append('.'); + + int len = sb.Length; + foreach (var key in bits.Values) + { + sb.Length = len; + key.AddToStringBuilder(sb); + values[slot++] = new DvText(sb.ToString()); + } } + Host.Assert(slot == slotLim); + + dst = new VBuffer(slotLim, values, dst.Indices); } - Host.Assert(slot == slotLim); - dst = new VBuffer(slotLim, values, dst.Indices); - } + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + Host.AssertValue(input); + Host.Assert(0 <= iinfo && iinfo < _infos.Length); + disposer = null; + + var info = _infos[iinfo]; + if (!info.TypeSrc.IsVector) + return MakeGetterOne(input, iinfo); + return MakeGetterInd(input, iinfo); + } - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Host.Assert(0 <= iinfo & iinfo < _types.Length); - return _types[iinfo]; - } + /// + /// This is for the scalar case. + /// + private ValueGetter> MakeGetterOne(IRow input, int iinfo) + { + Host.AssertValue(input); + Host.Assert(_infos[iinfo].TypeSrc.IsKey); + + int bitsPerKey = _bitsPerKey[iinfo]; + Host.Assert(bitsPerKey == _types[iinfo].VectorSize); + + int dstLength = _types[iinfo].VectorSize; + Host.Assert(dstLength > 0); + input.Schema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol); + Host.Assert(srcCol >= 0); + var getSrc = RowCursorUtils.GetGetterAs(NumberType.U4, input, srcCol); + var src = default(uint); + var bldr = new BufferBuilder(R4Adder.Instance); + return + (ref VBuffer dst) => + { + getSrc(ref src); + bldr.Reset(bitsPerKey, false); + EncodeValueToBinary(bldr, src, bitsPerKey, 0); + bldr.GetResult(ref dst); - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) - { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; - - var info = Infos[iinfo]; - if (!info.TypeSrc.IsVector) - return MakeGetterOne(input, iinfo); - return MakeGetterInd(input, iinfo); - } + Contracts.Assert(dst.Length == bitsPerKey); + }; + } - /// - /// This is for the scalar case. - /// - private ValueGetter> MakeGetterOne(IRow input, int iinfo) - { - Host.AssertValue(input); - Host.Assert(Infos[iinfo].TypeSrc.IsKey); + /// + /// This is for the indicator case - vector input and outputs should be concatenated. + /// + private ValueGetter> MakeGetterInd(IRow input, int iinfo) + { + Host.AssertValue(input); + Host.Assert(_infos[iinfo].TypeSrc.IsVector); + Host.Assert(_infos[iinfo].TypeSrc.ItemType.IsKey); + + int cv = _infos[iinfo].TypeSrc.VectorSize; + Host.Assert(cv >= 0); + input.Schema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol); + Host.Assert(srcCol >= 0); + var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.U4, input, srcCol); + var src = default(VBuffer); + var bldr = new BufferBuilder(R4Adder.Instance); + int bitsPerKey = _bitsPerKey[iinfo]; + return + (ref VBuffer dst) => + { + getSrc(ref src); + Host.Check(src.Length == cv || cv == 0); + bldr.Reset(src.Length * bitsPerKey, false); - int bitsPerKey = _bitsPerKey[iinfo]; - Host.Assert(bitsPerKey == _types[iinfo].VectorSize); + int index = 0; + foreach (uint value in src.DenseValues()) + { + EncodeValueToBinary(bldr, value, bitsPerKey, index * bitsPerKey); + index++; + } - int dstLength = _types[iinfo].VectorSize; - Host.Assert(dstLength > 0); + bldr.GetResult(ref dst); - var getSrc = RowCursorUtils.GetGetterAs(NumberType.U4, input, Infos[iinfo].Source); - var src = default(uint); - var bldr = new BufferBuilder(R4Adder.Instance); - return - (ref VBuffer dst) => - { - getSrc(ref src); - bldr.Reset(bitsPerKey, false); - EncodeValueToBinary(bldr, src, bitsPerKey, 0); - bldr.GetResult(ref dst); + Contracts.Assert(dst.Length == src.Length * bitsPerKey); + }; + } - Contracts.Assert(dst.Length == bitsPerKey); - }; - } + private void EncodeValueToBinary(BufferBuilder bldr, uint value, int bitsToConsider, int startIndex) + { + Contracts.Assert(0 < bitsToConsider && bitsToConsider <= sizeof(uint) * 8); + Contracts.Assert(startIndex >= 0); - /// - /// This is for the indicator case - vector input and outputs should be concatenated. - /// - private ValueGetter> MakeGetterInd(IRow input, int iinfo) - { - Host.AssertValue(input); - Host.Assert(Infos[iinfo].TypeSrc.IsVector); - Host.Assert(Infos[iinfo].TypeSrc.ItemType.IsKey); - - int cv = Infos[iinfo].TypeSrc.VectorSize; - Host.Assert(cv >= 0); - - var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.U4, input, Infos[iinfo].Source); - var src = default(VBuffer); - var bldr = new BufferBuilder(R4Adder.Instance); - int bitsPerKey = _bitsPerKey[iinfo]; - return - (ref VBuffer dst) => - { - getSrc(ref src); - Host.Check(src.Length == cv || cv == 0); - bldr.Reset(src.Length * bitsPerKey, false); + //Treat missing values, zero, as a special value of all 1s. + value--; + while (bitsToConsider > 0) + bldr.AddFeature(startIndex++, (value >> --bitsToConsider) & 1U); + } + } + } - int index = 0; - foreach (uint value in src.DenseValues()) - { - EncodeValueToBinary(bldr, value, bitsPerKey, index * bitsPerKey); - index++; - } + public sealed class KeyToBinaryVectorEstimator : IEstimator + { + private readonly IHost _host; + private readonly KeyToBinaryVectorTransform.ColumnInfo[] _columns; - bldr.GetResult(ref dst); + public KeyToBinaryVectorEstimator(IHostEnvironment env, params KeyToBinaryVectorTransform.ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(KeyToBinaryVectorEstimator)); + _columns = columns; + } - Contracts.Assert(dst.Length == src.Length * bitsPerKey); - }; + public KeyToBinaryVectorEstimator(IHostEnvironment env, string name, string source = null) : + this(env, new KeyToBinaryVectorTransform.ColumnInfo(source ?? name, name)) + { } - private void EncodeValueToBinary(BufferBuilder bldr, uint value, int bitsToConsider, int startIndex) + public SchemaShape GetOutputSchema(SchemaShape inputSchema) { - Contracts.Assert(0 < bitsToConsider && bitsToConsider <= sizeof(uint) * 8); - Contracts.Assert(startIndex >= 0); + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in _columns) + { + if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + + var metadata = new List(); + if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var keyMeta)) + if (col.Kind != SchemaShape.Column.VectorKind.VariableVector && keyMeta.ItemType.IsText) + metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, keyMeta.ItemType, false)); + if (col.Kind == SchemaShape.Column.VectorKind.Scalar) + metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(metadata)); + } - //Treat missing values, zero, as a special value of all 1s. - value--; - while (bitsToConsider > 0) - bldr.AddFeature(startIndex++, (value >> --bitsToConsider) & 1U); + return new SchemaShape(result.Values); } + + public KeyToBinaryVectorTransform Fit(IDataView input) => new KeyToBinaryVectorTransform(_host, input, _columns); } + } diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs new file mode 100644 index 0000000000..ebbd19b204 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs @@ -0,0 +1,152 @@ +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using System.IO; +using System.Linq; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class KeyToBinaryVectorEstimatorTest : TestDataPipeBase + { + public KeyToBinaryVectorEstimatorTest(ITestOutputHelper output) : base(output) + { + } + class TestClass + { + public int A; + public int B; + public int C; + } + class TestMeta + { + [VectorType(2)] + public string[] A; + public string B; + [VectorType(2)] + public int[] C; + public int D; + [VectorType(2)] + public float[] E; + public float F; + [VectorType(2)] + public string[] G; + public string H; + } + + [Fact] + public void KeyToVectorWorkout() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + dataView = new TermEstimator(Env, new[]{ + new TermTransform.ColumnInfo("A", "TermA"), + new TermTransform.ColumnInfo("B", "TermB"), + new TermTransform.ColumnInfo("C", "TermC", textKeyValues:true) + }).Fit(dataView).Transform(dataView); + + var pipe = new KeyToBinaryVectorEstimator(Env, new KeyToBinaryVectorTransform.ColumnInfo("TermA", "CatA"), + new KeyToBinaryVectorTransform.ColumnInfo("TermC", "CatC")); + TestEstimatorCore(pipe, dataView); + } + + [Fact] + void TestMetadataPropagation() + { + var data = new[] { + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E= new float[2] { 1.0f,2.0f}, F = 1.0f , G= new string[2]{ "A","D"}, H="D"}, + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 5,3}, D= 1, E=new float[2] { 3.0f,4.0f}, F = -1.0f ,G= new string[2]{"E", "A"}, H="E"}, + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E=new float[2] { 5.0f,6.0f}, F = 1.0f ,G= new string[2]{ "D", "E"}, H="D"} }; + + + var dataView = ComponentCreation.CreateDataView(Env, data); + var termEst = new TermEstimator(Env, + new TermTransform.ColumnInfo("A", "TA", textKeyValues: true), + new TermTransform.ColumnInfo("B", "TB", textKeyValues: true), + new TermTransform.ColumnInfo("C", "TC"), + new TermTransform.ColumnInfo("D", "TD")); + var termTransformer = termEst.Fit(dataView); + dataView = termTransformer.Transform(dataView); + + var pipe = new KeyToBinaryVectorEstimator(Env, + new KeyToBinaryVectorTransform.ColumnInfo("TA", "CatA"), + new KeyToBinaryVectorTransform.ColumnInfo("TB", "CatB"), + new KeyToBinaryVectorTransform.ColumnInfo("TC", "CatC"), + new KeyToBinaryVectorTransform.ColumnInfo("TD", "CatD")); + + var result = pipe.Fit(dataView).Transform(dataView); + ValidateMetadata(result); + } + + void ValidateMetadata(IDataView result) + { + Assert.True(result.Schema.TryGetColumnIndex("CatA", out int colA)); + Assert.True(result.Schema.TryGetColumnIndex("CatB", out int colB)); + Assert.True(result.Schema.TryGetColumnIndex("CatC", out int colC)); + Assert.True(result.Schema.TryGetColumnIndex("CatD", out int colD)); + var types = result.Schema.GetMetadataTypes(colA); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.SlotNames }); + VBuffer slots = default; + DvBool normalized = default; + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colA, ref slots); + Assert.True(slots.Length == 6); + Assert.Equal(slots.Values.Select(x => x.ToString()), new string[6] { "[0].Bit2", "[0].Bit1", "[0].Bit0", "[1].Bit2", "[1].Bit1", "[1].Bit0" }); + + types = result.Schema.GetMetadataTypes(colB); + Assert.Equal(types.Select(x => x.Key), new string[2] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colB, ref slots); + Assert.True(slots.Length == 2); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "Bit1", "Bit0" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colB, ref normalized); + Assert.True(normalized.IsTrue); + + types = result.Schema.GetMetadataTypes(colC); + Assert.Equal(types.Select(x => x.Key), new string[0]); + + types = result.Schema.GetMetadataTypes(colD); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colD, ref normalized); + Assert.True(normalized.IsTrue); + } + + [Fact] + void TestCommandLine() + { + using (var env = new TlcEnvironment()) + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToBinary{col=C:B} in=f:\2.txt" }), (int)0); + } + } + + [Fact] + void TestOldSavingAndLoading() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var est = new TermEstimator(Env, new[]{ + new TermTransform.ColumnInfo("A", "TermA"), + new TermTransform.ColumnInfo("B", "TermB", textKeyValues:true), + new TermTransform.ColumnInfo("C", "TermC") + }); + var transformer = est.Fit(dataView); + dataView = transformer.Transform(dataView); + var pipe = new KeyToBinaryVectorEstimator(Env, + new KeyToBinaryVectorTransform.ColumnInfo("TermA", "CatA"), + new KeyToBinaryVectorTransform.ColumnInfo("TermB", "CatB"), + new KeyToBinaryVectorTransform.ColumnInfo("TermC", "CatC") + ); + var result = pipe.Fit(dataView).Transform(dataView); + var resultRoles = new RoleMappedData(result); + using (var ms = new MemoryStream()) + { + TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles); + ms.Position = 0; + var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms); + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index 35dc90db5f..b458ee645b 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -56,7 +56,9 @@ public void KeyToVectorWorkout() }).Fit(dataView).Transform(dataView); var pipe = new KeyToVectorEstimator(Env, new KeyToVectorTransform.ColumnInfo("TermA", "CatA", false), - new KeyToVectorTransform.ColumnInfo("TermB", "CatB", true)); + new KeyToVectorTransform.ColumnInfo("TermB", "CatB", true), + new KeyToVectorTransform.ColumnInfo("TermC", "CatC", true), + new KeyToVectorTransform.ColumnInfo("TermC", "CatCNonBag", false)); TestEstimatorCore(pipe, dataView); } @@ -96,6 +98,7 @@ void TestMetadataPropagation() var result = pipe.Fit(dataView).Transform(dataView); ValidateMetadata(result); } + void ValidateMetadata(IDataView result) { Assert.True(result.Schema.TryGetColumnIndex("CatA", out int colA)); @@ -178,7 +181,7 @@ void TestCommandLine() { using (var env = new TlcEnvironment()) { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToVector{col=C:B} in=f:\2.txt" }), (int)0); + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToVector{col=C:B col={name=D source=B bag+}} in=f:\2.txt" }), (int)0); } } From 29ec05590c7a3a7743c49ee6a4fce34d671b8d01 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 10 Sep 2018 13:29:03 -0700 Subject: [PATCH 07/17] address some comments --- .../Transforms/KeyToVectorTransform.cs | 13 ++------- .../Transforms/TermTransform.cs | 16 +++++------ .../KeyToBinaryVectorTransform.cs | 15 +++++------ .../KeyToBinaryVectorEstimatorTest.cs | 27 +++++++++---------- .../Transformers/KeyToVectorEstimatorTests.cs | 13 ++++----- 5 files changed, 35 insertions(+), 49 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 2317887d8d..2f23e525e0 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -128,14 +128,11 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum return columns.Select(x => (x.Input, x.Output)).ToArray(); } - //REVIEW: This and method below need to go to base class as it get created. - private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}."; - private string TestIsKey(ColumnType type) { if (type.ItemType.KeyCount > 0) return null; - return "Expected Key type of known cardinality"; + return "key type of known cardinality"; } private ColInfo[] CreateInfos(ISchema schema) @@ -149,7 +146,7 @@ private ColInfo[] CreateInfos(ISchema schema) var type = schema.GetColumnType(colSrc); string reason = TestIsKey(type); if (reason != null) - throw Host.ExceptUserArg(nameof(ColumnPairs), InvalidTypeErrorFormat, ColumnPairs[i].input, type, reason); + throw Host.ExceptSchemaMismatch(nameof(ColumnPairs), "input", ColumnPairs[i].input, reason, type.ToString()); infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type); } return infos; @@ -190,9 +187,6 @@ public override void Save(ModelSaveContext ctx) // // for each added column // byte: bag as 0/1 - // for each added column - // int: keyCount - // int: valueCount ctx.Writer.Write(sizeof(float)); SaveColumns(ctx); @@ -228,9 +222,6 @@ private KeyToVectorTransform(IHost host, ModelLoadContext ctx) // // for each added column // byte: bag as 0/1 - // for each added column - // int: keyCount - // int: valueCount _bags = new bool[columnsLength]; _bags = ctx.Reader.ReadBoolArray(columnsLength); } diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index 503afb738d..b800bb3317 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -241,6 +241,13 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum return columns.Select(x => (x.Input, x.Output)).ToArray(); } + internal static string TestIsKnownDataKind(ColumnType type) + { + if (type.ItemType.RawKind != default && (type.IsVector || type.IsPrimitive)) + return null; + return "standard type or a vector of standard type"; + } + private ColInfo[] CreateInfos(ISchema schema) { Host.AssertValue(schema); @@ -252,7 +259,7 @@ private ColInfo[] CreateInfos(ISchema schema) var type = schema.GetColumnType(colSrc); string reason = TestIsKnownDataKind(type); if (reason != null) - throw Host.ExceptUserArg(nameof(ColumnPairs), InvalidTypeErrorFormat, ColumnPairs[i].input, type, reason); + throw Host.ExceptSchemaMismatch(nameof(ColumnPairs), "input", ColumnPairs[i].input, reason, type.ToString()); infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type); } return infos; @@ -429,13 +436,6 @@ public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, Co }, input); } - internal static string TestIsKnownDataKind(ColumnType type) - { - if (type.ItemType.RawKind != default && (type.IsVector || type.IsPrimitive)) - return null; - return "Expected standard type or a vector of standard type"; - } - /// /// Utility method to create the file-based . /// diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index f35ef84e30..21d63e76d3 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -82,28 +82,25 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum return columns.Select(x => (x.Input, x.Output)).ToArray(); } - //REVIEW: This and method below need to go to base class as it get created. - private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}."; - private string TestIsKey(ColumnType type) { if (type.ItemType.KeyCount > 0) return null; - return "Expected Key type of known cardinality"; + return "key type of known cardinality"; } - private ColInfo[] CreateInfos(ISchema schema) + private ColInfo[] CreateInfos(ISchema inputSchema) { - Host.AssertValue(schema); + Host.AssertValue(inputSchema); var infos = new ColInfo[ColumnPairs.Length]; for (int i = 0; i < ColumnPairs.Length; i++) { - if (!schema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) + if (!inputSchema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) throw Host.ExceptUserArg(nameof(ColumnPairs), "Source column '{0}' not found", ColumnPairs[i].input); - var type = schema.GetColumnType(colSrc); + var type = inputSchema.GetColumnType(colSrc); string reason = TestIsKey(type); if (reason != null) - throw Host.ExceptUserArg(nameof(ColumnPairs), InvalidTypeErrorFormat, ColumnPairs[i].input, type, reason); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input, reason, type.ToString()); infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type); } return infos; diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs index ebbd19b204..35338cf263 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs @@ -15,13 +15,15 @@ public class KeyToBinaryVectorEstimatorTest : TestDataPipeBase public KeyToBinaryVectorEstimatorTest(ITestOutputHelper output) : base(output) { } - class TestClass + + private class TestClass { public int A; public int B; public int C; } - class TestMeta + + private class TestMeta { [VectorType(2)] public string[] A; @@ -29,12 +31,6 @@ class TestMeta [VectorType(2)] public int[] C; public int D; - [VectorType(2)] - public float[] E; - public float F; - [VectorType(2)] - public string[] G; - public string H; } [Fact] @@ -52,15 +48,16 @@ public void KeyToVectorWorkout() var pipe = new KeyToBinaryVectorEstimator(Env, new KeyToBinaryVectorTransform.ColumnInfo("TermA", "CatA"), new KeyToBinaryVectorTransform.ColumnInfo("TermC", "CatC")); TestEstimatorCore(pipe, dataView); + Done(); } [Fact] - void TestMetadataPropagation() + public void TestMetadataPropagation() { var data = new[] { - new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E= new float[2] { 1.0f,2.0f}, F = 1.0f , G= new string[2]{ "A","D"}, H="D"}, - new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 5,3}, D= 1, E=new float[2] { 3.0f,4.0f}, F = -1.0f ,G= new string[2]{"E", "A"}, H="E"}, - new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E=new float[2] { 5.0f,6.0f}, F = 1.0f ,G= new string[2]{ "D", "E"}, H="D"} }; + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6}, + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 5,3}, D= 1}, + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6} }; var dataView = ComponentCreation.CreateDataView(Env, data); @@ -82,7 +79,7 @@ void TestMetadataPropagation() ValidateMetadata(result); } - void ValidateMetadata(IDataView result) + private void ValidateMetadata(IDataView result) { Assert.True(result.Schema.TryGetColumnIndex("CatA", out int colA)); Assert.True(result.Schema.TryGetColumnIndex("CatB", out int colB)); @@ -114,7 +111,7 @@ void ValidateMetadata(IDataView result) } [Fact] - void TestCommandLine() + public void TestCommandLine() { using (var env = new TlcEnvironment()) { @@ -123,7 +120,7 @@ void TestCommandLine() } [Fact] - void TestOldSavingAndLoading() + public void TestOldSavingAndLoading() { var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var dataView = ComponentCreation.CreateDataView(Env, data); diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index b458ee645b..5f8f77ae73 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -20,14 +20,14 @@ public KeyToVectorEstimatorTest(ITestOutputHelper output) : base(output) { } - class TestClass + private class TestClass { public int A; public int B; public int C; } - class TestMeta + private class TestMeta { [VectorType(2)] public string[] A; @@ -60,10 +60,11 @@ public void KeyToVectorWorkout() new KeyToVectorTransform.ColumnInfo("TermC", "CatC", true), new KeyToVectorTransform.ColumnInfo("TermC", "CatCNonBag", false)); TestEstimatorCore(pipe, dataView); + Done(); } [Fact] - void TestMetadataPropagation() + public void TestMetadataPropagation() { var data = new[] { new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E= new float[2] { 1.0f,2.0f}, F = 1.0f , G= new string[2]{ "A","D"}, H="D"}, @@ -99,7 +100,7 @@ void TestMetadataPropagation() ValidateMetadata(result); } - void ValidateMetadata(IDataView result) + private void ValidateMetadata(IDataView result) { Assert.True(result.Schema.TryGetColumnIndex("CatA", out int colA)); Assert.True(result.Schema.TryGetColumnIndex("CatB", out int colB)); @@ -177,7 +178,7 @@ void ValidateMetadata(IDataView result) } [Fact] - void TestCommandLine() + public void TestCommandLine() { using (var env = new TlcEnvironment()) { @@ -186,7 +187,7 @@ void TestCommandLine() } [Fact] - void TestOldSavingAndLoading() + public void TestOldSavingAndLoading() { var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var dataView = ComponentCreation.CreateDataView(Env, data); From 847bdc4e03d38188e8dad0532bd1bf84d7f08d05 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 10 Sep 2018 13:53:50 -0700 Subject: [PATCH 08/17] more cleanup --- .../Transforms/KeyToVectorTransform.cs | 25 ++++++++++--------- .../Transforms/TermTransform.cs | 23 ++++++----------- .../KeyToBinaryVectorTransform.cs | 2 +- 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 2f23e525e0..8425453a9d 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -135,18 +135,18 @@ private string TestIsKey(ColumnType type) return "key type of known cardinality"; } - private ColInfo[] CreateInfos(ISchema schema) + private ColInfo[] CreateInfos(ISchema inputSchema) { - Host.AssertValue(schema); + Host.AssertValue(inputSchema); var infos = new ColInfo[ColumnPairs.Length]; for (int i = 0; i < ColumnPairs.Length; i++) { - if (!schema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) - throw Host.ExceptUserArg(nameof(ColumnPairs), "Source column '{0}' not found", ColumnPairs[i].input); - var type = schema.GetColumnType(colSrc); + if (!inputSchema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input); + var type = inputSchema.GetColumnType(colSrc); string reason = TestIsKey(type); if (reason != null) - throw Host.ExceptSchemaMismatch(nameof(ColumnPairs), "input", ColumnPairs[i].input, reason, type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input, reason, type.ToString()); infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type); } return infos; @@ -170,7 +170,8 @@ private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "KEY2VECT", - verWrittenCur: 0x00010001, // Initial + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Get rid of writing float size in model context verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature); @@ -181,13 +182,10 @@ public override void Save(ModelSaveContext ctx) Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - // *** Binary format *** - // int: sizeof(Float) // // for each added column // byte: bag as 0/1 - ctx.Writer.Write(sizeof(float)); SaveColumns(ctx); Host.Assert(_bags.Length == ColumnPairs.Length); @@ -209,8 +207,11 @@ public static KeyToVectorTransform Create(IHostEnvironment env, ModelLoadContext private static ModelLoadContext ReadFloatFromCtx(IHostEnvironment env, ModelLoadContext ctx) { - int cbFloat = ctx.Reader.ReadInt32(); - env.CheckDecode(cbFloat == sizeof(float)); + if (ctx.Header.ModelVerWritten == 0x00010001) + { + int cbFloat = ctx.Reader.ReadInt32(); + env.CheckDecode(cbFloat == sizeof(float)); + } return ctx; } diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index b800bb3317..942f34df01 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -241,25 +241,25 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum return columns.Select(x => (x.Input, x.Output)).ToArray(); } - internal static string TestIsKnownDataKind(ColumnType type) + internal string TestIsKnownDataKind(ColumnType type) { if (type.ItemType.RawKind != default && (type.IsVector || type.IsPrimitive)) return null; return "standard type or a vector of standard type"; } - private ColInfo[] CreateInfos(ISchema schema) + private ColInfo[] CreateInfos(ISchema inputSchema) { - Host.AssertValue(schema); + Host.AssertValue(inputSchema); var infos = new ColInfo[ColumnPairs.Length]; for (int i = 0; i < ColumnPairs.Length; i++) { - if (!schema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) - throw Host.ExceptUserArg(nameof(ColumnPairs), "Source column '{0}' not found", ColumnPairs[i].input); - var type = schema.GetColumnType(colSrc); + if (!inputSchema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input); + var type = inputSchema.GetColumnType(colSrc); string reason = TestIsKnownDataKind(type); if (reason != null) - throw Host.ExceptSchemaMismatch(nameof(ColumnPairs), "input", ColumnPairs[i].input, reason, type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input, reason, type.ToString()); infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type); } return infos; @@ -407,9 +407,6 @@ public static IDataView Create(IHostEnvironment env, int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort) => new TermTransform(env, input, new[] { new ColumnInfo(source ?? name, name, maxNumTerms, sort) }).MakeDataTransform(input); - //REVIEW: This and static method below need to go to base class as it get created. - private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}."; - public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, ColumnBase[] column, IDataView input) { return Create(env, new Arguments() @@ -714,12 +711,6 @@ public override void Save(ModelSaveContext ctx) protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, schema); - protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) - { - if ((inputSchema.GetColumnType(srcCol).ItemType.RawKind == default)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, "image", inputSchema.GetColumnType(srcCol).ToString()); - } - private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa { private readonly ColumnType[] _types; diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index 21d63e76d3..395151faeb 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -96,7 +96,7 @@ private ColInfo[] CreateInfos(ISchema inputSchema) for (int i = 0; i < ColumnPairs.Length; i++) { if (!inputSchema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) - throw Host.ExceptUserArg(nameof(ColumnPairs), "Source column '{0}' not found", ColumnPairs[i].input); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input); var type = inputSchema.GetColumnType(colSrc); string reason = TestIsKey(type); if (reason != null) From 91e60ed1adec923327816acc17e1306d99158dcc Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 10 Sep 2018 14:38:33 -0700 Subject: [PATCH 09/17] move to trivial estimator --- .../Transforms/KeyToVectorTransform.cs | 140 +++++++++--------- .../KeyToBinaryVectorTransform.cs | 108 ++++++++------ 2 files changed, 129 insertions(+), 119 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 8425453a9d..704b7dafe6 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -99,28 +99,10 @@ public ColumnInfo(string input, string output, bool bag = KeyToVectorEstimator.D } } - internal sealed class ColInfo - { - public readonly string Name; - public readonly string Source; - public readonly ColumnType TypeSrc; - - public ColInfo(string name, string source, ColumnType type) - { - Name = name; - Source = source; - TypeSrc = type; - } - } - private const string RegistrationName = "KeyToVector"; - /// - /// _bags indicates whether vector inputs should have their output indicator vectors added - /// (instead of concatenated). This is faithful to what the user specified in the Arguments - /// and is persisted. - /// - private readonly bool[] _bags; + public IReadOnlyCollection Columns => _columns.AsReadOnly(); + private readonly ColumnInfo[] _columns; private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) { @@ -135,31 +117,18 @@ private string TestIsKey(ColumnType type) return "key type of known cardinality"; } - private ColInfo[] CreateInfos(ISchema inputSchema) + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Host.AssertValue(inputSchema); - var infos = new ColInfo[ColumnPairs.Length]; - for (int i = 0; i < ColumnPairs.Length; i++) - { - if (!inputSchema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input); - var type = inputSchema.GetColumnType(colSrc); - string reason = TestIsKey(type); - if (reason != null) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input, reason, type.ToString()); - infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type); - } - return infos; + var type = inputSchema.GetColumnType(srcCol); + string reason = TestIsKey(type); + if (reason != null) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, reason, type.ToString()); } - public KeyToVectorTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) : + public KeyToVectorTransform(IHostEnvironment env, params ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { - // Validate input schema - CreateInfos(input.Schema); - _bags = new bool[ColumnPairs.Length]; - for (int i = 0; i < ColumnPairs.Length; i++) - _bags[i] = columns[i].Bag; + _columns = columns.ToArray(); } public const string LoaderSignature = "KeyToVectorTransform"; @@ -188,9 +157,9 @@ public override void Save(ModelSaveContext ctx) // byte: bag as 0/1 SaveColumns(ctx); - Host.Assert(_bags.Length == ColumnPairs.Length); - for (int i = 0; i < _bags.Length; i++) - ctx.Writer.WriteBoolByte(_bags[i]); + Host.Assert(_columns.Length == ColumnPairs.Length); + for (int i = 0; i < _columns.Length; i++) + ctx.Writer.WriteBoolByte(_columns[i].Bag); } // Factory method for SignatureLoadModel. @@ -223,12 +192,16 @@ private KeyToVectorTransform(IHost host, ModelLoadContext ctx) // // for each added column // byte: bag as 0/1 - _bags = new bool[columnsLength]; - _bags = ctx.Reader.ReadBoolArray(columnsLength); + var bags = new bool[columnsLength]; + bags = ctx.Reader.ReadBoolArray(columnsLength); + + _columns = new ColumnInfo[columnsLength]; + for (int i = 0; i < columnsLength; i++) + _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, bags[i]); } public static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) => - new KeyToVectorTransform(env, input, columns).MakeDataTransform(input); + new KeyToVectorTransform(env, columns).MakeDataTransform(input); // Factory method for SignatureDataTransform. public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) @@ -250,7 +223,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV item.Bag ?? args.Bag); }; } - return new KeyToVectorTransform(env, input, cols).MakeDataTransform(input); + return new KeyToVectorTransform(env, cols).MakeDataTransform(input); } // Factory method for SignatureLoadDataTransform. @@ -265,6 +238,20 @@ public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISch private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa { + private sealed class ColInfo + { + public readonly string Name; + public readonly string Source; + public readonly ColumnType TypeSrc; + + public ColInfo(string name, string source, ColumnType type) + { + Name = name; + Source = source; + TypeSrc = type; + } + } + private readonly KeyToVectorTransform _parent; private readonly ColInfo[] _infos; private readonly VectorType[] _types; @@ -273,7 +260,7 @@ public Mapper(KeyToVectorTransform parent, ISchema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; - _infos = _parent.CreateInfos(inputSchema); + _infos = CreateInfos(inputSchema); _types = new VectorType[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { @@ -284,6 +271,20 @@ public Mapper(KeyToVectorTransform parent, ISchema inputSchema) } } + private ColInfo[] CreateInfos(ISchema inputSchema) + { + Host.AssertValue(inputSchema); + var infos = new ColInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + var type = inputSchema.GetColumnType(colSrc); + infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); + } + return infos; + } + public override RowMapperColumnInfo[] GetOutputColumns() { var result = new RowMapperColumnInfo[_parent.ColumnPairs.Length]; @@ -308,7 +309,7 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) { typeNames = null; } - if (_parent._bags[i] || _infos[i].TypeSrc.ValueCount == 1) + if (_parent._columns[i].Bag || _infos[i].TypeSrc.ValueCount == 1) { if (typeNames != null) { @@ -333,7 +334,7 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) } } - if (!_parent._bags[i] && srcType.ValueCount > 0) + if (!_parent._columns[i].Bag && srcType.ValueCount > 0) { MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => { @@ -343,7 +344,7 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo) colMetaInfo.Add(MetadataUtils.Kinds.CategoricalSlotRanges, info); } - if (!_parent._bags[i] || srcType.ValueCount == 1) + if (!_parent._columns[i].Bag || srcType.ValueCount == 1) { MetadataUtils.MetadataGetter getter = (int col, ref DvBool dst) => { @@ -449,7 +450,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, out Action dispose var info = _infos[iinfo]; if (!info.TypeSrc.IsVector) return MakeGetterOne(input, iinfo); - if (_parent._bags[iinfo]) + if (_parent._columns[iinfo].Bag) return MakeGetterBag(input, iinfo); return MakeGetterInd(input, iinfo); } @@ -501,7 +502,7 @@ private ValueGetter> MakeGetterBag(IRow input, int iinfo) Host.AssertValue(input); Host.Assert(_infos[iinfo].TypeSrc.IsVector); Host.Assert(_infos[iinfo].TypeSrc.ItemType.IsKey); - Host.Assert(_parent._bags[iinfo]); + Host.Assert(_parent._columns[iinfo].Bag); Host.Assert(_infos[iinfo].TypeSrc.ItemType.KeyCount == _types[iinfo].VectorSize); var info = _infos[iinfo]; @@ -545,7 +546,7 @@ private ValueGetter> MakeGetterInd(IRow input, int iinfo) Host.AssertValue(input); Host.Assert(_infos[iinfo].TypeSrc.IsVector); Host.Assert(_infos[iinfo].TypeSrc.ItemType.IsKey); - Host.Assert(!_parent._bags[iinfo]); + Host.Assert(!_parent._columns[iinfo].Bag); var info = _infos[iinfo]; int size = info.TypeSrc.ItemType.KeyCount; @@ -677,7 +678,7 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToke return PfaUtils.Call("cast.fanoutDouble", srcToken, 0, keyCount, false); JToken arrType = PfaUtils.Type.Array(PfaUtils.Type.Double); - if (_parent._bags[iinfo] || info.TypeSrc.ValueCount == 1) + if (_parent._columns[iinfo].Bag || info.TypeSrc.ValueCount == 1) { // The concatenation case. We can still use fanout, but we just append them all together. return PfaUtils.Call("a.flatMap", srcToken, @@ -715,37 +716,38 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src } } - public sealed class KeyToVectorEstimator : IEstimator + public sealed class KeyToVectorEstimator : TrivialEstimator { - private readonly IHost _host; - private readonly KeyToVectorTransform.ColumnInfo[] _columns; public static class Defaults { public const bool Bag = false; } public KeyToVectorEstimator(IHostEnvironment env, params KeyToVectorTransform.ColumnInfo[] columns) + : this(env, new KeyToVectorTransform(env, columns)) + { + } + + public KeyToVectorEstimator(IHostEnvironment env, string name, string source = null, bool bag = Defaults.Bag) + : this(env, new KeyToVectorTransform(env, new KeyToVectorTransform.ColumnInfo(source ?? name, name, bag))) { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register(nameof(KeyToVectorEstimator)); - _columns = columns; } - public KeyToVectorEstimator(IHostEnvironment env, string name, string source = null, bool bag = Defaults.Bag) : - this(env, new KeyToVectorTransform.ColumnInfo(source ?? name, name, bag)) + public KeyToVectorEstimator(IHostEnvironment env, KeyToVectorTransform transformer) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(KeyToVectorEstimator)), transformer) { } - public SchemaShape GetOutputSchema(SchemaShape inputSchema) + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { - _host.CheckValue(inputSchema, nameof(inputSchema)); + Host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.Columns.ToDictionary(x => x.Name); - foreach (var colInfo in _columns) + foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var keyMeta)) @@ -761,7 +763,5 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } - - public KeyToVectorTransform Fit(IDataView input) => new KeyToVectorTransform(_host, input, _columns); } } diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index 395151faeb..5cbd4a9385 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -47,23 +47,10 @@ public ColumnInfo(string input, string output) } } - internal sealed class ColInfo - { - public readonly string Name; - public readonly string Source; - public readonly ColumnType TypeSrc; - - public ColInfo(string name, string source, ColumnType type) - { - Name = name; - Source = source; - TypeSrc = type; - } - } - internal const string Summary = "Converts a key column to a binary encoded vector."; public const string UserName = "KeyToBinaryVectorTransform"; public const string LoaderSignature = "KeyToBinaryTransform"; + private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -81,6 +68,8 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum Contracts.CheckValue(columns, nameof(columns)); return columns.Select(x => (x.Input, x.Output)).ToArray(); } + public IReadOnlyCollection Columns => _columns.AsReadOnly(); + private readonly ColumnInfo[] _columns; private string TestIsKey(ColumnType type) { @@ -89,28 +78,19 @@ private string TestIsKey(ColumnType type) return "key type of known cardinality"; } - private ColInfo[] CreateInfos(ISchema inputSchema) + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Host.AssertValue(inputSchema); - var infos = new ColInfo[ColumnPairs.Length]; - for (int i = 0; i < ColumnPairs.Length; i++) - { - if (!inputSchema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input); - var type = inputSchema.GetColumnType(colSrc); - string reason = TestIsKey(type); - if (reason != null) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input, reason, type.ToString()); - infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type); - } - return infos; + var type = inputSchema.GetColumnType(srcCol); + string reason = TestIsKey(type); + if (reason != null) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, reason, type.ToString()); } - public KeyToBinaryVectorTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) + public KeyToBinaryVectorTransform(IHostEnvironment env, params ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { - // Validate input schema. - CreateInfos(input.Schema); + _columns = columns.ToArray(); + } public override void Save(ModelSaveContext ctx) @@ -140,10 +120,13 @@ public static KeyToBinaryVectorTransform Create(IHostEnvironment env, ModelLoadC private KeyToBinaryVectorTransform(IHost host, ModelLoadContext ctx) : base(host, ctx) { + _columns = new ColumnInfo[ColumnPairs.Length]; + for (int i = 0; i < ColumnPairs.Length; i++) + _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output); } public static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) => - new KeyToBinaryVectorTransform(env, input, columns).MakeDataTransform(input); + new KeyToBinaryVectorTransform(env, columns).MakeDataTransform(input); // Factory method for SignatureDataTransform. public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) @@ -162,7 +145,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV cols[i] = new ColumnInfo(item.Source, item.Name); }; } - return new KeyToBinaryVectorTransform(env, input, cols).MakeDataTransform(input); + return new KeyToBinaryVectorTransform(env, cols).MakeDataTransform(input); } // Factory method for SignatureLoadDataTransform. @@ -177,6 +160,20 @@ public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISch private sealed class Mapper : MapperBase { + private sealed class ColInfo + { + public readonly string Name; + public readonly string Source; + public readonly ColumnType TypeSrc; + + public ColInfo(string name, string source, ColumnType type) + { + Name = name; + Source = source; + TypeSrc = type; + } + } + private readonly KeyToBinaryVectorTransform _parent; private readonly ColInfo[] _infos; private readonly VectorType[] _types; @@ -186,7 +183,7 @@ public Mapper(KeyToBinaryVectorTransform parent, ISchema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; - _infos = _parent.CreateInfos(inputSchema); + _infos = CreateInfos(inputSchema); _types = new VectorType[_parent.ColumnPairs.Length]; _bitsPerKey = new int[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) @@ -202,6 +199,20 @@ public Mapper(KeyToBinaryVectorTransform parent, ISchema inputSchema) _types[i] = new VectorType(NumberType.Float, _infos[i].TypeSrc.ValueCount, _bitsPerKey[i]); } } + private ColInfo[] CreateInfos(ISchema inputSchema) + { + Host.AssertValue(inputSchema); + var infos = new ColInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + var type = inputSchema.GetColumnType(colSrc); + + infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); + } + return infos; + } public override RowMapperColumnInfo[] GetOutputColumns() { @@ -421,33 +432,34 @@ private void EncodeValueToBinary(BufferBuilder bldr, uint value, int bits } } - public sealed class KeyToBinaryVectorEstimator : IEstimator + public sealed class KeyToBinaryVectorEstimator : TrivialEstimator { - private readonly IHost _host; - private readonly KeyToBinaryVectorTransform.ColumnInfo[] _columns; public KeyToBinaryVectorEstimator(IHostEnvironment env, params KeyToBinaryVectorTransform.ColumnInfo[] columns) + : this(env, new KeyToBinaryVectorTransform(env, columns)) { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register(nameof(KeyToBinaryVectorEstimator)); - _columns = columns; } - public KeyToBinaryVectorEstimator(IHostEnvironment env, string name, string source = null) : - this(env, new KeyToBinaryVectorTransform.ColumnInfo(source ?? name, name)) + public KeyToBinaryVectorEstimator(IHostEnvironment env, string name, string source = null) + : this(env, new KeyToBinaryVectorTransform(env, new KeyToBinaryVectorTransform.ColumnInfo(source ?? name, name))) { } - public SchemaShape GetOutputSchema(SchemaShape inputSchema) + public KeyToBinaryVectorEstimator(IHostEnvironment env, KeyToBinaryVectorTransform transformer) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(KeyToBinaryVectorEstimator)), transformer) { - _host.CheckValue(inputSchema, nameof(inputSchema)); + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.Columns.ToDictionary(x => x.Name); - foreach (var colInfo in _columns) + foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var keyMeta)) @@ -460,8 +472,6 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } - - public KeyToBinaryVectorTransform Fit(IDataView input) => new KeyToBinaryVectorTransform(_host, input, _columns); } } From f71b1591747e4115dec8769d99a8771d9fa63564 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 10 Sep 2018 15:57:39 -0700 Subject: [PATCH 10/17] I ain't afraid of no pigs! --- .../Transforms/KeyToVectorTransform.cs | 98 +++++++++++++++++++ .../KeyToBinaryVectorTransform.cs | 68 +++++++++++++ .../KeyToBinaryVectorEstimatorTest.cs | 40 +++++++- .../Transformers/KeyToVectorEstimatorTests.cs | 32 ++++++ 4 files changed, 236 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 704b7dafe6..6f753311a2 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Text; using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -764,4 +765,101 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } } + + /// + /// Extension methods for the static-pipeline over objects. + /// + public static class KeyToVectorExtensions + { + private const bool DefaultBag = KeyToVectorEstimator.Defaults.Bag; + private struct Config + { + public readonly bool Bag; + public Config(bool bag) + { + Bag = bag; + } + } + + private interface IColInput + { + PipelineColumn Input { get; } + Config Config { get; } + } + + private sealed class OutKeyColumn : Key, IColInput + { + public PipelineColumn Input { get; } + public Config Config { get; } + + public OutKeyColumn(PipelineColumn input, Config config) + : base(Reconciler.Inst, input) + { + Input = input; + Config = config; + } + + } + + private sealed class OutVectorColumn : Vector, IColInput + { + public PipelineColumn Input { get; } + public Config Config { get; } + + public OutVectorColumn(Vector> input, Config config) + : base(Reconciler.Inst, input) + { + Input = input; + Config = config; + } + + public OutVectorColumn(Key input, Config config) + : base(Reconciler.Inst, input) + { + Input = input; + Config = config; + } + } + + private sealed class Reconciler : EstimatorReconciler + { + public static Reconciler Inst = new Reconciler(); + + private Reconciler() { } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + var infos = new KeyToVectorTransform.ColumnInfo[toOutput.Length]; + for (int i = 0; i < toOutput.Length; ++i) + { + var col = (IColInput)toOutput[i]; + infos[i] = new KeyToVectorTransform.ColumnInfo(inputNames[col.Input], outputNames[toOutput[i]], col.Config.Bag); + } + return new KeyToVectorEstimator(env, infos); + } + } + + /// + /// Takes a column of key type of known cardinality and produces an indicator vector of floats. + /// + public static Vector ToVector(this Key input, bool bag = DefaultBag) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input, new Config(bag)); + } + + /// + /// Takes a column of key type of known cardinality and produces an indicator vector of floats. + /// + public static Vector ToVector(this Vector> input, bool bag = DefaultBag) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input, new Config(bag)); + } + + } } diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index 5cbd4a9385..bc6750da3b 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Text; using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -473,5 +474,72 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } } + /// + /// Extension methods for the static-pipeline over objects. + /// + public static class KeyToVectorExtensions + { + private interface IColInput + { + PipelineColumn Input { get; } + } + + private sealed class OutVectorColumn : Vector, IColInput + { + public PipelineColumn Input { get; } + + public OutVectorColumn(Vector> input) + : base(Reconciler.Inst, input) + { + Input = input; + } + + public OutVectorColumn(Key input) + : base(Reconciler.Inst, input) + { + Input = input; + } + } + + private sealed class Reconciler : EstimatorReconciler + { + public static Reconciler Inst = new Reconciler(); + + private Reconciler() { } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + var infos = new KeyToBinaryVectorTransform.ColumnInfo[toOutput.Length]; + for (int i = 0; i < toOutput.Length; ++i) + { + var col = (IColInput)toOutput[i]; + infos[i] = new KeyToBinaryVectorTransform.ColumnInfo(inputNames[col.Input], outputNames[toOutput[i]]); + } + return new KeyToBinaryVectorEstimator(env, infos); + } + } + /// + /// Takes a column of key type of known cardinality and produces a binary encoded indicator vector of zeros and ones. + /// + public static Vector ToBinaryVector(this Key input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input); + } + + /// + /// Takes a column of key type of known cardinality and produces a binary encoded indicator vector of zeros and ones. + /// + public static Vector ToBinaryVector(this Vector> input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input); + } + + } } diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs index 35338cf263..98aa0b0841 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs @@ -1,4 +1,9 @@ -using Microsoft.ML.Runtime.Api; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data.StaticPipe; +using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.RunTests; @@ -34,7 +39,7 @@ private class TestMeta } [Fact] - public void KeyToVectorWorkout() + public void KeyToBinaryVectorWorkout() { var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; @@ -51,6 +56,37 @@ public void KeyToVectorWorkout() Done(); } + [Fact] + public void KeyToBinaryVectorPigsty() + { + string dataPath = GetDataPath("breast-cancer.txt"); + var reader = TextLoader.CreateReader(Env, ctx => ( + ScalarString: ctx.LoadText(1), + VectorString: ctx.LoadText(1, 4) + )); + + var data = reader.Read(new MultiFileSource(dataPath)); + + // Non-pigsty Term. + var dynamicData = new TermEstimator(Env, + new TermTransform.ColumnInfo("ScalarString", "A"), + new TermTransform.ColumnInfo("VectorString", "B")) + .Fit(data.AsDynamic).Transform(data.AsDynamic); + + var data2 = dynamicData.AssertStatic(Env, ctx => ( + A: ctx.KeyU4.TextValues.Scalar, + B: ctx.KeyU4.TextValues.Vector)); + + var est = data2.MakeNewEstimator() + .Append(row => ( + ScalarString: row.A.ToBinaryVector(), + VectorString: row.B.ToBinaryVector())); + + TestEstimatorCore(est.AsDynamic, data2.AsDynamic, invalidInput: data.AsDynamic); + + Done(); + } + [Fact] public void TestMetadataPropagation() { diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index 5f8f77ae73..e8a23c1adf 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Data.StaticPipe; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Model; @@ -63,6 +64,37 @@ public void KeyToVectorWorkout() Done(); } + [Fact] + public void KeyToVectorPigsty() + { + string dataPath = GetDataPath("breast-cancer.txt"); + var reader = TextLoader.CreateReader(Env, ctx => ( + ScalarString: ctx.LoadText(1), + VectorString: ctx.LoadText(1, 4) + )); + + var data = reader.Read(new MultiFileSource(dataPath)); + + // Non-pigsty Term. + var dynamicData = new TermEstimator(Env, + new TermTransform.ColumnInfo("ScalarString", "A"), + new TermTransform.ColumnInfo("VectorString", "B")) + .Fit(data.AsDynamic).Transform(data.AsDynamic); + + var data2 = dynamicData.AssertStatic(Env, ctx => ( + A: ctx.KeyU4.TextValues.Scalar, + B: ctx.KeyU4.TextValues.Vector)); + + var est = data2.MakeNewEstimator() + .Append(row => ( + ScalarString: row.A.ToVector(bag: true), + VectorString: row.B.ToVector())); + + TestEstimatorCore(est.AsDynamic, data2.AsDynamic, invalidInput: data.AsDynamic); + + Done(); + } + [Fact] public void TestMetadataPropagation() { From bbbc7145f0ce1b33e8837adb957eba43f86b6ff6 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 10 Sep 2018 16:12:59 -0700 Subject: [PATCH 11/17] small fixes --- src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs | 3 +-- src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs | 2 +- .../Transformers/KeyToVectorEstimatorTests.cs | 1 + 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 6f753311a2..c20501e3c5 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -31,7 +31,6 @@ namespace Microsoft.ML.Runtime.Data { - public sealed class KeyToVectorTransform : OneToOneTransformerBase { public abstract class ColumnBase : OneToOneColumn @@ -219,7 +218,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV { var item = args.Column[i]; - cols[i] = new ColumnInfo(item.Source, + cols[i] = new ColumnInfo(item.Source??item.Name, item.Name, item.Bag ?? args.Bag); }; diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index bc6750da3b..4209a704e5 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -143,7 +143,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - cols[i] = new ColumnInfo(item.Source, item.Name); + cols[i] = new ColumnInfo(item.Source ?? item.Name, item.Name); }; } return new KeyToBinaryVectorTransform(env, cols).MakeDataTransform(input); diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index e8a23c1adf..62242d01f8 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -88,6 +88,7 @@ public void KeyToVectorPigsty() var est = data2.MakeNewEstimator() .Append(row => ( ScalarString: row.A.ToVector(bag: true), + ScalarStringWithBag: row.A.ToVector(bag: true), VectorString: row.B.ToVector())); TestEstimatorCore(est.AsDynamic, data2.AsDynamic, invalidInput: data.AsDynamic); From a21832ea9387241c105d1b8ce69d6a0133e412b0 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Tue, 11 Sep 2018 10:43:46 -0700 Subject: [PATCH 12/17] address comments --- .../Transforms/KeyToVectorTransform.cs | 22 ++++--------------- .../KeyToBinaryVectorTransform.cs | 12 +++++----- .../KeyToBinaryVectorEstimatorTest.cs | 3 ++- .../Transformers/KeyToVectorEstimatorTests.cs | 3 ++- 4 files changed, 14 insertions(+), 26 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index c20501e3c5..a799fd8864 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -218,7 +218,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV { var item = args.Column[i]; - cols[i] = new ColumnInfo(item.Source??item.Name, + cols[i] = new ColumnInfo(item.Source ?? item.Name, item.Name, item.Bag ?? args.Bag); }; @@ -786,21 +786,7 @@ private interface IColInput Config Config { get; } } - private sealed class OutKeyColumn : Key, IColInput - { - public PipelineColumn Input { get; } - public Config Config { get; } - - public OutKeyColumn(PipelineColumn input, Config config) - : base(Reconciler.Inst, input) - { - Input = input; - Config = config; - } - - } - - private sealed class OutVectorColumn : Vector, IColInput + private sealed class OutVectorColumn : Vector, IColInput { public PipelineColumn Input { get; } public Config Config { get; } @@ -845,7 +831,7 @@ public override IEstimator Reconcile(IHostEnvironment env, /// /// Takes a column of key type of known cardinality and produces an indicator vector of floats. /// - public static Vector ToVector(this Key input, bool bag = DefaultBag) + public static Vector ToVector(this Key input, bool bag = DefaultBag) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input, new Config(bag)); @@ -854,7 +840,7 @@ public static Vector ToVector(this Key input /// /// Takes a column of key type of known cardinality and produces an indicator vector of floats. /// - public static Vector ToVector(this Vector> input, bool bag = DefaultBag) + public static Vector ToVector(this Vector> input, bool bag = DefaultBag) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input, new Config(bag)); diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index 4209a704e5..28652316ef 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -477,14 +477,14 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) /// /// Extension methods for the static-pipeline over objects. /// - public static class KeyToVectorExtensions + public static class KeyToBinaryVectorExtensions { private interface IColInput { PipelineColumn Input { get; } } - private sealed class OutVectorColumn : Vector, IColInput + private sealed class OutVectorColumn : Vector, IColInput { public PipelineColumn Input { get; } @@ -524,18 +524,18 @@ public override IEstimator Reconcile(IHostEnvironment env, } /// - /// Takes a column of key type of known cardinality and produces a binary encoded indicator vector of zeros and ones. + /// Takes a column of key type of known cardinality and produces a vector of bits representing the key in binary form. /// - public static Vector ToBinaryVector(this Key input) + public static Vector ToBinaryVector(this Key input) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input); } /// - /// Takes a column of key type of known cardinality and produces a binary encoded indicator vector of zeros and ones. + /// Takes a column of key type of known cardinality and produces a vector of bits representing the key in binary form. /// - public static Vector ToBinaryVector(this Vector> input) + public static Vector ToBinaryVector(this Vector> input) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input); diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs index 98aa0b0841..f3d7d149bf 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs @@ -57,7 +57,7 @@ public void KeyToBinaryVectorWorkout() } [Fact] - public void KeyToBinaryVectorPigsty() + public void KeyToBinaryVectorStatic() { string dataPath = GetDataPath("breast-cancer.txt"); var reader = TextLoader.CreateReader(Env, ctx => ( @@ -113,6 +113,7 @@ public void TestMetadataPropagation() var result = pipe.Fit(dataView).Transform(dataView); ValidateMetadata(result); + Done(); } private void ValidateMetadata(IDataView result) diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index 62242d01f8..4872ad1bdc 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -65,7 +65,7 @@ public void KeyToVectorWorkout() } [Fact] - public void KeyToVectorPigsty() + public void KeyToVectorStatic() { string dataPath = GetDataPath("breast-cancer.txt"); var reader = TextLoader.CreateReader(Env, ctx => ( @@ -131,6 +131,7 @@ public void TestMetadataPropagation() var result = pipe.Fit(dataView).Transform(dataView); ValidateMetadata(result); + Done(); } private void ValidateMetadata(IDataView result) From ac641d61199fb54e1a3783c89f3c4b0124a09fc1 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Tue, 11 Sep 2018 13:51:56 -0700 Subject: [PATCH 13/17] address some Tom comments --- .../Transforms/KeyToVectorTransform.cs | 59 +++++++++++++++++-- .../KeyToBinaryVectorTransform.cs | 23 ++++++++ .../StaticPipeFakes.cs | 9 --- .../BaseTestBaseline.cs | 6 +- .../Transformers/KeyToVectorEstimatorTests.cs | 7 ++- 5 files changed, 84 insertions(+), 20 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index a799fd8864..22e5805cbc 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -264,7 +264,7 @@ public Mapper(KeyToVectorTransform parent, ISchema inputSchema) _types = new VectorType[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - if (_infos[i].TypeSrc.ValueCount == 1) + if (_parent._columns[i].Bag|| _infos[i].TypeSrc.ValueCount == 1) _types[i] = new VectorType(NumberType.Float, _infos[i].TypeSrc.ItemType.KeyCount); else _types[i] = new VectorType(NumberType.Float, _infos[i].TypeSrc.ValueCount, _infos[i].TypeSrc.ItemType.KeyCount); @@ -804,6 +804,13 @@ public OutVectorColumn(Key input, Config config) Input = input; Config = config; } + + public OutVectorColumn(VarVector> input, Config config) + : base(Reconciler.Inst, input) + { + Input = input; + Config = config; + } } private sealed class Reconciler : EstimatorReconciler @@ -830,20 +837,62 @@ public override IEstimator Reconcile(IHostEnvironment env, /// /// Takes a column of key type of known cardinality and produces an indicator vector of floats. + /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, + /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. + /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. + /// + public static Vector ToVector(this Key input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input, new Config(false)); + } + + /// + /// Takes a column of key type of known cardinality and produces an indicator vector of floats. + /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, + /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. + /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. /// - public static Vector ToVector(this Key input, bool bag = DefaultBag) + public static Vector ToVector(this Vector> input) { Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input, new Config(bag)); + return new OutVectorColumn (input, new Config(false)); } /// /// Takes a column of key type of known cardinality and produces an indicator vector of floats. + /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, + /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. + /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. + /// In this case then the indicator vectors for all values in the column will be simply added together, + /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, + /// the output column will be a vector type of length equal to that cardinality. /// - public static Vector ToVector(this Vector> input, bool bag = DefaultBag) + public static Vector ToVector(this VarVector> input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn (input, new Config(false)); + } + + /// + /// Takes a column of key type of known cardinality and produces an indicator vector of floats. + /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, + /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. + /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. + /// In this case then the indicator vectors for all values in the column will be simply added together, + /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, + /// the output column will be a vector type of length equal to that cardinality. + /// + public static Vector ToBaggedVector(this Vector> input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input, new Config(true)); + } + + public static Vector ToBaggedVector(this VarVector> input) { Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input, new Config(bag)); + return new OutVectorColumn(input, new Config(true)); } } diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index 28652316ef..17b30eb9cb 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -499,6 +499,12 @@ public OutVectorColumn(Key input) { Input = input; } + + public OutVectorColumn(VarVector> input) + : base(Reconciler.Inst, input) + { + Input = input; + } } private sealed class Reconciler : EstimatorReconciler @@ -525,6 +531,9 @@ public override IEstimator Reconcile(IHostEnvironment env, /// /// Takes a column of key type of known cardinality and produces a vector of bits representing the key in binary form. + /// The first value is encoded as all zeros and missing values are encoded as all ones. + /// In the case where a vector has multiple keys, the encoded values are concatenated. + /// Number of bits per key is determined as the number of bits needed to represent the cardinality of the keys plus one. /// public static Vector ToBinaryVector(this Key input) { @@ -534,6 +543,9 @@ public static Vector ToBinaryVector(this Key /// /// Takes a column of key type of known cardinality and produces a vector of bits representing the key in binary form. + /// The first value is encoded as all zeros and missing values are encoded as all ones. + /// In the case where a vector has multiple keys, the encoded values are concatenated. + /// Number of bits per key is determined as the number of bits needed to represent the cardinality of the keys plus one. /// public static Vector ToBinaryVector(this Vector> input) { @@ -541,5 +553,16 @@ public static Vector ToBinaryVector(this Vector(input); } + /// + /// Takes a column of key type of known cardinality and produces a vector of bits representing the key in binary form. + /// The first value is encoded as all zeros and missing values are encoded as all ones. + /// In the case where a vector has multiple keys, the encoded values are concatenated. + /// Number of bits per key is determined as the number of bits needed to represent the cardinality of the keys plus one. + /// + public static Vector ToBinaryVector(this VarVector> input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input); + } } } diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs index 8313d94e7f..403cde05ea 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs @@ -160,15 +160,6 @@ public static VarVector> Dictionarize(this VarVector me) => _rec.VarVector>(me); } - public static class KeyToVectorTransformExtensions - { - private static FakeTransformReconciler _rec = new FakeTransformReconciler("KeyToVector"); - - public static Vector BagVectorize(this VarVector> me) - => _rec.Vector(me); - public static Vector BagVectorize(this VarVector> me) - => _rec.Vector(me); - } public static class TextTransformExtensions { diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs index f57b50bdab..0861104814 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs @@ -26,11 +26,11 @@ public abstract partial class BaseTestBaseline : BaseTestClass, IDisposable { private readonly ITestOutputHelper _output; - protected BaseTestBaseline(ITestOutputHelper helper): base(helper) + protected BaseTestBaseline(ITestOutputHelper helper) : base(helper) { _output = helper; - ITest test = (ITest)helper.GetType().GetField("test", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(helper); - TestName = test.TestCase.TestMethod.Method.Name; + ITest test = (ITest)helper.GetType().GetField("test", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(helper); + TestName = test.TestCase.TestMethod.TestClass.Class.Name + "." + test.TestCase.TestMethod.Method.Name; Init(); } diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index 4872ad1bdc..9520966be3 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -87,9 +87,10 @@ public void KeyToVectorStatic() var est = data2.MakeNewEstimator() .Append(row => ( - ScalarString: row.A.ToVector(bag: true), - ScalarStringWithBag: row.A.ToVector(bag: true), - VectorString: row.B.ToVector())); + ScalarString: row.A.ToVector(), + VectorString: row.B.ToVector(), + VectorBaggedString: row.B.ToBaggedVector() + )); TestEstimatorCore(est.AsDynamic, data2.AsDynamic, invalidInput: data.AsDynamic); From 27c20e0837546a921e919bf74cdad44408425ccb Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Tue, 11 Sep 2018 14:13:50 -0700 Subject: [PATCH 14/17] some more work --- .../Transforms/KeyToVectorTransform.cs | 66 +++++++++++-------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 22e5805cbc..d79ac779fe 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -264,7 +264,7 @@ public Mapper(KeyToVectorTransform parent, ISchema inputSchema) _types = new VectorType[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - if (_parent._columns[i].Bag|| _infos[i].TypeSrc.ValueCount == 1) + if (_parent._columns[i].Bag || _infos[i].TypeSrc.ValueCount == 1) _types[i] = new VectorType(NumberType.Float, _infos[i].TypeSrc.ItemType.KeyCount); else _types[i] = new VectorType(NumberType.Float, _infos[i].TypeSrc.ValueCount, _infos[i].TypeSrc.ItemType.KeyCount); @@ -770,46 +770,49 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) /// public static class KeyToVectorExtensions { - private const bool DefaultBag = KeyToVectorEstimator.Defaults.Bag; - private struct Config - { - public readonly bool Bag; - public Config(bool bag) - { - Bag = bag; - } - } - private interface IColInput { PipelineColumn Input { get; } - Config Config { get; } + bool Bag { get; } } private sealed class OutVectorColumn : Vector, IColInput { public PipelineColumn Input { get; } - public Config Config { get; } + public bool Bag { get; } - public OutVectorColumn(Vector> input, Config config) - : base(Reconciler.Inst, input) + public OutVectorColumn(Key input) + : base(Reconciler.Inst, input) { Input = input; - Config = config; + Bag = false; } - public OutVectorColumn(Key input, Config config) - : base(Reconciler.Inst, input) + public OutVectorColumn(Vector> input, bool bag) + : base(Reconciler.Inst, input) { Input = input; - Config = config; + Bag = bag; } - public OutVectorColumn(VarVector> input, Config config) + public OutVectorColumn(VarVector> input) : base(Reconciler.Inst, input) { Input = input; - Config = config; + Bag = true; + } + } + + private sealed class OutVarVectorColumn : VarVector, IColInput + { + public PipelineColumn Input { get; } + public bool Bag { get; } + + public OutVarVectorColumn(VarVector> input) + : base(Reconciler.Inst, input) + { + Input = input; + Bag = false; } } @@ -844,7 +847,7 @@ public override IEstimator Reconcile(IHostEnvironment env, public static Vector ToVector(this Key input) { Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input, new Config(false)); + return new OutVectorColumn(input); } /// @@ -856,7 +859,7 @@ public static Vector ToVector(this Key input) public static Vector ToVector(this Vector> input) { Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn (input, new Config(false)); + return new OutVectorColumn(input, false); } /// @@ -868,10 +871,10 @@ public static Vector ToVector(this Vector /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, /// the output column will be a vector type of length equal to that cardinality. /// - public static Vector ToVector(this VarVector> input) + public static VarVector ToVector(this VarVector> input) { Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn (input, new Config(false)); + return new OutVarVectorColumn(input); } /// @@ -886,13 +889,22 @@ public static Vector ToVector(this VarVector ToBaggedVector(this Vector> input) { Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input, new Config(true)); + return new OutVectorColumn(input, true); } + /// + /// Takes a column of key type of known cardinality and produces an indicator vector of floats. + /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, + /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. + /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. + /// In this case then the indicator vectors for all values in the column will be simply added together, + /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, + /// the output column will be a vector type of length equal to that cardinality. + /// public static Vector ToBaggedVector(this VarVector> input) { Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input, new Config(true)); + return new OutVectorColumn(input); } } From d4abb9b301c1917ad5309375f298ebfc980ae84a Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Tue, 11 Sep 2018 15:47:02 -0700 Subject: [PATCH 15/17] Key and Key --- .../Transforms/KeyToVectorTransform.cs | 110 +++++++++++++++++- .../KeyToBinaryVectorTransform.cs | 75 +++++++++++- 2 files changed, 180 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index d79ac779fe..efa519e05e 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -816,6 +816,46 @@ public OutVarVectorColumn(VarVector> input) } } + private sealed class OutVectorColumn : Vector, IColInput + { + public PipelineColumn Input { get; } + public bool Bag { get; } + + public OutVectorColumn(Key input) + : base(Reconciler.Inst, input) + { + Input = input; + Bag = false; + } + + public OutVectorColumn(Vector> input, bool bag) + : base(Reconciler.Inst, input) + { + Input = input; + Bag = bag; + } + + public OutVectorColumn(VarVector> input) + : base(Reconciler.Inst, input) + { + Input = input; + Bag = true; + } + } + + private sealed class OutVarVectorColumn : VarVector, IColInput + { + public PipelineColumn Input { get; } + public bool Bag { get; } + + public OutVarVectorColumn(VarVector> input) + : base(Reconciler.Inst, input) + { + Input = input; + Bag = false; + } + } + private sealed class Reconciler : EstimatorReconciler { public static Reconciler Inst = new Reconciler(); @@ -832,7 +872,7 @@ public override IEstimator Reconcile(IHostEnvironment env, for (int i = 0; i < toOutput.Length; ++i) { var col = (IColInput)toOutput[i]; - infos[i] = new KeyToVectorTransform.ColumnInfo(inputNames[col.Input], outputNames[toOutput[i]], col.Config.Bag); + infos[i] = new KeyToVectorTransform.ColumnInfo(inputNames[col.Input], outputNames[toOutput[i]], col.Bag); } return new KeyToVectorEstimator(env, infos); } @@ -907,5 +947,73 @@ public static Vector ToBaggedVector(this VarVector(input); } + /// + /// Takes a column of key type of known cardinality and produces an indicator vector of floats. + /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, + /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. + /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. + /// + public static Vector ToVector(this Key input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input); + } + + /// + /// Takes a column of key type of known cardinality and produces an indicator vector of floats. + /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, + /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. + /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. + /// + public static Vector ToVector(this Vector> input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input, false); + } + + /// + /// Takes a column of key type of known cardinality and produces an indicator vector of floats. + /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, + /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. + /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. + /// In this case then the indicator vectors for all values in the column will be simply added together, + /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, + /// the output column will be a vector type of length equal to that cardinality. + /// + public static VarVector ToVector(this VarVector> input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVarVectorColumn(input); + } + + /// + /// Takes a column of key type of known cardinality and produces an indicator vector of floats. + /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, + /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. + /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. + /// In this case then the indicator vectors for all values in the column will be simply added together, + /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, + /// the output column will be a vector type of length equal to that cardinality. + /// + public static Vector ToBaggedVector(this Vector> input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input, true); + } + + /// + /// Takes a column of key type of known cardinality and produces an indicator vector of floats. + /// Each key value of the input is used to create an indicator vector: the indicator vector is the length of the key cardinality, + /// where all values are 0, except for the entry corresponding to the value of the key, which is 1. + /// If the key value is missing, then all values are 0. Naturally this tends to generate very sparse vectors. + /// In this case then the indicator vectors for all values in the column will be simply added together, + /// to produce the final vector with type equal to the key cardinality; so, in all cases, whether vector or scalar, + /// the output column will be a vector type of length equal to that cardinality. + /// + public static Vector ToBaggedVector(this VarVector> input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input); + } } } diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index 17b30eb9cb..52a49b1563 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -499,9 +499,40 @@ public OutVectorColumn(Key input) { Input = input; } + } - public OutVectorColumn(VarVector> input) - : base(Reconciler.Inst, input) + private sealed class OutVarVectorColumn : VarVector, IColInput + { + public PipelineColumn Input { get; } + public OutVarVectorColumn(VarVector> input) + : base(Reconciler.Inst, input) + { + Input = input; + } + } + + private sealed class OutVectorColumn : Vector, IColInput + { + public PipelineColumn Input { get; } + + public OutVectorColumn(Vector> input) + : base(Reconciler.Inst, input) + { + Input = input; + } + + public OutVectorColumn(Key input) + : base(Reconciler.Inst, input) + { + Input = input; + } + } + + private sealed class OutVarVectorColumn : VarVector, IColInput + { + public PipelineColumn Input { get; } + public OutVarVectorColumn(VarVector> input) + : base(Reconciler.Inst, input) { Input = input; } @@ -559,10 +590,46 @@ public static Vector ToBinaryVector(this Vector - public static Vector ToBinaryVector(this VarVector> input) + public static VarVector ToBinaryVector(this VarVector> input) { Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input); + return new OutVarVectorColumn(input); + } + + /// + /// Takes a column of key type of known cardinality and produces a vector of bits representing the key in binary form. + /// The first value is encoded as all zeros and missing values are encoded as all ones. + /// In the case where a vector has multiple keys, the encoded values are concatenated. + /// Number of bits per key is determined as the number of bits needed to represent the cardinality of the keys plus one. + /// + public static Vector ToBinaryVector(this Key input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input); + } + + /// + /// Takes a column of key type of known cardinality and produces a vector of bits representing the key in binary form. + /// The first value is encoded as all zeros and missing values are encoded as all ones. + /// In the case where a vector has multiple keys, the encoded values are concatenated. + /// Number of bits per key is determined as the number of bits needed to represent the cardinality of the keys plus one. + /// + public static Vector ToBinaryVector(this Vector> input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input); + } + + /// + /// Takes a column of key type of known cardinality and produces a vector of bits representing the key in binary form. + /// The first value is encoded as all zeros and missing values are encoded as all ones. + /// In the case where a vector has multiple keys, the encoded values are concatenated. + /// Number of bits per key is determined as the number of bits needed to represent the cardinality of the keys plus one. + /// + public static VarVector ToBinaryVector(this VarVector> input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVarVectorColumn(input); } } } From bbc2dc264208492f50c7b67d040be678769c1109 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Tue, 11 Sep 2018 16:50:32 -0700 Subject: [PATCH 16/17] now proper merge --- src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs index f8a8ba7152..5b88cf960f 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs @@ -93,7 +93,7 @@ private static IDataView ApplyKeyToVec(List ktv // when the user has slightly different key values between the training and testing set. // The solution is to apply KeyToValue, then Term using the terms from the key metadata of the original key column // and finally the KeyToVector transform. - viewTrain = new KeyToValueTransform(host, ktv.Select(x => (x.Source ?? x.Name, x.Name)).ToArray()) + viewTrain = new KeyToValueTransform(host, ktv.Select(x => (x.Input , x.Output)).ToArray()) .Transform(viewTrain); viewTrain = TermTransform.Create(host, From 191e5ce2df54356ae634a8ea82e51a59f6125c3e Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Tue, 11 Sep 2018 17:07:49 -0700 Subject: [PATCH 17/17] Fix tests --- test/Microsoft.ML.TestFramework/BaseTestBaseline.cs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs index 0861104814..f26c40b7e0 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs @@ -30,7 +30,8 @@ protected BaseTestBaseline(ITestOutputHelper helper) : base(helper) { _output = helper; ITest test = (ITest)helper.GetType().GetField("test", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(helper); - TestName = test.TestCase.TestMethod.TestClass.Class.Name + "." + test.TestCase.TestMethod.Method.Name; + FullTestName = test.TestCase.TestMethod.TestClass.Class.Name + "." + test.TestCase.TestMethod.Method.Name; + TestName = test.TestCase.TestMethod.Method.Name; Init(); } @@ -92,6 +93,7 @@ void IDisposable.Dispose() private bool _passed; public string TestName { get; set; } + public string FullTestName { get; set; } public void Init() { @@ -103,7 +105,7 @@ public void Init() // Find the sample data and baselines. _baseDir = Path.Combine(RootDir, _baselineRootRelPath); - string logPath = Path.Combine(logDir, TestName + LogSuffix); + string logPath = Path.Combine(logDir, FullTestName + LogSuffix); LogWriter = OpenWriter(logPath); _passed = true; Env = new TlcEnvironment(42, outWriter: LogWriter, errWriter: LogWriter);