diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 70afe195b0..5352040958 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -2,10 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; using Microsoft.ML.Core.Data; using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; @@ -16,11 +12,15 @@ using Microsoft.ML.Runtime.Model.Onnx; using Microsoft.ML.Runtime.Model.Pfa; using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; [assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(IDataTransform), typeof(KeyToVectorTransform), typeof(KeyToVectorTransform.Arguments), typeof(SignatureDataTransform), "Key To Vector Transform", KeyToVectorTransform.UserName, "KeyToVector", "ToVector", DocName = "transform/KeyToVectorTransform.md")] -[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(IDataView), typeof(KeyToVectorTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(IDataTransform), typeof(KeyToVectorTransform), null, typeof(SignatureLoadDataTransform), "Key To Vector Transform", KeyToVectorTransform.LoaderSignature)] [assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(KeyToVectorTransform), null, typeof(SignatureLoadModel), @@ -733,7 +733,7 @@ public KeyToVectorEstimator(IHostEnvironment env, string name, string source = n { } - public KeyToVectorEstimator(IHostEnvironment env, KeyToVectorTransform transformer) + private KeyToVectorEstimator(IHostEnvironment env, KeyToVectorTransform transformer) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(KeyToVectorEstimator)), transformer) { } diff --git a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs index 44e873e9e7..446140573c 100644 --- a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs @@ -6,16 +6,31 @@ using Microsoft.ML.Data.StaticPipe.Runtime; using System; using System.Collections.Generic; -using System.Collections.Immutable; using System.Linq; namespace Microsoft.ML.Runtime.Data { public sealed class TermEstimator : IEstimator { + public static class Defaults + { + public const int MaxNumTerms = 1000000; + public const TermTransform.SortOrder Sort = TermTransform.SortOrder.Occurrence; + } + private readonly IHost _host; private readonly TermTransform.ColumnInfo[] _columns; - public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = TermTransform.Defaults.MaxNumTerms, TermTransform.SortOrder sort = TermTransform.Defaults.Sort) : + + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Name of the output column. + /// Name of the column to be transformed. If this is null '' will be used. + /// Maximum number of terms to keep per column when auto-training. + /// How items should be ordered when vectorized. By default, they will be in the order encountered. + /// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a'). + public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) : this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort)) { } @@ -47,7 +62,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) if (!col.IsKey || !col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector) { kv = new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, - col.ItemType, col.IsKey); + colInfo.TextKeyValues ? TextType.Instance : col.ItemType, col.IsKey); } Contracts.AssertValue(kv); @@ -90,7 +105,7 @@ public sealed class ToKeyFitResult // At the moment this is empty. Once PR #863 clears, we can change this class to hold the output // key-values metadata. - internal ToKeyFitResult(TermTransform.TermMap map) + public ToKeyFitResult(TermTransform.TermMap map) { } } @@ -101,8 +116,8 @@ public static partial class TermStaticExtensions // Raw generics would allow illegal possible inputs, e.g., Scalar. So, this is a partial // class, and all the public facing extension methods for each possible type are in a T4 generated result. - private const KeyValueOrder DefSort = (KeyValueOrder)TermTransform.Defaults.Sort; - private const int DefMax = TermTransform.Defaults.MaxNumTerms; + private const KeyValueOrder DefSort = (KeyValueOrder)TermEstimator.Defaults.Sort; + private const int DefMax = TermEstimator.Defaults.MaxNumTerms; private struct Config { @@ -176,7 +191,7 @@ public override IEstimator Reconcile(IHostEnvironment env, Pipelin { var infos = new TermTransform.ColumnInfo[toOutput.Length]; Action onFit = null; - for (int i=0; i public static IDataView Create(IHostEnvironment env, IDataView input, string name, string source = null, - int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort) => + int maxNumTerms = TermEstimator.Defaults.MaxNumTerms, SortOrder sort = TermEstimator.Defaults.Sort) => new TermTransform(env, input, new[] { new ColumnInfo(source ?? name, name, maxNumTerms, sort) }).MakeDataTransform(input); public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, ColumnBase[] column, IDataView input) @@ -710,7 +704,7 @@ public override void Save(ModelSaveContext ctx) }); } - internal TermMap GetTermMap(int iinfo) + public TermMap GetTermMap(int iinfo) { Contracts.Assert(0 <= iinfo && iinfo < _unboundMaps.Length); return _unboundMaps[iinfo]; diff --git a/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs b/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs index 8a63045c68..fcc0155617 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs @@ -470,7 +470,7 @@ private static BoundTermMap Bind(IHostEnvironment env, ISchema schema, TermMap u /// These are the immutable and serializable analogs to the used in /// training. /// - internal abstract class TermMap + public abstract class TermMap { /// /// The item type of the input type, that is, either the input type or, @@ -501,9 +501,9 @@ protected TermMap(PrimitiveType type, int count) OutputType = new KeyType(DataKind.U4, 0, Count == 0 ? 1 : Count); } - public abstract void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory); + internal abstract void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory); - public static TermMap Load(ModelLoadContext ctx, IHostEnvironment ectx, CodecFactory codecFactory) + internal static TermMap Load(ModelLoadContext ctx, IHostEnvironment ectx, CodecFactory codecFactory) { // *** Binary format *** // byte: map type code @@ -610,7 +610,7 @@ public static TextImpl Create(ModelLoadContext ctx, IExceptionContext ectx) return new TextImpl(pool); } - public override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory) + internal override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory) { // *** Binary format *** // byte: map type code, in this case 'Text' (0) @@ -685,7 +685,7 @@ public HashArrayImpl(PrimitiveType itemType, HashArray values) _values = values; } - public override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory) + internal override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory) { // *** Binary format *** // byte: map type code, in this case 'Codec' @@ -757,7 +757,7 @@ public override void WriteTextTerms(TextWriter writer) } } - internal abstract class TermMap : TermMap + public abstract class TermMap : TermMap { protected TermMap(PrimitiveType type, int count) : base(type, count) diff --git a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs index 42d1f4d310..582586b9a0 100644 --- a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs +++ b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - -using System; +using System.Collections.Generic; using System.Linq; using System.Text; using Microsoft.ML.Runtime; @@ -12,7 +10,6 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Internal.Internallearn; [assembly: LoadableClass(CategoricalHashTransform.Summary, typeof(IDataTransform), typeof(CategoricalHashTransform), typeof(CategoricalHashTransform.Arguments), typeof(SignatureDataTransform), CategoricalHashTransform.UserName, "CategoricalHashTransform", "CatHashTransform", "CategoricalHash", "CatHash")] @@ -62,14 +59,11 @@ protected override bool TryParse(string str) // We accept N:B:S where N is the new column name, B is the number of bits, // and S is source column names. - string extra; - if (!base.TryParse(str, out extra)) + if (!TryParse(str, out string extra)) return false; if (extra == null) return true; - - int bits; - if (!int.TryParse(extra, out bits)) + if (!int.TryParse(extra, out int bits)) return false; HashBits = bits; return true; @@ -201,14 +195,81 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV }; } - return CategoricalTransform.CreateTransformCore( + return CreateTransformCore( args.OutputKind, args.Column, args.Column.Select(col => col.OutputKind).ToList(), new HashTransform(h, hashArgs, input), h, - env, args); } } + + private static IDataTransform CreateTransformCore(CategoricalTransform.OutputKind argsOutputKind, OneToOneColumn[] columns, + List columnOutputKinds, IDataTransform input, IHost h, Arguments catHashArgs = null) + { + Contracts.CheckValue(columns, nameof(columns)); + Contracts.CheckValue(columnOutputKinds, nameof(columnOutputKinds)); + Contracts.CheckParam(columns.Length == columnOutputKinds.Count, nameof(columns)); + + using (var ch = h.Start("Create Transform Core")) + { + // Create the KeyToVectorTransform, if needed. + var cols = new List(); + bool binaryEncoding = argsOutputKind == CategoricalTransform.OutputKind.Bin; + for (int i = 0; i < columns.Length; i++) + { + var column = columns[i]; + if (!column.TrySanitize()) + throw h.ExceptUserArg(nameof(Column.Name)); + + bool? bag; + CategoricalTransform.OutputKind kind = columnOutputKinds[i] ?? argsOutputKind; + switch (kind) + { + default: + throw ch.ExceptUserArg(nameof(Column.OutputKind)); + case CategoricalTransform.OutputKind.Key: + continue; + case CategoricalTransform.OutputKind.Bin: + binaryEncoding = true; + bag = false; + break; + case CategoricalTransform.OutputKind.Ind: + bag = false; + break; + case CategoricalTransform.OutputKind.Bag: + bag = true; + break; + } + var col = new KeyToVectorTransform.Column(); + col.Name = column.Name; + col.Source = column.Name; + col.Bag = bag; + cols.Add(col); + } + + if (cols.Count == 0) + return input; + + IDataTransform transform; + if (binaryEncoding) + { + if ((catHashArgs?.InvertHash ?? 0) != 0) + ch.Warning("Invert hashing is being used with binary encoding."); + + var keyToBinaryVecCols = cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.Source, x.Name)).ToArray(); + transform = KeyToBinaryVectorTransform.Create(h, input, keyToBinaryVecCols); + } + else + { + var keyToVecCols = cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.Source, x.Name, x.Bag ?? argsOutputKind == CategoricalTransform.OutputKind.Bag)).ToArray(); + + transform = KeyToVectorTransform.Create(h, input, keyToVecCols); + } + + ch.Done(); + return transform; + } + } } } diff --git a/src/Microsoft.ML.Transforms/CategoricalTransform.cs b/src/Microsoft.ML.Transforms/CategoricalTransform.cs index 1e1ba6d1a4..6ad99c2a32 100644 --- a/src/Microsoft.ML.Transforms/CategoricalTransform.cs +++ b/src/Microsoft.ML.Transforms/CategoricalTransform.cs @@ -2,18 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Model; [assembly: LoadableClass(CategoricalTransform.Summary, typeof(IDataTransform), typeof(CategoricalTransform), typeof(CategoricalTransform.Arguments), typeof(SignatureDataTransform), CategoricalTransform.UserName, "CategoricalTransform", "CatTransform", "Categorical", "Cat")] @@ -22,7 +23,7 @@ namespace Microsoft.ML.Runtime.Data { /// - public static class CategoricalTransform + public sealed class CategoricalTransform : ITransformer, ICanSaveModel { public enum OutputKind : byte { @@ -70,14 +71,11 @@ protected override bool TryParse(string str) // We accept N:K:S where N is the new column name, K is the output kind, // and S is source column names. - string extra; - if (!base.TryParse(str, out extra)) + if (!TryParse(str, out string extra)) return false; if (extra == null) return true; - - OutputKind kind; - if (!Enum.TryParse(extra, true, out kind)) + if (!Enum.TryParse(extra, true, out OutputKind kind)) return false; OutputKind = kind; return true; @@ -96,11 +94,6 @@ public bool TryUnparse(StringBuilder sb) } } - private static class Defaults - { - public const OutputKind OutKind = OutputKind.Ind; - } - public sealed class Arguments : TermTransform.ArgumentsBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] @@ -108,7 +101,7 @@ public sealed class Arguments : TermTransform.ArgumentsBase [Argument(ArgumentType.AtMostOnce, HelpText = "Output kind: Bag (multi-set vector), Ind (indicator vector), or Key (index)", ShortName = "kind", SortOrder = 102)] - public OutputKind OutputKind = Defaults.OutKind; + public OutputKind OutputKind = CategoricalEstimator.Defaults.OutKind; public Arguments() { @@ -124,25 +117,17 @@ public Arguments() public const string UserName = "Categorical Transform"; /// - /// A helper method to create for public facing API. + /// A helper method to create . /// /// Host Environment. /// Input . This is the output from previous transform or loader. /// Name of the output column. /// Name of the column to be transformed. If this is null '' will be used. /// The type of output expected. - public static IDataTransform Create(IHostEnvironment env, IDataView input, string name, string source = null, OutputKind outputKind = Defaults.OutKind) + public static IDataView Create(IHostEnvironment env, IDataView input, string name, + string source = null, OutputKind outputKind = CategoricalEstimator.Defaults.OutKind) { - var args = new Arguments() - { - Column = new[] { new Column(){ - Source = source ?? name, - Name = name - } - }, - OutputKind = outputKind - }; - return Create(env, args, input); + return new CategoricalEstimator(env, name, source, outputKind).Fit(input).Transform(input) as IDataView; } public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) @@ -152,88 +137,128 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV h.CheckValue(args, nameof(args)); h.CheckValue(input, nameof(input)); h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column)); - return CreateTransformCore( - args.OutputKind, - args.Column, - args.Column.Select(col => col.OutputKind).ToList(), - TermTransform.Create(h, args, args.Column, input), - h, - env); + + var columns = new List(); + foreach (var column in args.Column) + { + var col = new CategoricalEstimator.ColumnInfo( + column.Source ?? column.Name, + column.Name, + column.OutputKind ?? args.OutputKind, + column.MaxNumTerms ?? args.MaxNumTerms, + column.Sort ?? args.Sort, + column.Term ?? args.Term); + col.SetTerms(column.Terms); + columns.Add(col); + } + return new CategoricalEstimator(env, columns.ToArray()).Fit(input).Transform(input) as IDataTransform; + } + + private readonly TransformerChain _transformer; + + public CategoricalTransform(TermEstimator term, IEstimator keyToVector, IDataView input) + { + var chain = term.Append(keyToVector); + _transformer = chain.Fit(input); + } + + public ISchema GetOutputSchema(ISchema inputSchema) => _transformer.GetOutputSchema(inputSchema); + + public IDataView Transform(IDataView input) => _transformer.Transform(input); + + public void Save(ModelSaveContext ctx) => _transformer.Save(ctx); + } + + public sealed class CategoricalEstimator : IEstimator + { + public static class Defaults + { + public const CategoricalTransform.OutputKind OutKind = CategoricalTransform.OutputKind.Ind; } - public static IDataTransform CreateTransformCore( - OutputKind argsOutputKind, - OneToOneColumn[] columns, - List columnOutputKinds, - IDataTransform input, - IHost h, - IHostEnvironment env, - CategoricalHashTransform.Arguments catHashArgs = null) + public class ColumnInfo : TermTransform.ColumnInfo { - Contracts.CheckValue(columns, nameof(columns)); - Contracts.CheckValue(columnOutputKinds, nameof(columnOutputKinds)); - Contracts.CheckParam(columns.Length == columnOutputKinds.Count, nameof(columns)); + public readonly CategoricalTransform.OutputKind OutputKind; + public ColumnInfo(string input, string output, CategoricalTransform.OutputKind outputKind = Defaults.OutKind, + int maxNumTerms = TermEstimator.Defaults.MaxNumTerms, TermTransform.SortOrder sort = TermEstimator.Defaults.Sort, + string[] term = null) + : base(input, output, maxNumTerms, sort, term, true) + { + OutputKind = outputKind; + } - using (var ch = h.Start("Create Transform Core")) + internal void SetTerms(string terms) { - // Create the KeyToVectorTransform, if needed. - var cols = new List(); - bool binaryEncoding = argsOutputKind == OutputKind.Bin; - for (int i = 0; i < columns.Length; i++) - { - var column = columns[i]; - if (!column.TrySanitize()) - throw h.ExceptUserArg(nameof(Column.Name)); + Terms = terms; + } - bool? bag; - OutputKind kind = columnOutputKinds[i].HasValue ? columnOutputKinds[i].Value : argsOutputKind; - switch (kind) - { - default: - throw env.ExceptUserArg(nameof(Column.OutputKind)); - case OutputKind.Key: - continue; - case OutputKind.Bin: - binaryEncoding = true; - bag = false; - break; - case OutputKind.Ind: - bag = false; - break; - case OutputKind.Bag: - bag = true; - break; - } - var col = new KeyToVectorTransform.Column(); - col.Name = column.Name; - col.Source = column.Name; - col.Bag = bag; - cols.Add(col); - } + } - if (cols.Count == 0) - return input; + private readonly IHost _host; + private readonly IEstimator _keyToSomething; + private TermEstimator _term; - IDataTransform transform; + /// A helper method to create for public facing API. + /// Host Environment. + /// Name of the output column. + /// Name of the column to be transformed. If this is null '' will be used. + /// The type of output expected. + public CategoricalEstimator(IHostEnvironment env, string name, + string source = null, CategoricalTransform.OutputKind outputKind = Defaults.OutKind) + : this(env, new ColumnInfo(source ?? name, name, outputKind)) + { + } + + public CategoricalEstimator(IHostEnvironment env, params ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(TermEstimator)); + _term = new TermEstimator(_host, columns); + + var cols = new List<(string input, string output, bool bag)>(); + bool binaryEncoding = false; + for (int i = 0; i < columns.Length; i++) + { + var column = columns[i]; + bool bag; + CategoricalTransform.OutputKind kind = columns[i].OutputKind; + switch (kind) + { + default: + throw _host.ExceptUserArg(nameof(column.OutputKind)); + case CategoricalTransform.OutputKind.Key: + continue; + case CategoricalTransform.OutputKind.Bin: + binaryEncoding = true; + bag = false; + break; + case CategoricalTransform.OutputKind.Ind: + bag = false; + break; + case CategoricalTransform.OutputKind.Bag: + bag = true; + break; + } + cols.Add((column.Output, column.Output, bag)); if (binaryEncoding) { - if ((catHashArgs?.InvertHash ?? 0) != 0) - ch.Warning("Invert hashing is being used with binary encoding."); - - var keyToBinaryVecCols = cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.Source, x.Name)).ToArray(); - transform = KeyToBinaryVectorTransform.Create(h, input, keyToBinaryVecCols); + _keyToSomething = new KeyToBinaryVectorEstimator(_host, cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.input, x.output)).ToArray()); } else { - var keyToVecCols = cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.Source, x.Name, x.Bag ?? argsOutputKind == OutputKind.Bag)).ToArray(); - - transform = KeyToVectorTransform.Create(h, input, keyToVecCols); + _keyToSomething = new KeyToVectorEstimator(_host, cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.input, x.output, x.bag)).ToArray()); } - - ch.Done(); - return transform; } } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) => _term.Append(_keyToSomething).GetOutputSchema(inputSchema); + + public CategoricalTransform Fit(IDataView input) => new CategoricalTransform(_term, _keyToSomething, input); + + internal void WrapTermWithDelegate(Action onFit) + { + _term = (TermEstimator)_term.WithOnFitDelegate(onFit); + } } public static class Categorical @@ -301,4 +326,153 @@ public static CommonOutputs.TransformOutput KeyToText(IHostEnvironment env, KeyT return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; } } + + public static class CategoricalStaticExtensions + { + public enum OneHotVectorOutputKind : byte + { + /// + /// Output is a bag (multi-set) vector + /// + Bag = 1, + + /// + /// Output is an indicator vector + /// + Ind = 2, + + /// + /// Output is binary encoded + /// + Bin = 4, + } + + public enum OneHotScalarOutputKind : byte + { + /// + /// Output is an indicator vector + /// + Ind = 2, + + /// + /// Output is binary encoded + /// + Bin = 4, + } + + private const KeyValueOrder DefSort = (KeyValueOrder)TermEstimator.Defaults.Sort; + private const int DefMax = TermEstimator.Defaults.MaxNumTerms; + private const OneHotVectorOutputKind DefOut = (OneHotVectorOutputKind)CategoricalEstimator.Defaults.OutKind; + + private struct Config + { + public readonly KeyValueOrder Order; + public readonly int Max; + public readonly OneHotVectorOutputKind OutputKind; + public readonly Action OnFit; + + public Config(OneHotVectorOutputKind outputKind, KeyValueOrder order, int max, Action onFit) + { + OutputKind = outputKind; + Order = order; + Max = max; + OnFit = onFit; + } + } + + private static Action Wrap(ToKeyFitResult.OnFit onFit) + { + if (onFit == null) + return null; + // The type T asociated with the delegate will be the actual value type once #863 goes in. + // However, until such time as #863 goes in, it would be too awkward to attempt to extract the metadata. + // For now construct the useless object then pass it into the delegate. + return map => onFit(new ToKeyFitResult(map)); + } + + private interface ICategoricalCol + { + PipelineColumn Input { get; } + Config Config { get; } + } + + private sealed class ImplScalar : Vector, ICategoricalCol + { + public PipelineColumn Input { get; } + public Config Config { get; } + public ImplScalar(PipelineColumn input, Config config) : base(Rec.Inst, input) + { + Input = input; + Config = config; + } + } + + private sealed class ImplVector : Vector, ICategoricalCol + { + public PipelineColumn Input { get; } + public Config Config { get; } + public ImplVector(PipelineColumn input, Config config) : base(Rec.Inst, input) + { + Input = input; + Config = config; + } + } + + private sealed class Rec : EstimatorReconciler + { + public static readonly Rec Inst = new Rec(); + + public override IEstimator Reconcile(IHostEnvironment env, PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) + { + var infos = new CategoricalEstimator.ColumnInfo[toOutput.Length]; + Action onFit = null; + for (int i = 0; i < toOutput.Length; ++i) + { + var tcol = (ICategoricalCol)toOutput[i]; + infos[i] = new CategoricalEstimator.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], (CategoricalTransform.OutputKind)tcol.Config.OutputKind, + tcol.Config.Max, (TermTransform.SortOrder)tcol.Config.Order); + if (tcol.Config.OnFit != null) + { + int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call. + onFit += tt => tcol.Config.OnFit(tt.GetTermMap(ii)); + } + } + var est = new CategoricalEstimator(env, infos); + if (onFit != null) + est.WrapTermWithDelegate(onFit); + return est; + } + } + + /// + /// Converts the categorical value into an indicator array by building a dictionary of categories based on the data and using the id in the dictionary as the index in the array + /// + /// Incoming data. + /// Specify output type of indicator array: array or binary encoded data. + /// How Id for each value would be assigined: by occurrence or by value. + /// Maximum number of ids to keep during data scanning. + /// /// Called upon fitting with the learnt enumeration on the dataset. + public static Vector OneHotEncoding(this Scalar input, OneHotScalarOutputKind outputKind = (OneHotScalarOutputKind)DefOut, KeyValueOrder order = DefSort, + int maxItems = DefMax, ToKeyFitResult>.OnFit onFit = null) + { + Contracts.CheckValue(input, nameof(input)); + return new ImplScalar(input, new Config((OneHotVectorOutputKind)outputKind, order, maxItems, Wrap(onFit))); + } + + /// + /// Converts the categorical value into an indicator array by building a dictionary of categories based on the data and using the id in the dictionary as the index in the array + /// + /// Incoming data. + /// Specify output type of indicator array: Multiarray, array or binary encoded data. + /// How Id for each value would be assigined: by occurrence or by value. + /// Maximum number of ids to keep during data scanning. + /// Called upon fitting with the learnt enumeration on the dataset. + public static Vector OneHotEncoding(this Vector input, OneHotVectorOutputKind outputKind = DefOut, KeyValueOrder order = DefSort, int maxItems = DefMax, + ToKeyFitResult>.OnFit onFit = null) + { + Contracts.CheckValue(input, nameof(input)); + return new ImplVector(input, new Config(outputKind, order, maxItems, Wrap(onFit))); + } + } } diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index ab33f5761f..cde6ae9aa1 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -2,10 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; using Microsoft.ML.Core.Data; using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; @@ -13,11 +9,15 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; [assembly: LoadableClass(KeyToBinaryVectorTransform.Summary, typeof(IDataTransform), typeof(KeyToBinaryVectorTransform), typeof(KeyToBinaryVectorTransform.Arguments), typeof(SignatureDataTransform), - "Key To Binary Vector Transform", KeyToBinaryVectorTransform.UserName, "KeyToBinary", "ToVector", DocName = "transform/KeyToBinaryVectorTransform.md")] + "Key To Binary Vector Transform", KeyToBinaryVectorTransform.UserName, "KeyToBinary", "ToBinaryVector", DocName = "transform/KeyToBinaryVectorTransform.md")] -[assembly: LoadableClass(KeyToBinaryVectorTransform.Summary, typeof(IDataView), typeof(KeyToBinaryVectorTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(KeyToBinaryVectorTransform.Summary, typeof(IDataTransform), typeof(KeyToBinaryVectorTransform), null, typeof(SignatureLoadDataTransform), "Key To Binary Vector Transform", KeyToBinaryVectorTransform.LoaderSignature)] [assembly: LoadableClass(KeyToBinaryVectorTransform.Summary, typeof(KeyToBinaryVectorTransform), null, typeof(SignatureLoadModel), @@ -446,7 +446,7 @@ public KeyToBinaryVectorEstimator(IHostEnvironment env, string name, string sour { } - public KeyToBinaryVectorEstimator(IHostEnvironment env, KeyToBinaryVectorTransform transformer) + private KeyToBinaryVectorEstimator(IHostEnvironment env, KeyToBinaryVectorTransform transformer) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(KeyToBinaryVectorEstimator)), transformer) { } diff --git a/test/BaselineOutput/SingleDebug/Categorical/featurized.tsv b/test/BaselineOutput/SingleDebug/Categorical/featurized.tsv new file mode 100644 index 0000000000..cd3a602966 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/Categorical/featurized.tsv @@ -0,0 +1,14 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=A:R4:0-4 +#@ col=B:R4:5-24 +#@ col=C:R4:25-44 +#@ col=D:R4:45-49 +#@ col=E:R4:50-69 +#@ } +Bit4 Bit3 Bit2 Bit1 Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0 Bit4 Bit3 Bit2 Bit1 Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0 +70 14:1 19:1 24:1 34:1 39:1 44:1 59:1 64:1 69:1 +70 13:1 18:1 33:1 38:1 58:1 63:1 +70 4:1 8:1 9:1 14:1 19:1 24:1 28:1 29:1 34:1 39:1 44:1 49:1 53:1 54:1 59:1 64:1 69:1 +70 3:1 7:1 12:1 14:1 17:1 19:1 24:1 27:1 32:1 34:1 37:1 39:1 44:1 48:1 52:1 57:1 59:1 62:1 64:1 69:1 diff --git a/test/BaselineOutput/SingleRelease/Categorical/featurized.tsv b/test/BaselineOutput/SingleRelease/Categorical/featurized.tsv new file mode 100644 index 0000000000..cd3a602966 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/Categorical/featurized.tsv @@ -0,0 +1,14 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=A:R4:0-4 +#@ col=B:R4:5-24 +#@ col=C:R4:25-44 +#@ col=D:R4:45-49 +#@ col=E:R4:50-69 +#@ } +Bit4 Bit3 Bit2 Bit1 Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0 Bit4 Bit3 Bit2 Bit1 Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0 +70 14:1 19:1 24:1 34:1 39:1 44:1 59:1 64:1 69:1 +70 13:1 18:1 33:1 38:1 58:1 63:1 +70 4:1 8:1 9:1 14:1 19:1 24:1 28:1 29:1 34:1 39:1 44:1 49:1 53:1 54:1 59:1 64:1 69:1 +70 3:1 7:1 12:1 14:1 17:1 19:1 24:1 27:1 32:1 34:1 37:1 39:1 44:1 48:1 52:1 57:1 59:1 62:1 64:1 69:1 diff --git a/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs b/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs index 234995d1e5..b08061416e 100644 --- a/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs +++ b/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs @@ -45,13 +45,7 @@ public ParameterMixingCalibratedPredictor TrainKMeansAndLR() } }, new MultiFileSource(_dataPath)); - IDataView trans = CategoricalTransform.Create(env, new CategoricalTransform.Arguments - { - Column = new[] - { - new CategoricalTransform.Column { Name = "CatFeatures", Source = "CatFeatures" } - } - }, loader); + IDataView trans = CategoricalTransform.Create(env, loader, "CatFeatures"); trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "NumFeatures"); trans = new ConcatTransform(env, "Features", "NumFeatures", "CatFeatures").Transform(trans); diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index bde0cc20fc..b3f9938dfd 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -237,7 +237,7 @@ public void BinaryClassifierLogisticRegressionTest() public void BinaryClassifierSymSgdTest() { //Results sometimes go out of error tolerance on OS X. - if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) return; RunOneAllTests(TestLearners.symSGD, TestDatasets.breastCancer, summary: true); @@ -582,14 +582,7 @@ public void TestTreeEnsembleCombinerWithCategoricalSplits() var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile }).Data; #pragma warning restore 0618 - var cat = CategoricalTransform.Create(Env, - new CategoricalTransform.Arguments() - { - Column = new[] - { - new CategoricalTransform.Column() { Name = "Features", Source = "Categories" } - } - }, dataView); + var cat = CategoricalTransform.Create(Env, dataView, "Features", "Categories"); var fastTrees = new IPredictorModel[3]; for (int i = 0; i < 3; i++) { diff --git a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs new file mode 100644 index 0000000000..89e95ec08e --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs @@ -0,0 +1,324 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using System.IO; +using System.Linq; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class CategoricalTests : TestDataPipeBase + { + public CategoricalTests(ITestOutputHelper output) : base(output) + { + } + + private class TestClass + { + public int A; + public int B; + public int C; + } + + private class TestMeta + { + [VectorType(2)] + public string[] A; + public string B; + [VectorType(2)] + public int[] C; + public int D; + [VectorType(2)] + public float[] E; + public float F; + [VectorType(2)] + public string[] G; + public string H; + } + + [Fact] + public void CategoricalWorkout() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new CategoricalEstimator(Env, new[]{ + new CategoricalEstimator.ColumnInfo("A", "CatA", CategoricalTransform.OutputKind.Bag), + new CategoricalEstimator.ColumnInfo("A", "CatB", CategoricalTransform.OutputKind.Bin), + new CategoricalEstimator.ColumnInfo("A", "CatC", CategoricalTransform.OutputKind.Ind), + new CategoricalEstimator.ColumnInfo("A", "CatD", CategoricalTransform.OutputKind.Key), + }); + + TestEstimatorCore(pipe, dataView); + Done(); + } + + [Fact] + public void CategoricalStatic() + { + string dataPath = GetDataPath("breast-cancer.txt"); + var reader = TextLoader.CreateReader(Env, ctx => ( + ScalarString: ctx.LoadText(1), + VectorString: ctx.LoadText(1, 4))); + var data = reader.Read(new MultiFileSource(dataPath)); + var wrongCollection = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + + var invalidData = ComponentCreation.CreateDataView(Env, wrongCollection); + var est = data.MakeNewEstimator(). + Append(row => ( + A: row.ScalarString.OneHotEncoding(outputKind: CategoricalStaticExtensions.OneHotScalarOutputKind.Ind), + B: row.VectorString.OneHotEncoding(outputKind: CategoricalStaticExtensions.OneHotVectorOutputKind.Ind), + C: row.VectorString.OneHotEncoding(outputKind: CategoricalStaticExtensions.OneHotVectorOutputKind.Bag), + D: row.ScalarString.OneHotEncoding(outputKind: CategoricalStaticExtensions.OneHotScalarOutputKind.Bin), + E: row.VectorString.OneHotEncoding(outputKind: CategoricalStaticExtensions.OneHotVectorOutputKind.Bin) + )); + + TestEstimatorCore(est.AsDynamic, data.AsDynamic, invalidInput: invalidData); + + var outputPath = GetOutputPath("Categorical", "featurized.tsv"); + using (var ch = Env.Start("save")) + { + var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); + IDataView savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); + savedData = new ChooseColumnsTransform(Env, savedData, "A", "B", "C", "D", "E"); + using (var fs = File.Create(outputPath)) + DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); + } + + CheckEquality("Categorical", "featurized.tsv"); + Done(); + } + + [Fact] + public void TestMetadataPropagation() + { + var data = new[] { + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E= new float[2] { 1.0f,2.0f}, F = 1.0f , G= new string[2]{ "A","D"}, H="D"}, + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 5,3}, D= 1, E=new float[2] { 3.0f,4.0f}, F = -1.0f ,G= new string[2]{"E", "A"}, H="E"}, + new TestMeta() { A=new string[2] { "A", "B"}, B="C", C=new int[2] { 3,5}, D= 6, E=new float[2] { 5.0f,6.0f}, F = 1.0f ,G= new string[2]{ "D", "E"}, H="D"} }; + + + var dataView = ComponentCreation.CreateDataView(Env, data); + var bagPipe = new CategoricalEstimator(Env, + new CategoricalEstimator.ColumnInfo("A", "CatA", CategoricalTransform.OutputKind.Bag), + new CategoricalEstimator.ColumnInfo("B", "CatB", CategoricalTransform.OutputKind.Bag), + new CategoricalEstimator.ColumnInfo("C", "CatC", CategoricalTransform.OutputKind.Bag), + new CategoricalEstimator.ColumnInfo("D", "CatD", CategoricalTransform.OutputKind.Bag), + new CategoricalEstimator.ColumnInfo("E", "CatE", CategoricalTransform.OutputKind.Ind), + new CategoricalEstimator.ColumnInfo("F", "CatF", CategoricalTransform.OutputKind.Ind), + new CategoricalEstimator.ColumnInfo("G", "CatG", CategoricalTransform.OutputKind.Key), + new CategoricalEstimator.ColumnInfo("H", "CatH", CategoricalTransform.OutputKind.Key)); + + var binPipe = new CategoricalEstimator(Env, + new CategoricalEstimator.ColumnInfo("A", "CatA", CategoricalTransform.OutputKind.Bin), + new CategoricalEstimator.ColumnInfo("B", "CatB", CategoricalTransform.OutputKind.Bin), + new CategoricalEstimator.ColumnInfo("C", "CatC", CategoricalTransform.OutputKind.Bin), + new CategoricalEstimator.ColumnInfo("D", "CatD", CategoricalTransform.OutputKind.Bin), + new CategoricalEstimator.ColumnInfo("E", "CatE", CategoricalTransform.OutputKind.Ind), + new CategoricalEstimator.ColumnInfo("F", "CatF", CategoricalTransform.OutputKind.Ind), + new CategoricalEstimator.ColumnInfo("G", "CatG", CategoricalTransform.OutputKind.Key), + new CategoricalEstimator.ColumnInfo("H", "CatH", CategoricalTransform.OutputKind.Key)); + + var bagResult = bagPipe.Fit(dataView).Transform(dataView); + var binResult = binPipe.Fit(dataView).Transform(dataView); + + ValidateBagMetadata(bagResult); + ValidateBinMetadata(binResult); + Done(); + } + + private void ValidateBinMetadata(IDataView result) + { + Assert.True(result.Schema.TryGetColumnIndex("CatA", out int colA)); + Assert.True(result.Schema.TryGetColumnIndex("CatB", out int colB)); + Assert.True(result.Schema.TryGetColumnIndex("CatC", out int colC)); + Assert.True(result.Schema.TryGetColumnIndex("CatD", out int colD)); + Assert.True(result.Schema.TryGetColumnIndex("CatE", out int colE)); + Assert.True(result.Schema.TryGetColumnIndex("CatF", out int colF)); + Assert.True(result.Schema.TryGetColumnIndex("CatE", out int colG)); + Assert.True(result.Schema.TryGetColumnIndex("CatF", out int colH)); + var types = result.Schema.GetMetadataTypes(colA); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.SlotNames }); + VBuffer slots = default; + DvBool normalized = default; + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colA, ref slots); + Assert.True(slots.Length == 6); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[6] { "[0].Bit2", "[0].Bit1", "[0].Bit0", "[1].Bit2", "[1].Bit1", "[1].Bit0" }); + + types = result.Schema.GetMetadataTypes(colB); + Assert.Equal(types.Select(x => x.Key), new string[2] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colB, ref slots); + Assert.True(slots.Length == 2); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "Bit1", "Bit0" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colB, ref normalized); + Assert.True(normalized.IsTrue); + + types = result.Schema.GetMetadataTypes(colC); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.SlotNames }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colC, ref slots); + Assert.True(slots.Length == 6); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[6] { "[0].Bit2", "[0].Bit1", "[0].Bit0", "[1].Bit2", "[1].Bit1", "[1].Bit0" }); + + + types = result.Schema.GetMetadataTypes(colD); + Assert.Equal(types.Select(x => x.Key), new string[2] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colD, ref slots); + Assert.True(slots.Length == 3); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[3] { "Bit2", "Bit1", "Bit0" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colD, ref normalized); + Assert.True(normalized.IsTrue); + + + types = result.Schema.GetMetadataTypes(colE); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.SlotNames }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colE, ref slots); + Assert.True(slots.Length == 8); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[8] { "[0].Bit3", "[0].Bit2", "[0].Bit1", "[0].Bit0", "[1].Bit3", "[1].Bit2", "[1].Bit1", "[1].Bit0" }); + + types = result.Schema.GetMetadataTypes(colF); + Assert.Equal(types.Select(x => x.Key), new string[2] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colE, ref slots); + Assert.True(slots.Length == 8); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[8] { "[0].Bit3", "[0].Bit2", "[0].Bit1", "[0].Bit0", "[1].Bit3", "[1].Bit2", "[1].Bit1", "[1].Bit0" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colF, ref normalized); + Assert.True(normalized.IsTrue); + + types = result.Schema.GetMetadataTypes(colG); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.SlotNames }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colG, ref slots); + Assert.True(slots.Length == 8); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[8] { "[0].Bit3", "[0].Bit2", "[0].Bit1", "[0].Bit0", "[1].Bit3", "[1].Bit2", "[1].Bit1", "[1].Bit0" }); + + + types = result.Schema.GetMetadataTypes(colH); + Assert.Equal(types.Select(x => x.Key), new string[2] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colH, ref slots); + Assert.True(slots.Length == 3); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[3] { "Bit2", "Bit1", "Bit0" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colH, ref normalized); + Assert.True(normalized.IsTrue); + } + + private void ValidateBagMetadata(IDataView result) + { + Assert.True(result.Schema.TryGetColumnIndex("CatA", out int colA)); + Assert.True(result.Schema.TryGetColumnIndex("CatB", out int colB)); + Assert.True(result.Schema.TryGetColumnIndex("CatC", out int colC)); + Assert.True(result.Schema.TryGetColumnIndex("CatD", out int colD)); + Assert.True(result.Schema.TryGetColumnIndex("CatE", out int colE)); + Assert.True(result.Schema.TryGetColumnIndex("CatF", out int colF)); + Assert.True(result.Schema.TryGetColumnIndex("CatE", out int colG)); + Assert.True(result.Schema.TryGetColumnIndex("CatF", out int colH)); + var types = result.Schema.GetMetadataTypes(colA); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.SlotNames }); + VBuffer slots = default; + VBuffer slotRanges = default; + DvBool normalized = default; + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colA, ref slots); + Assert.True(slots.Length == 2); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "A", "B" }); + + types = result.Schema.GetMetadataTypes(colB); + Assert.Equal(types.Select(x => x.Key), new string[2] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colB, ref slots); + Assert.True(slots.Length == 1); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[1] { "C" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colB, ref normalized); + Assert.True(normalized.IsTrue); + + types = result.Schema.GetMetadataTypes(colC); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.SlotNames }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colC, ref slots); + Assert.True(slots.Length == 2); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "3", "5" }); + + + types = result.Schema.GetMetadataTypes(colD); + Assert.Equal(types.Select(x => x.Key), new string[2] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colD, ref slots); + Assert.True(slots.Length == 2); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "6", "1" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colD, ref normalized); + Assert.True(normalized.IsTrue); + + + types = result.Schema.GetMetadataTypes(colE); + Assert.Equal(types.Select(x => x.Key), new string[3] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colE, ref slots); + Assert.True(slots.Length == 12); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[12] { "[0].1", "[0].2", "[0].3", "[0].4", "[0].5", "[0].6", "[1].1", "[1].2", "[1].3", "[1].4", "[1].5", "[1].6" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colE, ref slotRanges); + Assert.True(slotRanges.Length == 4); + Assert.Equal(slotRanges.Items().Select(x => x.Value.ToString()), new string[4] { "0", "5", "6", "11" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colE, ref normalized); + Assert.True(normalized.IsTrue); + + types = result.Schema.GetMetadataTypes(colF); + Assert.Equal(types.Select(x => x.Key), new string[3] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colF, ref slots); + Assert.True(slots.Length == 2); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "1", "-1" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colF, ref slotRanges); + Assert.True(slotRanges.Length == 2); + Assert.Equal(slotRanges.Items().Select(x => x.Value.ToString()), new string[2] { "0", "1" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colF, ref normalized); + Assert.True(normalized.IsTrue); + + types = result.Schema.GetMetadataTypes(colG); + Assert.Equal(types.Select(x => x.Key), new string[3] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colG, ref slots); + Assert.True(slots.Length == 12); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[12] { "[0].1", "[0].2", "[0].3", "[0].4", "[0].5", "[0].6", "[1].1", "[1].2", "[1].3", "[1].4", "[1].5", "[1].6" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colG, ref slotRanges); + Assert.True(slotRanges.Length == 4); + Assert.Equal(slotRanges.Items().Select(x => x.Value.ToString()), new string[4] { "0", "5", "6", "11" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colG, ref normalized); + + + types = result.Schema.GetMetadataTypes(colH); + Assert.Equal(types.Select(x => x.Key), new string[3] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.Kinds.IsNormalized }); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colH, ref slots); + Assert.True(slots.Length == 2); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "1", "-1" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colH, ref slotRanges); + Assert.True(slotRanges.Length == 2); + Assert.Equal(slotRanges.Items().Select(x => x.Value.ToString()), new string[2] { "0", "1" }); + result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colH, ref normalized); + } + + [Fact] + public void TestCommandLine() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Cat{col=B:A} in=f:\2.txt" }), (int)0); + } + + [Fact] + public void TestOldSavingAndLoading() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new CategoricalEstimator(Env, new[]{ + new CategoricalEstimator.ColumnInfo("A", "TermA"), + new CategoricalEstimator.ColumnInfo("B", "TermB"), + new CategoricalEstimator.ColumnInfo("C", "TermC") + }); + var result = pipe.Fit(dataView).Transform(dataView); + var resultRoles = new RoleMappedData(result); + using (var ms = new MemoryStream()) + { + TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles); + ms.Position = 0; + var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms); + } + } + + } +} diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs index d7b57c232a..fdd240c514 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs @@ -128,7 +128,7 @@ private void ValidateMetadata(IDataView result) DvBool normalized = default; result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colA, ref slots); Assert.True(slots.Length == 6); - Assert.Equal(slots.Values.Select(x => x.ToString()), new string[6] { "[0].Bit2", "[0].Bit1", "[0].Bit0", "[1].Bit2", "[1].Bit1", "[1].Bit0" }); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[6] { "[0].Bit2", "[0].Bit1", "[0].Bit0", "[1].Bit2", "[1].Bit1", "[1].Bit0" }); types = result.Schema.GetMetadataTypes(colB); Assert.Equal(types.Select(x => x.Key), new string[2] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.IsNormalized }); diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index 1701c17083..45c6f938c5 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -152,7 +152,7 @@ private void ValidateMetadata(IDataView result) DvBool normalized = default; result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colA, ref slots); Assert.True(slots.Length == 2); - Assert.Equal(slots.Values.Select(x => x.ToString()), new string[2] { "A", "B" }); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "A", "B" }); types = result.Schema.GetMetadataTypes(colB); Assert.Equal(types.Select(x => x.Key), new string[3] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.Kinds.IsNormalized });