diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs index d69379d5da..e3e58885d6 100644 --- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs +++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs @@ -9,6 +9,8 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.Model.Onnx; +using Microsoft.ML.Runtime.Model.Pfa; [assembly: LoadableClass(typeof(RowToRowMapperTransform), null, typeof(SignatureLoadDataTransform), "", RowToRowMapperTransform.LoaderSignature)] @@ -110,7 +112,7 @@ public Dictionary Infos() /// It does so with the help of an , that is given a schema in its constructor, and has methods /// to get the dependencies on input columns and the getters for the output columns, given an active set of output columns. /// - public sealed class RowToRowMapperTransform : RowToRowTransformBase + public sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMapper, ITransformCanSaveOnnx, ITransformCanSavePfa { private sealed class Bindings : ColumnBindingsBase { @@ -209,6 +211,10 @@ private static VersionInfo GetVersionInfo() public override ISchema Schema { get { return _bindings; } } + public bool CanSaveOnnx => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx : false; + + public bool CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false; + public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper) : base(env, RegistrationName, input) { @@ -318,6 +324,101 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid return cursors; } + public void SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + if (_mapper is ISaveAsOnnx onnx) + { + Host.Check(onnx.CanSaveOnnx, "Cannot be saved as ONNX."); + onnx.SaveAsOnnx(ctx); + } + } + + public void SaveAsPfa(BoundPfaContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + if (_mapper is ISaveAsPfa pfa) + { + Host.Check(pfa.CanSavePfa, "Cannot be saved as PFA."); + pfa.SaveAsPfa(ctx); + } + } + + public Func GetDependencies(Func predicate) + { + Func predicateInput; + _bindings.GetActive(predicate, out predicateInput); + return predicateInput; + } + + public IRow GetRow(IRow input, Func active, out Action disposer) + { + Host.CheckValue(input, nameof(input)); + Host.CheckValue(active, nameof(active)); + Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to"); + + disposer = null; + using (var ch = Host.Start("GetEntireRow")) + { + Action disp; + var activeArr = new bool[Schema.ColumnCount]; + for (int i = 0; i < Schema.ColumnCount; i++) + activeArr[i] = active(i); + var pred = _bindings.GetActiveOutputColumns(activeArr); + var getters = _mapper.CreateGetters(input, pred, out disp); + disposer += disp; + ch.Done(); + return new Row(input, this, Schema, getters); + } + } + + private sealed class Row : IRow + { + private readonly IRow _input; + private readonly Delegate[] _getters; + + private readonly RowToRowMapperTransform _parent; + + public long Batch { get { return _input.Batch; } } + + public long Position { get { return _input.Position; } } + + public ISchema Schema { get; } + + public Row(IRow input, RowToRowMapperTransform parent, ISchema schema, Delegate[] getters) + { + _input = input; + _parent = parent; + Schema = schema; + _getters = getters; + } + + public ValueGetter GetGetter(int col) + { + bool isSrc; + int index = _parent._bindings.MapColumnIndex(out isSrc, col); + if (isSrc) + return _input.GetGetter(index); + + Contracts.Assert(_getters[index] != null); + var fn = _getters[index] as ValueGetter; + if (fn == null) + throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); + return fn; + } + + public ValueGetter GetIdGetter() => _input.GetIdGetter(); + + public bool IsColumnActive(int col) + { + bool isSrc; + int index = _parent._bindings.MapColumnIndex(out isSrc, col); + if (isSrc) + return _input.IsColumnActive((index)); + return _getters[index] != null; + } + } + private sealed class RowCursor : SynchronizedCursorBase, IRowCursor { private readonly Delegate[] _getters; diff --git a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs index 36d839b93d..103f2efd9f 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs @@ -19,9 +19,9 @@ public interface ICanSaveOnnx } /// - /// This data model component is savable as ONNX. + /// This component know how to save himself in ONNX format. /// - public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform + public interface ISaveAsOnnx : ICanSaveOnnx { /// /// Save as ONNX. @@ -30,6 +30,13 @@ public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform void SaveAsOnnx(OnnxContext ctx); } + /// + /// This data model component is savable as ONNX. + /// + public interface ITransformCanSaveOnnx : ISaveAsOnnx, IDataTransform + { + } + /// /// This is savable in ONNX. Note that this is /// typically called within an that is wrapping diff --git a/src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs b/src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs index 94505f36db..4bd62eacb9 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs @@ -20,9 +20,9 @@ public interface ICanSavePfa } /// - /// This data model component is savable as PFA. See http://dmg.org/pfa/ . + /// This component know how to save himself in Pfa format. /// - public interface ITransformCanSavePfa : ICanSavePfa, IDataTransform + public interface ISaveAsPfa : ICanSavePfa { /// /// Save as PFA. For any columns that are output, this interface should use @@ -34,6 +34,14 @@ public interface ITransformCanSavePfa : ICanSavePfa, IDataTransform void SaveAsPfa(BoundPfaContext ctx); } + /// + /// This data model component is savable as PFA. See http://dmg.org/pfa/ . + /// + public interface ITransformCanSavePfa : ISaveAsPfa, IDataTransform + { + + } + /// /// This is savable as a PFA. Note that this is /// typically called within an that is wrapping diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs index 3c47f84faa..9a287dd386 100644 --- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs @@ -217,7 +217,7 @@ internal sealed class CopyColumnsRowMapper : IRowMapper { private readonly ISchema _schema; private readonly Dictionary _colNewToOldMapping; - private (string Source, string Name)[] _columns; + private readonly (string Source, string Name)[] _columns; private readonly IHost _host; public const string LoaderSignature = "CopyColumnsRowMapper"; diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs index b301eac793..6773dfbf9c 100644 --- a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs @@ -172,6 +172,18 @@ public Delegate[] CreateGetters(IRow input, Func activeOutput, out Ac } protected abstract Delegate MakeGetter(IRow input, int iinfo, out Action disposer); + + protected int AddMetaGetter(ColumnMetadataInfo colMetaInfo, ISchema schema, string kind, ColumnType ct, Dictionary colMap) + { + MetadataUtils.MetadataGetter getter = (int col, ref T dst) => + { + var originalCol = colMap[col]; + schema.GetMetadata(kind, originalCol, ref dst); + }; + var info = new MetadataInfo(ct, getter); + colMetaInfo.Add(kind, info); + return 0; + } } } } diff --git a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs new file mode 100644 index 0000000000..33d23b9e8c --- /dev/null +++ b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs @@ -0,0 +1,52 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using System.Linq; + +namespace Microsoft.ML.Runtime.Data +{ + public sealed class TermEstimator : IEstimator + { + private readonly IHost _host; + private readonly TermTransform.ColumnInfo[] _columns; + public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = TermTransform.Defaults.MaxNumTerms, TermTransform.SortOrder sort = TermTransform.Defaults.Sort) : + this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort)) + { + } + + public TermEstimator(IHostEnvironment env, params TermTransform.ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(TermEstimator)); + _columns = columns; + } + + public TermTransform Fit(IDataView input) => new TermTransform(_host, input, _columns); + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in _columns) + { + var col = inputSchema.FindColumn(colInfo.Input); + + if (col == null) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + + if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + string[] metadata; + if (col.MetadataKinds.Contains(MetadataUtils.Kinds.SlotNames)) + metadata = new[] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.KeyValues }; + else + metadata = new[] { MetadataUtils.Kinds.KeyValues }; + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, NumberType.U4, true, metadata); + } + + return new SchemaShape(result.Values); + } + } +} diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index 4af45555fe..d124433797 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -1,15 +1,7 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -#pragma warning disable 420 // volatile with Interlocked.CompareExchange - -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Text; -using System.Threading; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -20,16 +12,28 @@ using Microsoft.ML.Runtime.Model.Onnx; using Microsoft.ML.Runtime.Model.Pfa; using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; -[assembly: LoadableClass(TermTransform.Summary, typeof(TermTransform), typeof(TermTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(TermTransform.Summary, typeof(IDataTransform), typeof(TermTransform), + typeof(TermTransform.Arguments), typeof(SignatureDataTransform), TermTransform.UserName, "Term", "AutoLabel", "TermTransform", "AutoLabelTransform", DocName = "transform/TermTransform.md")] -[assembly: LoadableClass(TermTransform.Summary, typeof(TermTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(TermTransform.Summary, typeof(IDataView), typeof(TermTransform), null, typeof(SignatureLoadDataTransform), + TermTransform.UserName, TermTransform.LoaderSignature)] + +[assembly: LoadableClass(TermTransform.Summary, typeof(TermTransform), null, typeof(SignatureLoadModel), + TermTransform.UserName, TermTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(TermTransform), null, typeof(SignatureLoadRowMapper), TermTransform.UserName, TermTransform.LoaderSignature)] namespace Microsoft.ML.Runtime.Data { - // TermTransform builds up term vocabularies (dictionaries). // Notes: // * Each column builds/uses exactly one "vocabulary" (dictionary). @@ -37,7 +41,7 @@ namespace Microsoft.ML.Runtime.Data // * The Key value is the one-based index of the item in the dictionary. // * Not found is assigned the value zero. /// - public sealed partial class TermTransform : OneToOneTransformBase, ITransformTemplate + public sealed partial class TermTransform : OneToOneTransformerBase { public abstract class ColumnBase : OneToOneColumn { @@ -97,7 +101,7 @@ public enum SortOrder : byte // other things, like case insensitive (where appropriate), culturally aware, etc.? } - private static class Defaults + internal static class Defaults { public const int MaxNumTerms = 1000000; public const SortOrder Sort = SortOrder.Occurrence; @@ -144,6 +148,42 @@ public sealed class Arguments : ArgumentsBase public Column[] Column; } + internal sealed class ColInfo + { + public readonly string Name; + public readonly string Source; + public readonly ColumnType TypeSrc; + + public ColInfo(string name, string source, ColumnType type) + { + Name = name; + Source = source; + TypeSrc = type; + } + } + + public class ColumnInfo + { + public ColumnInfo(string input, string output, int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort, string[] term = null, bool textKeyValues = false) + { + Input = input; + Output = output; + Sort = sort; + MaxNumTerms = maxNumTerms; + Term = term; + TextKeyValues = textKeyValues; + } + + public readonly string Input; + public readonly string Output; + public readonly SortOrder Sort; + public readonly int MaxNumTerms; + public readonly string[] Term; + public readonly bool TextKeyValues; + + internal string Terms { get; set; } + } + public const string Summary = "Converts input values (words, numbers, etc.) to index in a dictionary."; public const string UserName = "Term Transform"; public const string LoaderSignature = "TermTransform"; @@ -164,6 +204,22 @@ private static VersionInfo GetVersionInfo() private const uint VerManagerNonTextTypesSupported = 0x00010002; public const string TermManagerLoaderSignature = "TermManager"; + private static volatile MemoryStreamPool _codecFactoryPool; + private volatile CodecFactory _codecFactory; + + private CodecFactory CodecFactory + { + get + { + if (_codecFactory == null) + { + Interlocked.CompareExchange(ref _codecFactoryPool, new MemoryStreamPool(), null); + Interlocked.CompareExchange(ref _codecFactory, new CodecFactory(Host, _codecFactoryPool), null); + } + Host.Assert(_codecFactory != null); + return _codecFactory; + } + } private static VersionInfo GetTermManagerVersionInfo() { return new VersionInfo( @@ -175,32 +231,159 @@ private static VersionInfo GetTermManagerVersionInfo() loaderSignature: TermManagerLoaderSignature); } - // These are parallel to Infos. - private readonly ColumnType[] _types; - private readonly BoundTermMap[] _termMap; + private readonly TermMap[] _unboundMaps; private readonly bool[] _textMetadata; - private const string RegistrationName = "Term"; - private static volatile MemoryStreamPool _codecFactoryPool; - private volatile CodecFactory _codecFactory; + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + { + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); + } - private CodecFactory CodecFactory + private ColInfo[] CreateInfos(ISchema schema) { - get + Host.AssertValue(schema); + var infos = new ColInfo[ColumnPairs.Length]; + for (int i = 0; i < ColumnPairs.Length; i++) { - if (_codecFactory == null) + if (!schema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) + throw Host.ExceptUserArg(nameof(ColumnPairs), "Source column '{0}' not found", ColumnPairs[i].input); + var type = schema.GetColumnType(colSrc); + string reason = TestIsKnownDataKind(type); + if (reason != null) + throw Host.ExceptUserArg(nameof(ColumnPairs), InvalidTypeErrorFormat, ColumnPairs[i].input, type, reason); + infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type); + } + return infos; + } + + public TermTransform(IHostEnvironment env, IDataView input, + params ColumnInfo[] columns) : + this(env, input, columns, null, null, null) + { } + + private TermTransform(IHostEnvironment env, IDataView input, + ColumnInfo[] columns, + string file = null, string termsColumn = null, + IComponentFactory loaderFactory = null) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) + { + using (var ch = Host.Start("Training")) + { + var infos = CreateInfos(Host, ColumnPairs, input.Schema, TestIsKnownDataKind); + _unboundMaps = Train(Host, ch, infos, file, termsColumn, loaderFactory, columns, input); + _textMetadata = new bool[_unboundMaps.Length]; + for (int iinfo = 0; iinfo < columns.Length; ++iinfo) { - Interlocked.CompareExchange(ref _codecFactoryPool, new MemoryStreamPool(), null); - Interlocked.CompareExchange(ref _codecFactory, new CodecFactory(Host, _codecFactoryPool), null); + _textMetadata[iinfo] = columns[iinfo].TextKeyValues; } - Host.Assert(_codecFactory != null); - return _codecFactory; + ch.Assert(_unboundMaps.Length == columns.Length); + ch.Done(); + } + } + + // Factory method for SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + env.CheckValue(args.Column, nameof(args.Column)); + var cols = new ColumnInfo[args.Column.Length]; + using (var ch = env.Start("ValidateArgs")) + { + if ((args.Term != null || !string.IsNullOrEmpty(args.Terms)) && + (!string.IsNullOrWhiteSpace(args.DataFile) || args.Loader != null || + !string.IsNullOrWhiteSpace(args.TermsColumn))) + { + ch.Warning("Explicit term list specified. Data file arguments will be ignored"); + } + if (!Enum.IsDefined(typeof(SortOrder), args.Sort)) + throw ch.ExceptUserArg(nameof(args.Sort), "Undefined sorting criteria '{0}' detected", args.Sort); + + for (int i = 0; i < cols.Length; i++) + { + var item = args.Column[i]; + var sortOrder = item.Sort ?? args.Sort; + if (!Enum.IsDefined(typeof(SortOrder), sortOrder)) + throw env.ExceptUserArg(nameof(args.Sort), "Undefined sorting criteria '{0}' detected for column '{1}'", sortOrder, item.Name); + + cols[i] = new ColumnInfo(item.Source, + item.Name, + item.MaxNumTerms ?? args.MaxNumTerms, + sortOrder, + item.Term, + item.TextKeyValues ?? args.TextKeyValues); + cols[i].Terms = item.Terms; + }; } + return new TermTransform(env, input, cols, args.DataFile, args.TermsColumn, args.Loader).MakeDataTransform(input); + } + + // Factory method for SignatureLoadModel. + public static TermTransform Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(RegistrationName); + + host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + return new TermTransform(host, ctx); + } + + private TermTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) + { + var columnsLength = ColumnPairs.Length; + + if (ctx.Header.ModelVerWritten >= VerNonTextTypesSupported) + _textMetadata = ctx.Reader.ReadBoolArray(columnsLength); + else + _textMetadata = new bool[columnsLength]; // No need to set in this case. They're all text. + + const string dir = "Vocabulary"; + var termMap = new TermMap[columnsLength]; + bool b = ctx.TryProcessSubModel(dir, + c => + { + // *** Binary format *** + // int: number of term maps (should equal number of columns) + // for each term map: + // byte: code identifying the term map type (0 text, 1 codec) + // : type specific format, see TermMap save/load methods + + host.CheckValue(c, nameof(ctx)); + c.CheckAtModel(GetTermManagerVersionInfo()); + int cmap = c.Reader.ReadInt32(); + host.CheckDecode(cmap == columnsLength); + if (c.Header.ModelVerWritten >= VerManagerNonTextTypesSupported) + { + for (int i = 0; i < columnsLength; ++i) + termMap[i] = TermMap.Load(c, host, CodecFactory); + } + else + { + for (int i = 0; i < columnsLength; ++i) + termMap[i] = TermMap.TextImpl.Create(c, host); + } + }); +#pragma warning disable MSML_NoMessagesForLoadContext // Vaguely useful. + if (!b) + throw host.ExceptDecode("Missing {0} model", dir); +#pragma warning restore MSML_NoMessagesForLoadContext + _unboundMaps = termMap; } - public override bool CanSavePfa => true; - public override bool CanSaveOnnx => true; + // Factory method for SignatureLoadDataTransform. + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); /// /// Convenience constructor for public facing API. @@ -212,106 +395,85 @@ private CodecFactory CodecFactory /// Maximum number of terms to keep per column when auto-training. /// How items should be ordered when vectorized. By default, they will be in the order encountered. /// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a'). - public TermTransform(IHostEnvironment env, - IDataView input, - string name, - string source = null, - int maxNumTerms = Defaults.MaxNumTerms, - SortOrder sort = Defaults.Sort) - : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, MaxNumTerms = maxNumTerms, Sort = sort }, input) - { - } + public static IDataView Create(IHostEnvironment env, + IDataView input, string name, string source = null, + int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort) => + new TermTransform(env, input, new[] { new ColumnInfo(source ?? name, name, maxNumTerms, sort) }).MakeDataTransform(input); - /// - /// Public constructor corresponding to SignatureDataTransform. - /// - public TermTransform(IHostEnvironment env, Arguments args, IDataView input) - : this(args, Contracts.CheckRef(args, nameof(args)).Column, env, input) - { - } + //REVIEW: This and static method below need to go to base class as it get created. + private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}."; - /// - /// Re-apply constructor. - /// - private TermTransform(IHostEnvironment env, TermTransform transform, IDataView newSource) - : base(env, RegistrationName, transform, newSource, TestIsKnownDataKind) + private static ColInfo[] CreateInfos(IHostEnvironment env, (string source, string name)[] columns, ISchema schema, Func testType) { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == transform.Infos.Length); + env.CheckUserArg(Utils.Size(columns) > 0, nameof(columns)); + env.AssertValue(schema); + env.AssertValueOrNull(testType); - _textMetadata = transform._textMetadata; - _termMap = new BoundTermMap[Infos.Length]; - for (int iinfo = 0; iinfo < Infos.Length; ++iinfo) + var infos = new ColInfo[columns.Length]; + for (int i = 0; i < columns.Length; i++) { - TermMap map = transform._termMap[iinfo].Map; - if (!map.ItemType.Equals(Infos[iinfo].TypeSrc.ItemType)) + if (!schema.TryGetColumnIndex(columns[i].source, out int colSrc)) + throw env.ExceptUserArg(nameof(columns), "Source column '{0}' not found", columns[i].source); + var type = schema.GetColumnType(colSrc); + if (testType != null) { - // Column with the same name, but different types. - throw Host.Except( - "For column '{0}', term map was trained on items of type '{1}' but being applied to type '{2}'", - Infos[iinfo].Name, map.ItemType, Infos[iinfo].TypeSrc.ItemType); + string reason = testType(type); + if (reason != null) + throw env.ExceptUserArg(nameof(columns), InvalidTypeErrorFormat, columns[i].source, type, reason); } - _termMap[iinfo] = map.Bind(this, iinfo); + infos[i] = new ColInfo(columns[i].name, columns[i].source, type); } - _types = ComputeTypesAndMetadata(); + return infos; } - public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) + public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, ColumnBase[] column, IDataView input) { - return new TermTransform(env, this, newSource); - } - - /// - /// Public constructor for compositional forms. - /// - public TermTransform(ArgumentsBase args, ColumnBase[] column, IHostEnvironment env, IDataView input) - : base(env, RegistrationName, column, input, TestIsKnownDataKind) - { - Host.CheckValue(args, nameof(args)); - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(column)); - - using (var ch = Host.Start("Training")) + return Create(env, new Arguments() { - TermMap[] unboundMaps = Train(Host, ch, Infos, args, column, Source); - ch.Assert(unboundMaps.Length == Infos.Length); - _textMetadata = new bool[unboundMaps.Length]; - _termMap = new BoundTermMap[unboundMaps.Length]; - for (int iinfo = 0; iinfo < Infos.Length; ++iinfo) + Column = column.Select(x => new Column() { - _textMetadata[iinfo] = column[iinfo].TextKeyValues ?? args.TextKeyValues; - _termMap[iinfo] = unboundMaps[iinfo].Bind(this, iinfo); - } - _types = ComputeTypesAndMetadata(); - ch.Done(); - } + MaxNumTerms = x.MaxNumTerms, + Name = x.Name, + Sort = x.Sort, + Source = x.Source, + Term = x.Term, + Terms = x.Terms, + TextKeyValues = x.TextKeyValues + }).ToArray(), + Data = args.Data, + DataFile = args.DataFile, + Loader = args.Loader, + MaxNumTerms = args.MaxNumTerms, + Sort = args.Sort, + Term = args.Term, + Terms = args.Terms, + TermsColumn = args.TermsColumn, + TextKeyValues = args.TextKeyValues + }, input); } - private static string TestIsKnownDataKind(ColumnType type) + internal static string TestIsKnownDataKind(ColumnType type) { - if (type.ItemType.RawKind != default(DataKind) && (type.IsVector || type.IsPrimitive)) + if (type.ItemType.RawKind != default && (type.IsVector || type.IsPrimitive)) return null; return "Expected standard type or a vector of standard type"; } /// - /// Utility method to create the file-based if the - /// argument of was present. + /// Utility method to create the file-based . /// - private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, ArgumentsBase args, Builder bldr) + private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, string file, string termsColumn, + IComponentFactory loaderFactory, Builder bldr) { Contracts.AssertValue(ch); ch.AssertValue(env); - ch.AssertValue(args); - ch.Assert(!string.IsNullOrWhiteSpace(args.DataFile)); + ch.Assert(!string.IsNullOrWhiteSpace(file)); ch.AssertValue(bldr); - string file = args.DataFile; // First column using the file. - string src = args.TermsColumn; + string src = termsColumn; IMultiStreamSource fileSource = new MultiFileSource(file); - var loaderFactory = args.Loader; // If the user manually specifies a loader, or this is already a pre-processed binary // file, then we assume the user knows what they're doing and do not attempt to convert // to the desired type ourselves. @@ -330,7 +492,7 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu if (isBinary || isTranspose) { ch.Assert(isBinary != isTranspose); - ch.CheckUserArg(!string.IsNullOrWhiteSpace(src), nameof(args.TermsColumn), + ch.CheckUserArg(!string.IsNullOrWhiteSpace(src), nameof(termsColumn), "Must be specified"); if (isBinary) termData = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource); @@ -363,10 +525,10 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu int colSrc; if (!termData.Schema.TryGetColumnIndex(src, out colSrc)) - throw ch.ExceptUserArg(nameof(args.TermsColumn), "Unknown column '{0}'", src); + throw ch.ExceptUserArg(nameof(termsColumn), "Unknown column '{0}'", src); var typeSrc = termData.Schema.GetColumnType(colSrc); if (!autoConvert && !typeSrc.Equals(bldr.ItemType)) - throw ch.ExceptUserArg(nameof(args.TermsColumn), "Must be of type '{0}' but was '{1}'", bldr.ItemType, typeSrc); + throw ch.ExceptUserArg(nameof(termsColumn), "Must be of type '{0}' but was '{1}'", bldr.ItemType, typeSrc); using (var cursor = termData.GetRowCursor(col => col == colSrc)) using (var pch = env.StartProgressChannel("Building term dictionary from file")) @@ -396,25 +558,15 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu /// This builds the instances per column. /// private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] infos, - ArgumentsBase args, ColumnBase[] column, IDataView trainingData) + string file, string termsColumn, + IComponentFactory loaderFactory, ColumnInfo[] columns, IDataView trainingData) { Contracts.AssertValue(env); env.AssertValue(ch); ch.AssertValue(infos); - ch.AssertValue(args); - ch.AssertValue(column); + ch.AssertValue(columns); ch.AssertValue(trainingData); - if ((args.Term != null || !string.IsNullOrEmpty(args.Terms)) && - (!string.IsNullOrWhiteSpace(args.DataFile) || args.Loader != null || - !string.IsNullOrWhiteSpace(args.TermsColumn))) - { - ch.Warning("Explicit term list specified. Data file arguments will be ignored"); - } - - if (!Enum.IsDefined(typeof(SortOrder), args.Sort)) - throw ch.ExceptUserArg(nameof(args.Sort), "Undefined sorting criteria '{0}' detected", args.Sort); - TermMap termsFromFile = null; var termMap = new TermMap[infos.Length]; int[] lims = new int[infos.Length]; @@ -424,36 +576,28 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info for (int iinfo = 0; iinfo < infos.Length; iinfo++) { // First check whether we have a terms argument, and handle it appropriately. - var terms = new DvText(column[iinfo].Terms); - var termsArray = column[iinfo].Term; - if (!terms.HasChars && termsArray == null) - { - terms = new DvText(args.Terms); - termsArray = args.Term; - } + var terms = new DvText(columns[iinfo].Terms); + var termsArray = columns[iinfo].Term; terms = terms.Trim(); if (terms.HasChars || (termsArray != null && termsArray.Length > 0)) { // We have terms! Pass it in. - var sortOrder = column[iinfo].Sort ?? args.Sort; - if (!Enum.IsDefined(typeof(SortOrder), sortOrder)) - throw ch.ExceptUserArg(nameof(args.Sort), "Undefined sorting criteria '{0}' detected for column '{1}'", sortOrder, infos[iinfo].Name); - + var sortOrder = columns[iinfo].Sort; var bldr = Builder.Create(infos[iinfo].TypeSrc, sortOrder); - if(terms.HasChars) + if (terms.HasChars) bldr.ParseAddTermArg(ref terms, ch); else bldr.ParseAddTermArg(termsArray, ch); termMap[iinfo] = bldr.Finish(); } - else if (!string.IsNullOrWhiteSpace(args.DataFile)) + else if (!string.IsNullOrWhiteSpace(file)) { // First column using this file. if (termsFromFile == null) { - var bldr = Builder.Create(infos[iinfo].TypeSrc, column[iinfo].Sort ?? args.Sort); - termsFromFile = CreateFileTermMap(env, ch, args, bldr); + var bldr = Builder.Create(infos[iinfo].TypeSrc, columns[iinfo].Sort); + termsFromFile = CreateFileTermMap(env, ch, file, termsColumn, loaderFactory, bldr); } if (!termsFromFile.ItemType.Equals(infos[iinfo].TypeSrc.ItemType)) { @@ -462,7 +606,7 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info // a complicated feature would be, and also because it's difficult to see how we // can logically reconcile "reinterpretation" for different types with the resulting // data view having an actual type. - throw ch.ExceptUserArg(nameof(args.DataFile), "Data file terms loaded as type '{0}' but mismatches column '{1}' item type '{2}'", + throw ch.ExceptUserArg(nameof(file), "Data file terms loaded as type '{0}' but mismatches column '{1}' item type '{2}'", termsFromFile.ItemType, infos[iinfo].Name, infos[iinfo].TypeSrc.ItemType); } termMap[iinfo] = termsFromFile; @@ -470,9 +614,10 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info else { // Auto train this column. Leave the term map null for now, but set the lim appropriately. - lims[iinfo] = column[iinfo].MaxNumTerms ?? args.MaxNumTerms; + lims[iinfo] = columns[iinfo].MaxNumTerms; ch.CheckUserArg(lims[iinfo] > 0, nameof(Column.MaxNumTerms), "Must be positive"); - Utils.Add(ref toTrain, infos[iinfo].Source); + Contracts.Check(trainingData.Schema.TryGetColumnIndex(infos[iinfo].Source, out int colIndex)); + Utils.Add(ref toTrain, colIndex); ++trainsNeeded; } } @@ -497,9 +642,10 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info { if (termMap[iinfo] != null) continue; - var bldr = Builder.Create(infos[iinfo].TypeSrc, column[iinfo].Sort ?? args.Sort); + var bldr = Builder.Create(infos[iinfo].TypeSrc, columns[iinfo].Sort); trainerInfo[itrainer] = iinfo; - trainer[itrainer++] = Trainer.Create(cursor, infos[iinfo].Source, false, lims[iinfo], bldr); + trainingData.Schema.TryGetColumnIndex(infos[iinfo].Source, out int colIndex); + trainer[itrainer++] = Trainer.Create(cursor, colIndex, false, lims[iinfo], bldr); } ch.Assert(itrainer == trainer.Length); pch.SetHeader(header, @@ -548,116 +694,17 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info return termMap; } - // Computes the types of the columns. - private ColumnType[] ComputeTypesAndMetadata() - { - Contracts.Assert(Utils.Size(Infos) > 0); - Contracts.Assert(Utils.Size(Infos) == Utils.Size(_termMap)); - - var md = Metadata; - var types = new ColumnType[Infos.Length]; - for (int iinfo = 0; iinfo < types.Length; iinfo++) - { - Contracts.Assert(types[iinfo] == null); - - var info = Infos[iinfo]; - KeyType keyType = _termMap[iinfo].Map.OutputType; - Host.Assert(keyType.KeyCount > 0); - if (info.TypeSrc.IsVector) - types[iinfo] = new VectorType(keyType, info.TypeSrc.AsVector); - else - types[iinfo] = keyType; - - // Inherit slot names from source. - using (var bldr = md.BuildMetadata(iinfo, Source.Schema, info.Source, MetadataUtils.Kinds.SlotNames)) - { - // Add key values metadata. It is legal to not add anything, in which case - // this builder performs no operations except passing slot names. - _termMap[iinfo].AddMetadata(bldr); - } - } - md.Seal(); - return types; - } - - private TermTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestIsKnownDataKind) - { - Host.AssertValue(ctx); - - // *** Binary format *** - // for each term map: - // bool(byte): whether this column should present key value metadata as text - - int cinfo = Infos.Length; - Host.Assert(cinfo > 0); - - if (ctx.Header.ModelVerWritten >= VerNonTextTypesSupported) - _textMetadata = ctx.Reader.ReadBoolArray(cinfo); - else - _textMetadata = new bool[cinfo]; // No need to set in this case. They're all text. - - const string dir = "Vocabulary"; - TermMap[] termMap = new TermMap[cinfo]; - bool b = ctx.TryProcessSubModel(dir, - c => - { - // *** Binary format *** - // int: number of term maps (should equal number of columns) - // for each term map: - // byte: code identifying the term map type (0 text, 1 codec) - // : type specific format, see TermMap save/load methods - - Host.CheckValue(c, nameof(ctx)); - c.CheckAtModel(GetTermManagerVersionInfo()); - int cmap = c.Reader.ReadInt32(); - Host.CheckDecode(cmap == cinfo); - if (c.Header.ModelVerWritten >= VerManagerNonTextTypesSupported) - { - for (int i = 0; i < cinfo; ++i) - termMap[i] = TermMap.Load(c, host, this); - } - else - { - for (int i = 0; i < cinfo; ++i) - termMap[i] = TermMap.TextImpl.Create(c, host); - } - }); -#pragma warning disable MSML_NoMessagesForLoadContext // Vaguely useful. - if (!b) - throw Host.ExceptDecode("Missing {0} model", dir); -#pragma warning restore MSML_NoMessagesForLoadContext - _termMap = new BoundTermMap[cinfo]; - for (int i = 0; i < cinfo; ++i) - _termMap[i] = termMap[i].Bind(this, i); - - _types = ComputeTypesAndMetadata(); - } - - public static TermTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); - env.CheckValue(input, nameof(input)); - env.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - return h.Apply("Loading Model", ch => new TermTransform(h, ctx, input)); - } - public override void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - // *** Binary format *** - // for each term map: - // bool(byte): whether this column should present key value metadata as text - SaveBase(ctx); + base.SaveColumns(ctx); - Host.Assert(_termMap.Length == Infos.Length); - Host.Assert(_textMetadata.Length == Infos.Length); + Host.Assert(_unboundMaps.Length == _textMetadata.Length); + Host.Assert(_textMetadata.Length == ColumnPairs.Length); ctx.Writer.WriteBoolBytesNoCount(_textMetadata, _textMetadata.Length); // REVIEW: Should we do separate sub models for each dictionary? @@ -674,85 +721,194 @@ public override void Save(ModelSaveContext ctx) Host.CheckValue(c, nameof(ctx)); c.CheckAtModel(); c.SetVersionInfo(GetTermManagerVersionInfo()); - c.Writer.Write(_termMap.Length); - foreach (var term in _termMap) - term.Map.Save(c, this); + c.Writer.Write(_unboundMaps.Length); + foreach (var term in _unboundMaps) + term.Save(c, Host, CodecFactory); c.SaveTextStream("Terms.txt", writer => { - foreach (var map in _termMap) + foreach (var map in _unboundMaps) map.WriteTextTerms(writer); }); }); } - protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToken srcToken) + protected override IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(this, schema); + + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Contracts.AssertValue(ctx); - Contracts.Assert(0 <= iinfo && iinfo < Infos.Length); - Contracts.Assert(Infos[iinfo] == info); - Contracts.AssertValue(srcToken); - Contracts.Assert(CanSavePfa); + if ((inputSchema.GetColumnType(srcCol).ItemType.RawKind == default)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, "image", inputSchema.GetColumnType(srcCol).ToString()); + } - if (!info.TypeSrc.ItemType.IsText) - return null; - var terms = default(VBuffer); - TermMap map = (TermMap)_termMap[iinfo].Map; - map.GetTerms(ref terms); - var jsonMap = new JObject(); - foreach (var kv in terms.Items()) - jsonMap[kv.Value.ToString()] = kv.Key; - string cellName = ctx.DeclareCell( - "TermMap", PfaUtils.Type.Map(PfaUtils.Type.Int), jsonMap); - JObject cellRef = PfaUtils.Cell(cellName); - - if (info.TypeSrc.IsVector) + private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa + { + private readonly ColumnType[] _types; + private readonly TermTransform _parent; + private readonly ColInfo[] _infos; + + private readonly BoundTermMap[] _termMap; + + public bool CanSaveOnnx => true; + + public bool CanSavePfa => true; + + public Mapper(TermTransform parent, ISchema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { - var funcName = ctx.GetFreeFunctionName("mapTerm"); - ctx.Pfa.AddFunc(funcName, new JArray(PfaUtils.Param("term", PfaUtils.Type.String)), - PfaUtils.Type.Int, PfaUtils.If(PfaUtils.Call("map.containsKey", cellRef, "term"), PfaUtils.Index(cellRef, "term"), -1)); - var funcRef = PfaUtils.FuncRef("u." + funcName); - return PfaUtils.Call("a.map", srcToken, funcRef); + _parent = parent; + _infos = _parent.CreateInfos(inputSchema); + _types = new ColumnType[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + var type = _infos[i].TypeSrc; + KeyType keyType = _parent._unboundMaps[i].OutputType; + ColumnType colType; + if (type.IsVector) + colType = new VectorType(keyType, type.AsVector); + else + colType = keyType; + _types[i] = colType; + } + _termMap = new BoundTermMap[_parent.ColumnPairs.Length]; + for (int iinfo = 0; iinfo < _parent.ColumnPairs.Length; ++iinfo) + { + _termMap[iinfo] = _parent._unboundMaps[iinfo].Bind(Host, inputSchema, _infos, _parent._textMetadata, iinfo); + } } - return PfaUtils.If(PfaUtils.Call("map.containsKey", cellRef, srcToken), PfaUtils.Index(cellRef, srcToken), -1); - } - protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) - { - if (!info.TypeSrc.ItemType.IsText) - return false; - - var terms = default(VBuffer); - TermMap map = (TermMap)_termMap[iinfo].Map; - map.GetTerms(ref terms); - string opType = "LabelEncoder"; - var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - node.AddAttribute("classes_strings", terms.DenseValues()); - node.AddAttribute("default_int64", -1); - //default_string needs to be an empty string but there is a BUG in Lotus that - //throws a validation error when default_string is empty. As a work around, set - //default_string to a space. - node.AddAttribute("default_string", " "); - return true; - } + public override RowMapperColumnInfo[] GetOutputColumns() + { + var result = new RowMapperColumnInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + Host.Assert(colIndex >= 0); + var colMetaInfo = new ColumnMetadataInfo(_parent.ColumnPairs[i].output); + _termMap[i].AddMetadata(colMetaInfo); - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Host.Assert(0 <= iinfo & iinfo < _types.Length); - var type = _types[iinfo]; - Host.Assert(type != null); - return type; - } + foreach (var type in InputSchema.GetMetadataTypes(colIndex).Where(x => x.Key == MetadataUtils.Kinds.SlotNames)) + { + Utils.MarshalInvoke(AddMetaGetter, type.Value.RawType, colMetaInfo, InputSchema, type.Key, type.Value, ColMapNewToOld); + } + result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _types[i], colMetaInfo); + } + return result; + } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) - { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); + disposer = null; + var type = _termMap[iinfo].Map.OutputType; + return Utils.MarshalInvoke(MakeGetter, type.RawType, input, iinfo); + } + + private Delegate MakeGetter(IRow row, int src) => _termMap[src].GetMappingGetter(row); + + private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + { + if (!info.TypeSrc.ItemType.IsText) + return false; - return _termMap[iinfo].GetMappingGetter(input); + var terms = default(VBuffer); + TermMap map = (TermMap)_termMap[iinfo].Map; + map.GetTerms(ref terms); + string opType = "LabelEncoder"; + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("classes_strings", terms.DenseValues()); + node.AddAttribute("default_int64", -1); + //default_string needs to be an empty string but there is a BUG in Lotus that + //throws a validation error when default_string is empty. As a work around, set + //default_string to a space. + node.AddAttribute("default_string", " "); + return true; + } + + public void SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + + for (int iinfo = 0; iinfo < _infos.Length; ++iinfo) + { + ColInfo info = _infos[iinfo]; + string sourceColumnName = info.Source; + if (!ctx.ContainsColumn(sourceColumnName)) + { + ctx.RemoveColumn(info.Name, false); + continue; + } + + if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName), + ctx.AddIntermediateVariable(_types[iinfo], info.Name))) + { + ctx.RemoveColumn(info.Name, true); + } + } + } + + public void SaveAsPfa(BoundPfaContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + + var toHide = new List(); + var toDeclare = new List>(); + + for (int iinfo = 0; iinfo < _infos.Length; ++iinfo) + { + var info = _infos[iinfo]; + var srcName = info.Source; + string srcToken = ctx.TokenOrNullForName(srcName); + if (srcToken == null) + { + toHide.Add(info.Name); + continue; + } + var result = SaveAsPfaCore(ctx, iinfo, info, srcToken); + if (result == null) + { + toHide.Add(info.Name); + continue; + } + toDeclare.Add(new KeyValuePair(info.Name, result)); + } + ctx.Hide(toHide.ToArray()); + ctx.DeclareVar(toDeclare.ToArray()); + } + + private JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToken srcToken) + { + Contracts.AssertValue(ctx); + Contracts.Assert(0 <= iinfo && iinfo < _infos.Length); + Contracts.Assert(_infos[iinfo] == info); + Contracts.AssertValue(srcToken); + //Contracts.Assert(CanSavePfa); + + if (!info.TypeSrc.ItemType.IsText) + return null; + var terms = default(VBuffer); + TermMap map = (TermMap)_termMap[iinfo].Map; + map.GetTerms(ref terms); + var jsonMap = new JObject(); + foreach (var kv in terms.Items()) + jsonMap[kv.Value.ToString()] = kv.Key; + string cellName = ctx.DeclareCell( + "TermMap", PfaUtils.Type.Map(PfaUtils.Type.Int), jsonMap); + JObject cellRef = PfaUtils.Cell(cellName); + + if (info.TypeSrc.IsVector) + { + var funcName = ctx.GetFreeFunctionName("mapTerm"); + ctx.Pfa.AddFunc(funcName, new JArray(PfaUtils.Param("term", PfaUtils.Type.String)), + PfaUtils.Type.Int, PfaUtils.If(PfaUtils.Call("map.containsKey", cellRef, "term"), PfaUtils.Index(cellRef, "term"), -1)); + var funcRef = PfaUtils.FuncRef("u." + funcName); + return PfaUtils.Call("a.map", srcToken, funcRef); + } + return PfaUtils.If(PfaUtils.Call("map.containsKey", cellRef, srcToken), PfaUtils.Index(cellRef, srcToken), -1); + } } } } diff --git a/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs b/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs index 9a43dc5517..d428d06489 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs @@ -1,21 +1,18 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -#pragma warning disable 420 // volatile with Interlocked.CompareExchange - using System; using System.IO; using System.Text; +using System.Threading; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; namespace Microsoft.ML.Runtime.Data { - // Implementations of the helper objects for term transform. - - public sealed partial class TermTransform : OneToOneTransformBase, ITransformTemplate + public sealed partial class TermTransform { /// /// These are objects shared by both the scalar and vector implementations of @@ -67,7 +64,7 @@ private static Builder CreateCore(PrimitiveType type, bool sorted) // of building our term dictionary. For the other types (practically, only the UX types), // we should ignore nothing. RefPredicate mapsToMissing; - if (!Conversion.Conversions.Instance.TryGetIsNAPredicate(type, out mapsToMissing)) + if (!Runtime.Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(type, out mapsToMissing)) mapsToMissing = (ref T val) => false; return new Impl(type, mapsToMissing, sorted); } @@ -208,7 +205,7 @@ public override void ParseAddTermArg(ref DvText terms, IChannel ch) { T val; var tryParse = Conversion.Conversions.Instance.GetParseConversion(ItemType); - for (bool more = true; more; ) + for (bool more = true; more;) { DvText term; more = terms.SplitOne(',', out term, out terms); @@ -485,9 +482,9 @@ protected TermMap(PrimitiveType type, int count) OutputType = new KeyType(DataKind.U4, 0, Count == 0 ? 1 : Count); } - public abstract void Save(ModelSaveContext ctx, TermTransform trans); + public abstract void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory); - public static TermMap Load(ModelLoadContext ctx, IExceptionContext ectx, TermTransform trans) + public static TermMap Load(ModelLoadContext ctx, IHostEnvironment ectx, CodecFactory codecFactory) { // *** Binary format *** // byte: map type code @@ -497,24 +494,24 @@ public static TermMap Load(ModelLoadContext ctx, IExceptionContext ectx, TermTra ectx.CheckDecode(Enum.IsDefined(typeof(MapType), mtype)); switch (mtype) { - case MapType.Text: - // Binary format defined by this method. - return TextImpl.Create(ctx, ectx); - case MapType.Codec: - // *** Binary format *** - // codec parameterization: the codec - // int: number of terms - // value codec block: the terms written in the codec-defined binary format - IValueCodec codec; - if (!trans.CodecFactory.TryReadCodec(ctx.Reader.BaseStream, out codec)) - throw ectx.Except("Unrecognized codec read"); - ectx.CheckDecode(codec.Type.IsPrimitive); - int count = ctx.Reader.ReadInt32(); - ectx.CheckDecode(count >= 0); - return Utils.MarshalInvoke(LoadCodecCore, codec.Type.RawType, ctx, ectx, codec, count); - default: - ectx.Assert(false); - throw ectx.Except("Unrecognized type '{0}'", mtype); + case MapType.Text: + // Binary format defined by this method. + return TextImpl.Create(ctx, ectx); + case MapType.Codec: + // *** Binary format *** + // codec parameterization: the codec + // int: number of terms + // value codec block: the terms written in the codec-defined binary format + IValueCodec codec; + if (!codecFactory.TryReadCodec(ctx.Reader.BaseStream, out codec)) + throw ectx.Except("Unrecognized codec read"); + ectx.CheckDecode(codec.Type.IsPrimitive); + int count = ctx.Reader.ReadInt32(); + ectx.CheckDecode(count >= 0); + return Utils.MarshalInvoke(LoadCodecCore, codec.Type.RawType, ctx, ectx, codec, count); + default: + ectx.Assert(false); + throw ectx.Except("Unrecognized type '{0}'", mtype); } } @@ -556,19 +553,18 @@ private static TermMap LoadCodecCore(ModelLoadContext ctx, IExceptionContext /// requests on the input dataset. This should throw an error if we attempt to bind this /// to the wrong type of item. /// - public BoundTermMap Bind(TermTransform trans, int iinfo) + public BoundTermMap Bind(IHostEnvironment env, ISchema schema, ColInfo[] infos, bool[] textMetadata, int iinfo) { - Contracts.AssertValue(trans); - trans.Host.Assert(0 <= iinfo && iinfo < trans.Infos.Length); + env.Assert(0 <= iinfo && iinfo < infos.Length); - var info = trans.Infos[iinfo]; + var info = infos[iinfo]; var inType = info.TypeSrc.ItemType; if (!inType.Equals(ItemType)) { - throw trans.Host.Except("Could not apply a map over type '{0}' to column '{1}' since it has type '{2}'", + throw env.Except("Could not apply a map over type '{0}' to column '{1}' since it has type '{2}'", ItemType, info.Name, inType); } - return BoundTermMap.Create(this, trans, iinfo); + return BoundTermMap.Create(env, schema, this, infos, textMetadata, iinfo); } public abstract void WriteTextTerms(TextWriter writer); @@ -614,7 +610,7 @@ public static TextImpl Create(ModelLoadContext ctx, IExceptionContext ectx) return new TextImpl(pool); } - public override void Save(ModelSaveContext ctx, TermTransform trans) + public override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory) { // *** Binary format *** // byte: map type code, in this case 'Text' (0) @@ -622,14 +618,14 @@ public override void Save(ModelSaveContext ctx, TermTransform trans) // int[]: term string ids ctx.Writer.Write((byte)MapType.Text); - trans.Host.Assert(_pool.Count >= 0); - trans.Host.CheckDecode(_pool.Get("") == null); + host.Assert(_pool.Count >= 0); + host.CheckDecode(_pool.Get("") == null); ctx.Writer.Write(_pool.Count); int id = 0; foreach (var nstr in _pool) { - trans.Host.Assert(nstr.Id == id); + host.Assert(nstr.Id == id); ctx.SaveNonEmptyString(nstr.Value); id++; } @@ -689,7 +685,7 @@ public HashArrayImpl(PrimitiveType itemType, HashArray values) _values = values; } - public override void Save(ModelSaveContext ctx, TermTransform trans) + public override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory) { // *** Binary format *** // byte: map type code, in this case 'Codec' @@ -698,12 +694,12 @@ public override void Save(ModelSaveContext ctx, TermTransform trans) // value codec block: the terms written in the codec-defined binary format IValueCodec codec; - if (!trans.CodecFactory.TryGetCodec(ItemType, out codec)) - throw trans.Host.Except("We do not know how to serialize terms of type '{0}'", ItemType); + if (!codecFactory.TryGetCodec(ItemType, out codec)) + throw host.Except("We do not know how to serialize terms of type '{0}'", ItemType); ctx.Writer.Write((byte)MapType.Codec); - trans.Host.Assert(codec.Type.Equals(ItemType)); - trans.Host.Assert(codec.Type.IsPrimitive); - trans.CodecFactory.WriteCodec(ctx.Writer.BaseStream, codec); + host.Assert(codec.Type.Equals(ItemType)); + host.Assert(codec.Type.IsPrimitive); + codecFactory.WriteCodec(ctx.Writer.BaseStream, codec); IValueCodec codecT = (IValueCodec)codec; ctx.Writer.Write(_values.Count); using (var writer = codecT.OpenWriter(ctx.Writer.BaseStream)) @@ -810,48 +806,48 @@ private abstract class BoundTermMap { public readonly TermMap Map; - private readonly TermTransform _parent; private readonly int _iinfo; private readonly bool _inputIsVector; + private readonly IHostEnvironment _host; + private readonly bool[] _textMetadata; + private readonly ColInfo[] _infos; + private readonly ISchema _schema; - private IHost Host { get { return _parent.Host; } } + private bool IsTextMetadata { get { return _textMetadata[_iinfo]; } } - private bool IsTextMetadata { get { return _parent._textMetadata[_iinfo]; } } - - private BoundTermMap(TermMap map, TermTransform trans, int iinfo) + private BoundTermMap(IHostEnvironment env, ISchema schema, TermMap map, ColInfo[] infos, bool[] textMetadata, int iinfo) { - Contracts.AssertValue(trans); - _parent = trans; - - Host.AssertValue(map); - Host.Assert(0 <= iinfo && iinfo < trans.Infos.Length); - ColInfo info = trans.Infos[iinfo]; - Host.Assert(info.TypeSrc.ItemType.Equals(map.ItemType)); + _host = env; + //assert me. + _textMetadata = textMetadata; + _infos = infos; + _schema = schema; + _host.AssertValue(map); + _host.Assert(0 <= iinfo && iinfo < infos.Length); + var info = infos[iinfo]; + _host.Assert(info.TypeSrc.ItemType.Equals(map.ItemType)); Map = map; _iinfo = iinfo; _inputIsVector = info.TypeSrc.IsVector; } - public static BoundTermMap Create(TermMap map, TermTransform trans, int iinfo) + public static BoundTermMap Create(IHostEnvironment host, ISchema schema, TermMap map, ColInfo[] infos, bool[] textMetadata, int iinfo) { - Contracts.AssertValue(trans); - var host = trans.Host; - host.AssertValue(map); - host.Assert(0 <= iinfo && iinfo < trans.Infos.Length); - ColInfo info = trans.Infos[iinfo]; + host.Assert(0 <= iinfo && iinfo < infos.Length); + var info = infos[iinfo]; host.Assert(info.TypeSrc.ItemType.Equals(map.ItemType)); - return Utils.MarshalInvoke(CreateCore, map.ItemType.RawType, map, trans, iinfo); + return Utils.MarshalInvoke(CreateCore, map.ItemType.RawType, host, schema, map, infos, textMetadata, iinfo); } - public static BoundTermMap CreateCore(TermMap map, TermTransform trans, int iinfo) + public static BoundTermMap CreateCore(IHostEnvironment env, ISchema schema, TermMap map, ColInfo[] infos, bool[] textMetadata, int iinfo) { TermMap mapT = (TermMap)map; if (mapT.ItemType.IsKey) - return new KeyImpl(mapT, trans, iinfo); - return new Impl(mapT, trans, iinfo); + return new KeyImpl(env, schema, mapT, infos, textMetadata, iinfo); + return new Impl(env, schema, mapT, infos, textMetadata, iinfo); } public abstract Delegate GetMappingGetter(IRow row); @@ -860,7 +856,7 @@ public static BoundTermMap CreateCore(TermMap map, TermTransform trans, int i /// Allows us to optionally register metadata. It is also perfectly legal for /// this to do nothing, which corresponds to there being no metadata. /// - public abstract void AddMetadata(MetadataDispatcher.Builder bldr); + public abstract void AddMetadata(ColumnMetadataInfo colMetaInfo); /// /// Writes out all terms we map to a text writer, with one line per mapped term. @@ -878,8 +874,8 @@ private abstract class Base : BoundTermMap { protected readonly TermMap TypedMap; - public Base(TermMap map, TermTransform trans, int iinfo) - : base(map, trans, iinfo) + public Base(IHostEnvironment env, ISchema schema, TermMap map, ColInfo[] infos, bool[] textMetadata, int iinfo) + : base(env, schema, map, infos, textMetadata, iinfo) { TypedMap = map; } @@ -908,10 +904,12 @@ public override Delegate GetMappingGetter(IRow input) if (!_inputIsVector) { ValueMapper map = TypedMap.GetKeyMapper(); - var info = _parent.Infos[_iinfo]; + var info = _infos[_iinfo]; T src = default(T); Contracts.Assert(!info.TypeSrc.IsVector); - ValueGetter getSrc = _parent.GetSrcGetter(input, _iinfo); + input.Schema.TryGetColumnIndex(info.Source, out int colIndex); + _host.Assert(input.IsColumnActive(colIndex)); + var getSrc = input.GetGetter(colIndex); ValueGetter retVal = (ref uint dst) => { @@ -928,9 +926,11 @@ public override Delegate GetMappingGetter(IRow input) // will have an indirect wrapping class to hold "map" and "info". This is // bad, especially since "map" is very frequently called. ValueMapper map = TypedMap.GetKeyMapper(); - var info = _parent.Infos[_iinfo]; + var info = _infos[_iinfo]; // First test whether default maps to default. If so this is sparsity preserving. - ValueGetter> getSrc = _parent.GetSrcGetter>(input, _iinfo); + input.Schema.TryGetColumnIndex(info.Source, out int colIndex); + _host.Assert(input.IsColumnActive(colIndex)); + var getSrc = input.GetGetter>(colIndex); VBuffer src = default(VBuffer); ValueGetter> retVal; // REVIEW: Consider whether possible or reasonable to not use a builder here. @@ -948,7 +948,7 @@ public override Delegate GetMappingGetter(IRow input) getSrc(ref src); int cval = src.Length; if (cv != 0 && cval != cv) - throw Host.Except("Column '{0}': TermTransform expects {1} slots, but got {2}", info.Name, cv, cval); + throw _host.Except("Column '{0}': TermTransform expects {1} slots, but got {2}", info.Name, cv, cval); if (cval == 0) { // REVIEW: Should the VBufferBuilder be changed so that it can @@ -983,7 +983,7 @@ public override Delegate GetMappingGetter(IRow input) getSrc(ref src); int cval = src.Length; if (cv != 0 && cval != cv) - throw Host.Except("Column '{0}': TermTransform expects {1} slots, but got {2}", info.Name, cv, cval); + throw _host.Except("Column '{0}': TermTransform expects {1} slots, but got {2}", info.Name, cv, cval); if (cval == 0) { // REVIEW: Should the VBufferBuilder be changed so that it can @@ -1017,7 +1017,7 @@ public override Delegate GetMappingGetter(IRow input) if (nextExplicitSlot == slot) { // This was an explicitly defined value. - Host.Assert(islot < src.Count); + _host.Assert(islot < src.Count); map(ref values[islot], ref dstItem); if (dstItem != 0) bldr.AddFeature(slot, dstItem); @@ -1025,7 +1025,7 @@ public override Delegate GetMappingGetter(IRow input) } else { - Host.Assert(slot < nextExplicitSlot); + _host.Assert(slot < nextExplicitSlot); // This is a non-defined implicit default value. No need to attempt a remap // since we already have it. bldr.AddFeature(slot, defaultMapValue); @@ -1039,11 +1039,10 @@ public override Delegate GetMappingGetter(IRow input) } } - public override void AddMetadata(MetadataDispatcher.Builder bldr) + public override void AddMetadata(ColumnMetadataInfo colMetaInfo) { if (TypedMap.Count == 0) return; - if (IsTextMetadata && !TypedMap.ItemType.IsText) { var conv = Conversion.Conversions.Instance; @@ -1052,25 +1051,27 @@ public override void AddMetadata(MetadataDispatcher.Builder bldr) MetadataUtils.MetadataGetter> getter = (int iinfo, ref VBuffer dst) => { - Host.Assert(iinfo == _iinfo); + _host.Assert(iinfo == _iinfo); // No buffer sharing convenient here. VBuffer dstT = default(VBuffer); TypedMap.GetTerms(ref dstT); GetTextTerms(ref dstT, stringMapper, ref dst); }; - bldr.AddGetter>(MetadataUtils.Kinds.KeyValues, - new VectorType(TextType.Instance, TypedMap.OutputType.KeyCount), getter); + var columnType = new VectorType(TextType.Instance, TypedMap.OutputType.KeyCount); + var info = new MetadataInfo>(columnType, getter); + colMetaInfo.Add(MetadataUtils.Kinds.KeyValues, info); } else { MetadataUtils.MetadataGetter> getter = (int iinfo, ref VBuffer dst) => { - Host.Assert(iinfo == _iinfo); + _host.Assert(iinfo == _iinfo); TypedMap.GetTerms(ref dst); }; - bldr.AddGetter>(MetadataUtils.Kinds.KeyValues, - new VectorType(TypedMap.ItemType, TypedMap.OutputType.KeyCount), getter); + var columnType = new VectorType(TypedMap.ItemType, TypedMap.OutputType.KeyCount); + var info = new MetadataInfo>(columnType, getter); + colMetaInfo.Add(MetadataUtils.Kinds.KeyValues, info); } } } @@ -1081,34 +1082,34 @@ public override void AddMetadata(MetadataDispatcher.Builder bldr) /// private sealed class KeyImpl : Base { - public KeyImpl(TermMap map, TermTransform trans, int iinfo) - : base(map, trans, iinfo) + public KeyImpl(IHostEnvironment env, ISchema schema, TermMap map, ColInfo[] infos, bool[] textMetadata, int iinfo) + : base(env, schema, map, infos, textMetadata, iinfo) { - Host.Assert(TypedMap.ItemType.IsKey); + _host.Assert(TypedMap.ItemType.IsKey); } - public override void AddMetadata(MetadataDispatcher.Builder bldr) + public override void AddMetadata(ColumnMetadataInfo colMetaInfo) { if (TypedMap.Count == 0) return; - int srcCol = _parent.Infos[_iinfo].Source; - ColumnType srcMetaType = _parent.Source.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, srcCol); + _schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol); + ColumnType srcMetaType = _schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, srcCol); if (srcMetaType == null || srcMetaType.VectorSize != TypedMap.ItemType.KeyCount || - TypedMap.ItemType.KeyCount == 0 || !Utils.MarshalInvoke(AddMetadataCore, srcMetaType.ItemType.RawType, srcMetaType.ItemType, bldr)) + TypedMap.ItemType.KeyCount == 0 || !Utils.MarshalInvoke(AddMetadataCore, srcMetaType.ItemType.RawType, srcMetaType.ItemType, colMetaInfo)) { // No valid input key-value metadata. Back off to the base implementation. - base.AddMetadata(bldr); + base.AddMetadata(colMetaInfo); } } - private bool AddMetadataCore(ColumnType srcMetaType, MetadataDispatcher.Builder bldr) + private bool AddMetadataCore(ColumnType srcMetaType, ColumnMetadataInfo colMetaInfo) { - Host.AssertValue(srcMetaType); - Host.Assert(srcMetaType.RawType == typeof(TMeta)); - Host.AssertValue(bldr); + _host.AssertValue(srcMetaType); + _host.Assert(srcMetaType.RawType == typeof(TMeta)); + _host.AssertValue(colMetaInfo); var srcType = TypedMap.ItemType.AsKey; - Host.AssertValue(srcType); + _host.AssertValue(srcType); var dstType = new KeyType(DataKind.U4, srcType.Min, srcType.Count); var convInst = Conversion.Conversions.Instance; ValueMapper conv; @@ -1116,14 +1117,14 @@ private bool AddMetadataCore(ColumnType srcMetaType, MetadataDispatcher.B // If we can't convert this type to U4, don't try to pass along the metadata. if (!convInst.TryGetStandardConversion(srcType, dstType, out conv, out identity)) return false; - int srcCol = _parent.Infos[_iinfo].Source; + _schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol); ValueGetter> getter = (ref VBuffer dst) => { VBuffer srcMeta = default(VBuffer); - _parent.Source.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, srcCol, ref srcMeta); - Host.Assert(srcMeta.Length == srcType.Count); + _schema.GetMetadata(MetadataUtils.Kinds.KeyValues, srcCol, ref srcMeta); + _host.Assert(srcMeta.Length == srcType.Count); VBuffer keyVals = default(VBuffer); TypedMap.GetTerms(ref keyVals); @@ -1136,7 +1137,7 @@ private bool AddMetadataCore(ColumnType srcMetaType, MetadataDispatcher.B T keyVal = pair.Value; conv(ref keyVal, ref convKeyVal); // The builder for the key values should not have any missings. - Host.Assert(0 < convKeyVal && convKeyVal <= srcMeta.Length); + _host.Assert(0 < convKeyVal && convKeyVal <= srcMeta.Length); srcMeta.GetItemOrDefault((int)(convKeyVal - 1), ref values[pair.Key]); } dst = new VBuffer(TypedMap.OutputType.KeyCount, values, dst.Indices); @@ -1148,29 +1149,29 @@ private bool AddMetadataCore(ColumnType srcMetaType, MetadataDispatcher.B MetadataUtils.MetadataGetter> mgetter = (int iinfo, ref VBuffer dst) => { - Host.Assert(iinfo == _iinfo); + _host.Assert(iinfo == _iinfo); var tempMeta = default(VBuffer); getter(ref tempMeta); Contracts.Assert(tempMeta.IsDense); GetTextTerms(ref tempMeta, stringMapper, ref dst); - Host.Assert(dst.Length == TypedMap.OutputType.KeyCount); + _host.Assert(dst.Length == TypedMap.OutputType.KeyCount); }; - - bldr.AddGetter>(MetadataUtils.Kinds.KeyValues, - new VectorType(TextType.Instance, TypedMap.OutputType.KeyCount), mgetter); + var columnType = new VectorType(TextType.Instance, TypedMap.OutputType.KeyCount); + var info = new MetadataInfo>(columnType, mgetter); + colMetaInfo.Add(MetadataUtils.Kinds.KeyValues, info); } else { MetadataUtils.MetadataGetter> mgetter = (int iinfo, ref VBuffer dst) => { - Host.Assert(iinfo == _iinfo); + _host.Assert(iinfo == _iinfo); getter(ref dst); - Host.Assert(dst.Length == TypedMap.OutputType.KeyCount); + _host.Assert(dst.Length == TypedMap.OutputType.KeyCount); }; - - bldr.AddGetter>(MetadataUtils.Kinds.KeyValues, - new VectorType(srcMetaType.ItemType.AsPrimitive, TypedMap.OutputType.KeyCount), mgetter); + var columnType = new VectorType(srcMetaType.ItemType.AsPrimitive, TypedMap.OutputType.KeyCount); + var info = new MetadataInfo>(columnType, mgetter); + colMetaInfo.Add(MetadataUtils.Kinds.KeyValues, info); } return true; } @@ -1180,8 +1181,8 @@ public override void WriteTextTerms(TextWriter writer) if (TypedMap.Count == 0) return; - int srcCol = _parent.Infos[_iinfo].Source; - ColumnType srcMetaType = _parent.Source.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, srcCol); + _schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol); + ColumnType srcMetaType = _schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, srcCol); if (srcMetaType == null || srcMetaType.VectorSize != TypedMap.ItemType.KeyCount || TypedMap.ItemType.KeyCount == 0 || !Utils.MarshalInvoke(WriteTextTermsCore, srcMetaType.ItemType.RawType, srcMetaType.AsVector.ItemType, writer)) { @@ -1192,10 +1193,10 @@ public override void WriteTextTerms(TextWriter writer) private bool WriteTextTermsCore(PrimitiveType srcMetaType, TextWriter writer) { - Host.AssertValue(srcMetaType); - Host.Assert(srcMetaType.RawType == typeof(TMeta)); + _host.AssertValue(srcMetaType); + _host.Assert(srcMetaType.RawType == typeof(TMeta)); var srcType = TypedMap.ItemType.AsKey; - Host.AssertValue(srcType); + _host.AssertValue(srcType); var dstType = new KeyType(DataKind.U4, srcType.Min, srcType.Count); var convInst = Conversion.Conversions.Instance; ValueMapper conv; @@ -1203,10 +1204,10 @@ private bool WriteTextTermsCore(PrimitiveType srcMetaType, TextWriter wri // If we can't convert this type to U4, don't try. if (!convInst.TryGetStandardConversion(srcType, dstType, out conv, out identity)) return false; - int srcCol = _parent.Infos[_iinfo].Source; + _schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol); VBuffer srcMeta = default(VBuffer); - _parent.Source.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, srcCol, ref srcMeta); + _schema.GetMetadata(MetadataUtils.Kinds.KeyValues, srcCol, ref srcMeta); if (srcMeta.Length != srcType.Count) return false; @@ -1225,7 +1226,7 @@ private bool WriteTextTermsCore(PrimitiveType srcMetaType, TextWriter wri T keyVal = pair.Value; conv(ref keyVal, ref convKeyVal); // The key mapping will not have admitted missing keys. - Host.Assert(0 < convKeyVal && convKeyVal <= srcMeta.Length); + _host.Assert(0 < convKeyVal && convKeyVal <= srcMeta.Length); srcMeta.GetItemOrDefault((int)(convKeyVal - 1), ref metaVal); keyStringMapper(ref keyVal, ref sb); writer.Write("{0}\t{1}", pair.Key, sb.ToString()); @@ -1238,8 +1239,8 @@ private bool WriteTextTermsCore(PrimitiveType srcMetaType, TextWriter wri private sealed class Impl : Base { - public Impl(TermMap map, TermTransform trans, int iinfo) - : base(map, trans, iinfo) + public Impl(IHostEnvironment env, ISchema schema, TermMap map, ColInfo[] infos, bool[] textMetadata, int iinfo) + : base(env, schema, map, infos, textMetadata, iinfo) { } } diff --git a/src/Microsoft.ML.Transforms/CategoricalTransform.cs b/src/Microsoft.ML.Transforms/CategoricalTransform.cs index 5c045e4a48..b19ebc7bfc 100644 --- a/src/Microsoft.ML.Transforms/CategoricalTransform.cs +++ b/src/Microsoft.ML.Transforms/CategoricalTransform.cs @@ -156,7 +156,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV args.OutputKind, args.Column, args.Column.Select(col => col.OutputKind).ToList(), - new TermTransform(args, args.Column, h, input), + TermTransform.Create(h, args, args.Column, input), h, env); } @@ -261,7 +261,7 @@ public static CommonOutputs.TransformOutput CatTransformDict(IHostEnvironment en [TlcModule.EntryPoint(Name = "Transforms.CategoricalHashOneHotVectorizer", Desc = CategoricalHashTransform.Summary, - UserName = CategoricalHashTransform.UserName , + UserName = CategoricalHashTransform.UserName, XmlInclude = new[] { @"", @""})] public static CommonOutputs.TransformOutput CatTransformHash(IHostEnvironment env, CategoricalHashTransform.Arguments input) @@ -287,7 +287,7 @@ public static CommonOutputs.TransformOutput TextToKey(IHostEnvironment env, Term host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - var xf = new TermTransform(host, input, input.Data); + var xf = TermTransform.Create(host, input, input.Data); return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; } diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index 3d29bce613..7bf05df50c 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -74,7 +74,7 @@ public static CommonOutputs.TransformOutput NGramTransform(IHostEnvironment env, public static CommonOutputs.TransformOutput TermTransform(IHostEnvironment env, TermTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "TermTransform", input); - var xf = new TermTransform(h, input, input.Data); + var xf = Data.TermTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index 605649fe6c..4f61a2d9f9 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -360,7 +360,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV naDropArgs.Column[iinfo] = new NADropTransform.Column { Name = column.Name, Source = column.Name }; } - view = new TermTransform(h, termArgs, view); + view = TermTransform.Create(h, termArgs, view); if (naDropArgs != null) view = new NADropTransform(h, naDropArgs, view); } diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs index c62274961b..9417ea6260 100644 --- a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs @@ -419,7 +419,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV Sort = termLoaderArgs.Sort, Column = termCols.ToArray() }; - view = new TermTransform(h, termArgs, view); + view = TermTransform.Create(h, termArgs, view); if (termLoaderArgs.DropUnknowns) { diff --git a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs index 6502fc2afa..fb91b03ebb 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs @@ -101,11 +101,11 @@ private static IDataView ApplyKeyToVec(List ktv, ID .ToArray() }, viewTrain); - viewTrain = new Data.TermTransform(host, - new Data.TermTransform.Arguments() + viewTrain = TermTransform.Create(host, + new TermTransform.Arguments() { Column = ktv - .Select(c => new Data.TermTransform.Column() { Name = c.Name, Source = c.Name, Terms = GetTerms(viewTrain, c.Source) }) + .Select(c => new TermTransform.Column() { Name = c.Name, Source = c.Name, Terms = GetTerms(viewTrain, c.Source) }) .ToArray(), TextKeyValues = true }, @@ -255,20 +255,20 @@ public static CommonOutputs.TransformOutput PrepareClassificationLabel(IHostEnvi return new CommonOutputs.TransformOutput { Model = new TransformModel(env, nop, input.Data), OutputData = nop }; } - var args = new Data.TermTransform.Arguments() + var args = new TermTransform.Arguments() { Column = new[] { - new Data.TermTransform.Column() + new TermTransform.Column() { Name = input.LabelColumn, Source = input.LabelColumn, TextKeyValues = input.TextKeyValues, - Sort = Data.TermTransform.SortOrder.Value + Sort = TermTransform.SortOrder.Value } } }; - var xf = new Data.TermTransform(host, args, input.Data); + var xf = TermTransform.Create(host, args, input.Data); return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; } diff --git a/test/BaselineOutput/SingleDebug/Term/Term.tsv b/test/BaselineOutput/SingleDebug/Term/Term.tsv new file mode 100644 index 0000000000..ace0b7b175 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/Term/Term.tsv @@ -0,0 +1,29 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=float1:R4:0 +#@ col=float4:R4:1-4 +#@ col=double1:R8:5 +#@ col=double4:R8:6-9 +#@ col=int1:I4:10 +#@ col=text1:TX:11 +#@ col=text2:TX:12-13 +#@ col=TermFloat1:U4[0-9]:14 +#@ col=TermFloat4:U4[0-29]:15-18 +#@ col=TermDouble1:U4[0-9]:19 +#@ col=TermDouble4:U4[0-29]:20-23 +#@ col=TermInt1:U4[0-9]:24 +#@ col=TermText1:U4[0-3]:25 +#@ col=TermText2:U4[0-10]:26-27 +#@ } +float1 age fnlwgt education-num capital-gain double1 age fnlwgt education-num capital-gain int1 text1 workclass education TermFloat1 age fnlwgt education-num capital-gain TermDouble1 age fnlwgt education-num capital-gain TermInt1 TermText1 workclass education +25 25 226802 7 0 25 25 226802 7 0 25 Private Private 11th 0 0 1 2 3 0 0 1 2 3 0 0 0 1 +38 38 89814 9 0 38 38 89814 9 0 38 Private Private HS-grad 1 4 5 6 3 1 4 5 6 3 1 0 0 2 +28 28 336951 12 0 28 28 336951 12 0 28 Local-gov Local-gov Assoc-acdm 2 7 8 9 3 2 7 8 9 3 2 1 3 4 +44 44 160323 10 7688 44 44 160323 10 7688 44 Private Private Some-college 3 10 11 12 13 3 10 11 12 13 3 0 0 5 +18 18 103497 10 0 18 18 103497 10 0 18 ? ? Some-college 4 14 15 12 3 4 14 15 12 3 4 2 6 5 +34 34 198693 6 0 34 34 198693 6 0 34 Private Private 10th 5 16 17 18 3 5 16 17 18 3 5 0 0 7 +29 29 227026 9 0 29 29 227026 9 0 29 ? ? HS-grad 6 19 20 6 3 6 19 20 6 3 6 2 6 2 +63 63 104626 15 3103 63 63 104626 15 3103 63 Self-emp-not-inc Self-emp-not-inc Prof-school 7 21 22 23 24 7 21 22 23 24 7 3 8 9 +24 24 369667 10 0 24 24 369667 10 0 24 Private Private Some-college 8 25 26 12 3 8 25 26 12 3 8 0 0 5 +55 55 104996 4 0 55 55 104996 4 0 55 Private Private 7th-8th 9 27 28 29 3 9 27 28 29 3 9 0 0 10 diff --git a/test/BaselineOutput/SingleRelease/Term/Term.tsv b/test/BaselineOutput/SingleRelease/Term/Term.tsv new file mode 100644 index 0000000000..ace0b7b175 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/Term/Term.tsv @@ -0,0 +1,29 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=float1:R4:0 +#@ col=float4:R4:1-4 +#@ col=double1:R8:5 +#@ col=double4:R8:6-9 +#@ col=int1:I4:10 +#@ col=text1:TX:11 +#@ col=text2:TX:12-13 +#@ col=TermFloat1:U4[0-9]:14 +#@ col=TermFloat4:U4[0-29]:15-18 +#@ col=TermDouble1:U4[0-9]:19 +#@ col=TermDouble4:U4[0-29]:20-23 +#@ col=TermInt1:U4[0-9]:24 +#@ col=TermText1:U4[0-3]:25 +#@ col=TermText2:U4[0-10]:26-27 +#@ } +float1 age fnlwgt education-num capital-gain double1 age fnlwgt education-num capital-gain int1 text1 workclass education TermFloat1 age fnlwgt education-num capital-gain TermDouble1 age fnlwgt education-num capital-gain TermInt1 TermText1 workclass education +25 25 226802 7 0 25 25 226802 7 0 25 Private Private 11th 0 0 1 2 3 0 0 1 2 3 0 0 0 1 +38 38 89814 9 0 38 38 89814 9 0 38 Private Private HS-grad 1 4 5 6 3 1 4 5 6 3 1 0 0 2 +28 28 336951 12 0 28 28 336951 12 0 28 Local-gov Local-gov Assoc-acdm 2 7 8 9 3 2 7 8 9 3 2 1 3 4 +44 44 160323 10 7688 44 44 160323 10 7688 44 Private Private Some-college 3 10 11 12 13 3 10 11 12 13 3 0 0 5 +18 18 103497 10 0 18 18 103497 10 0 18 ? ? Some-college 4 14 15 12 3 4 14 15 12 3 4 2 6 5 +34 34 198693 6 0 34 34 198693 6 0 34 Private Private 10th 5 16 17 18 3 5 16 17 18 3 5 0 0 7 +29 29 227026 9 0 29 29 227026 9 0 29 ? ? HS-grad 6 19 20 6 3 6 19 20 6 3 6 2 6 2 +63 63 104626 15 3103 63 63 104626 15 3103 63 Self-emp-not-inc Self-emp-not-inc Prof-school 7 21 22 23 24 7 21 22 23 24 7 3 8 9 +24 24 369667 10 0 24 24 369667 10 0 24 Private Private Some-college 8 25 26 12 3 8 25 26 12 3 8 0 0 5 +55 55 104996 4 0 55 55 104996 4 0 55 Private Private 7th-8th 9 27 28 29 3 9 27 28 29 3 9 0 0 10 diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 5f4a579914..0a1acadfd5 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -44,7 +44,7 @@ private IDataView GetBreastCancerDataView() Column = new[] { new TextLoader.Column("Label", DataKind.R4, 0), - new TextLoader.Column("Features", DataKind.R4, + new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(1, 9) }) } }, @@ -736,7 +736,7 @@ public void EntryPointPipelineEnsemble() Column = new[] { new ConcatTransform.Column() { Name = "Features", Source = new[] { "Features1", "Features2" } } } }, data); - data = new TermTransform(Env, new TermTransform.Arguments() + data = TermTransform.Create(Env, new TermTransform.Arguments() { Column = new[] { diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index d189e627f5..a402e9dd38 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -22,13 +22,13 @@ public abstract partial class TestDataPipeBase : TestDataViewBase /// /// 'Workout test' for an estimator. /// Checks the following traits: - /// - the estimator is applicable to the validFitInput, and not applicable to validTransformInput and invalidInput; - /// - the fitted transformer is applicable to validFitInput and validTransformInput, and not applicable to invalidInput; + /// - the estimator is applicable to the validFitInput and validForFitNotValidForTransformInput, and not applicable to validTransformInput and invalidInput; + /// - the fitted transformer is applicable to validFitInput and validTransformInput, and not applicable to invalidInput and validForFitNotValidForTransformInput; /// - fitted transformer can be saved and re-loaded into the transformer with the same behavior. /// - schema propagation for fitted transformer conforms to schema propagation of estimator. /// protected void TestEstimatorCore(IEstimator estimator, - IDataView validFitInput, IDataView validTransformInput = null, IDataView invalidInput = null) + IDataView validFitInput, IDataView validTransformInput = null, IDataView invalidInput = null, IDataView validForFitNotValidForTransformInput = null) { Contracts.AssertValue(estimator); Contracts.AssertValue(validFitInput); @@ -59,6 +59,12 @@ protected void TestEstimatorCore(IEstimator estimator, mustFail(() => estimator.Fit(invalidInput)); } + if (validForFitNotValidForTransformInput != null) + { + estimator.GetOutputSchema(SchemaShape.Create(validForFitNotValidForTransformInput.Schema)); + estimator.Fit(validForFitNotValidForTransformInput); + } + var transformer = estimator.Fit(validFitInput); // Save and reload. string modelPath = GetOutputPath(TestName + "-model.zip"); @@ -104,6 +110,13 @@ protected void TestEstimatorCore(IEstimator estimator, mustFail(() => loadedTransformer.GetOutputSchema(invalidInput.Schema)); mustFail(() => loadedTransformer.Transform(invalidInput)); } + if (validForFitNotValidForTransformInput != null) + { + mustFail(() => transformer.GetOutputSchema(validForFitNotValidForTransformInput.Schema)); + mustFail(() => transformer.Transform(validForFitNotValidForTransformInput)); + mustFail(() => loadedTransformer.GetOutputSchema(validForFitNotValidForTransformInput.Schema)); + mustFail(() => loadedTransformer.Transform(validForFitNotValidForTransformInput)); + } // Schema verification between estimator and transformer. var scoredTrainSchemaShape = SchemaShape.Create(transformer.GetOutputSchema(validFitInput.Schema)); diff --git a/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs b/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs index 6b0a2adc38..60d2dc2fb1 100644 --- a/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs @@ -138,7 +138,7 @@ void TestMetadataCopy() using (var env = new TlcEnvironment()) { var dataView = ComponentCreation.CreateDataView(env, data); - var term = new TermTransform(env, new TermTransform.Arguments() + var term = TermTransform.Create(env, new TermTransform.Arguments() { Column = new[] { new TermTransform.Column() { Source = "Term", Name = "T" } } }, dataView); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs index daa3148e85..ed78f1fc84 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs @@ -28,7 +28,7 @@ void DecomposableTrainAndPredict() using (var env = new TlcEnvironment()) { var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); - var term = new TermTransform(env, loader, "Label"); + var term = TermTransform.Create(env, loader, "Label"); var concat = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"); var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs index 87ceb18f42..56a213dcf9 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs @@ -30,7 +30,7 @@ void New_DecomposableTrainAndPredict() .Read(new MultiFileSource(dataPath)); var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") - .Append(new MyTermTransform(env, "Label"), TransformerScope.TrainTest) + .Append(new TermEstimator(env, "Label"), TransformerScope.TrainTest) .Append(new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label")) .Append(new MyKeyToValueTransform(env, "PredictedLabel")); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs index d2ebf51650..93654ab9ac 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs @@ -38,7 +38,7 @@ void New_Extensibility() }; var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new MyLambdaTransform(env, action)) - .Append(new MyTermTransform(env, "Label"), TransformerScope.TrainTest) + .Append(new TermEstimator(env, "Label"), TransformerScope.TrainTest) .Append(new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label")) .Append(new MyKeyToValueTransform(env, "PredictedLabel")); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs index 3624cfa3b5..18726c09bc 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs @@ -29,7 +29,7 @@ public void New_Metacomponents() var sdcaTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label"); var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") - .Append(new MyTermTransform(env, "Label"), TransformerScope.TrainTest) + .Append(new TermEstimator(env, "Label"), TransformerScope.TrainTest) .Append(new MyOva(env, sdcaTrainer)) .Append(new MyKeyToValueTransform(env, "PredictedLabel")); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs index 961cfad58e..ed4283f233 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs @@ -325,33 +325,6 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) } } - public class MyTermTransform : IEstimator - { - private readonly IHostEnvironment _env; - private readonly string _column; - private readonly string _srcColumn; - - public MyTermTransform(IHostEnvironment env, string column, string srcColumn = null) - { - _env = env; - _column = column; - _srcColumn = srcColumn; - } - - public TransformWrapper Fit(IDataView input) - { - var xf = new TermTransform(_env, input, _column, _srcColumn); - var empty = new EmptyDataView(_env, input.Schema); - var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input); - return new TransformWrapper(_env, chunk); - } - - public SchemaShape GetOutputSchema(SchemaShape inputSchema) - { - throw new NotImplementedException(); - } - } - public class MyConcatTransform : IEstimator { private readonly IHostEnvironment _env; @@ -587,7 +560,7 @@ public MyOva(IHostEnvironment env, ITrainerEstimator PredictionKind.MultiClassClassification; - private static TrainerInfo MakeTrainerInfo(ITrainerEstimator, TScalarPredictor> estimator) + private static TrainerInfo MakeTrainerInfo(ITrainerEstimator, TScalarPredictor> estimator) => new TrainerInfo(estimator.Info.NeedNormalization, estimator.Info.NeedCalibration, false); protected override ScorerWrapper MakeScorer(OvaPredictor predictor, RoleMappedData data) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs index daf257e9a8..c8240e69b9 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs @@ -31,7 +31,7 @@ void Extensibility() j.SepalWidth = i.SepalWidth; }; var lambda = LambdaTransform.CreateMap(env, loader, action); - var term = new TermTransform(env, lambda, "Label"); + var term = TermTransform.Create(env, lambda, "Label"); var concat = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"); var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs index 11e8dd196c..0fb4dec56d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs @@ -25,7 +25,7 @@ public void Metacomponents() using (var env = new TlcEnvironment()) { var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); - var term = new TermTransform(env, loader, "Label"); + var term = TermTransform.Create(env, loader, "Label"); var concat = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"); var trainer = new Ova(env, new Ova.Arguments { diff --git a/test/Microsoft.ML.Tests/TermEstimatorTests.cs b/test/Microsoft.ML.Tests/TermEstimatorTests.cs new file mode 100644 index 0000000000..3a1a08ac38 --- /dev/null +++ b/test/Microsoft.ML.Tests/TermEstimatorTests.cs @@ -0,0 +1,196 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using System; +using System.IO; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests +{ + public class TermEstimatorTests : TestDataPipeBase + { + public TermEstimatorTests(ITestOutputHelper output) : base(output) + { + } + + class TestClass + { + public int A; + public int B; + public int C; + } + + class TestClassXY + { + public int X; + public int Y; + } + + class TestClassDifferentTypes + { + public string A; + public string B; + public string C; + } + + + class TestMetaClass + { + public int NotUsed; + public string Term; + } + + [Fact] + void TestDifferntTypes() + { + string dataPath = GetDataPath("adult.test"); + + var loader = new TextLoader(Env, new TextLoader.Arguments + { + Column = new[]{ + new TextLoader.Column("float1", DataKind.R4, 0), + new TextLoader.Column("float4", DataKind.R4, new[]{new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10) }), + new TextLoader.Column("double1", DataKind.R8, 0), + new TextLoader.Column("double4", DataKind.R8, new[]{new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10) }), + new TextLoader.Column("int1", DataKind.I4, 0), + new TextLoader.Column("text1", DataKind.TX, 1), + new TextLoader.Column("text2", DataKind.TX, new[]{new TextLoader.Range(1), new TextLoader.Range(3)}), + }, + Separator = ",", + HasHeader = true + }, new MultiFileSource(dataPath)); + + var pipe = new TermEstimator(Env, new[]{ + new TermTransform.ColumnInfo("float1", "TermFloat1"), + new TermTransform.ColumnInfo("float4", "TermFloat4"), + new TermTransform.ColumnInfo("double1", "TermDouble1"), + new TermTransform.ColumnInfo("double4", "TermDouble4"), + new TermTransform.ColumnInfo("int1", "TermInt1"), + new TermTransform.ColumnInfo("text1", "TermText1"), + new TermTransform.ColumnInfo("text2", "TermText2") + }); + var data = loader.Read(new MultiFileSource(dataPath)); + data = TakeFilter.Create(Env, data, 10); + var outputPath = GetOutputPath("Term", "Term.tsv"); + using (var ch = Env.Start("save")) + { + var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); + using (var fs = File.Create(outputPath)) + DataSaverUtils.SaveDataView(ch, saver, pipe.Fit(data).Transform(data), fs, keepHidden: true); + } + + CheckEquality("Term", "Term.tsv"); + Done(); + } + + [Fact] + void TestSimpleCase() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + + var xydata = new[] { new TestClassXY() { X = 10, Y = 100 }, new TestClassXY() { X = -1, Y = -100 } }; + var stringData = new[] { new TestClassDifferentTypes { A = "1", B = "c", C = "b" } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new TermEstimator(Env, new[]{ + new TermTransform.ColumnInfo("A", "TermA"), + new TermTransform.ColumnInfo("B", "TermB"), + new TermTransform.ColumnInfo("C", "TermC") + }); + var invalidData = ComponentCreation.CreateDataView(Env, xydata); + var validFitNotValidTransformData = ComponentCreation.CreateDataView(Env, stringData); + TestEstimatorCore(pipe, dataView, null, invalidData, validFitNotValidTransformData); + } + + [Fact] + void TestOldSavingAndLoading() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var est = new TermEstimator(env, new[]{ + new TermTransform.ColumnInfo("A", "TermA"), + new TermTransform.ColumnInfo("B", "TermB"), + new TermTransform.ColumnInfo("C", "TermC") + }); + var transformer = est.Fit(dataView); + var result = transformer.Transform(dataView); + var resultRoles = new RoleMappedData(result); + using (var ms = new MemoryStream()) + { + TrainUtils.SaveModel(env, env.Start("saving"), ms, null, resultRoles); + ms.Position = 0; + var loadedView = ModelFileUtils.LoadTransforms(env, dataView, ms); + ValidateTermTransformer(loadedView); + } + } + } + + [Fact] + void TestMetadataCopy() + { + var data = new[] { new TestMetaClass() { Term = "A", NotUsed = 1 }, new TestMetaClass() { Term = "B" }, new TestMetaClass() { Term = "C" } }; + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var termEst = new TermEstimator(env, new[] { + new TermTransform.ColumnInfo("Term" ,"T") }); + var termTransformer = termEst.Fit(dataView); + var result = termTransformer.Transform(dataView); + + result.Schema.TryGetColumnIndex("T", out int termIndex); + var names1 = default(VBuffer); + var type1 = result.Schema.GetColumnType(termIndex); + int size = type1.ItemType.IsKey ? type1.ItemType.KeyCount : -1; + result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, termIndex, ref names1); + Assert.True(names1.Count > 0); + } + } + + [Fact] + void TestCommandLine() + { + using (var env = new TlcEnvironment()) + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} in=f:\2.txt" }), (int)0); + } + } + + private void ValidateTermTransformer(IDataView result) + { + result.Schema.TryGetColumnIndex("TermA", out int ColA); + result.Schema.TryGetColumnIndex("TermB", out int ColB); + result.Schema.TryGetColumnIndex("TermC", out int ColC); + using (var cursor = result.GetRowCursor(x => true)) + { + uint avalue = 0; + uint bvalue = 0; + uint cvalue = 0; + + var aGetter = cursor.GetGetter(ColA); + var bGetter = cursor.GetGetter(ColB); + var cGetter = cursor.GetGetter(ColC); + uint i = 1; + while (cursor.MoveNext()) + { + aGetter(ref avalue); + bGetter(ref bvalue); + cGetter(ref cvalue); + Assert.Equal(i, avalue); + Assert.Equal(i, bvalue); + Assert.Equal(i, cvalue); + i++; + } + } + } + } +} +